This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -5,8 +5,8 @@ from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from ..config import CausalityAxis
from .attention import AttentionType, make_attn
from .normalization import NormType
from .resnet import ResnetBlock
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in MLX conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array) -> mx.array:
"""
@@ -116,10 +118,14 @@ def build_downsampling_path(
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if i_level != num_resolutions - 1:
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
stage["downsample"] = Downsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res = curr_res // 2
down_modules[i_level] = stage