This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -5,9 +5,9 @@ from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from ..config import CausalityAxis
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
in_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
def __call__(self, x: mx.array) -> mx.array:
@@ -124,10 +128,14 @@ def build_upsampling_path(
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if level != 0:
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
stage["upsample"] = Upsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res *= 2
up_modules[level] = stage