format
This commit is contained in:
@@ -5,8 +5,8 @@ from typing import Set, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in MLX conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -116,10 +118,14 @@ def build_downsampling_path(
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
stage["attn"][i_block] = make_attn(
|
||||
block_in, attn_type=attn_type, norm_type=norm_type
|
||||
)
|
||||
|
||||
if i_level != num_resolutions - 1:
|
||||
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
stage["downsample"] = Downsample(
|
||||
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||
)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
down_modules[i_level] = stage
|
||||
|
||||
Reference in New Issue
Block a user