format
This commit is contained in:
@@ -5,9 +5,9 @@ from typing import Set, Tuple
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..config import CausalityAxis
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
@@ -124,10 +128,14 @@ def build_upsampling_path(
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
stage["attn"][i_block] = make_attn(
|
||||
block_in, attn_type=attn_type, norm_type=norm_type
|
||||
)
|
||||
|
||||
if level != 0:
|
||||
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
stage["upsample"] = Upsample(
|
||||
block_in, resamp_with_conv, causality_axis=causality_axis
|
||||
)
|
||||
curr_res *= 2
|
||||
|
||||
up_modules[level] = stage
|
||||
|
||||
Reference in New Issue
Block a user