This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -8,7 +8,6 @@ conversion (channels-first → channels-last) is needed.
"""
import logging
import math
import mlx.core as mx
import mlx.nn as nn
@@ -19,23 +18,111 @@ logger = logging.getLogger(__name__)
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
VAE22_MEAN = mx.array([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
])
VAE22_MEAN = mx.array(
[
-0.2289,
-0.0052,
-0.1323,
-0.2339,
-0.2799,
0.0174,
0.1838,
0.1557,
-0.1382,
0.0542,
0.2813,
0.0891,
0.1570,
-0.0098,
0.0375,
-0.1825,
-0.2246,
-0.1207,
-0.0698,
0.5109,
0.2665,
-0.2108,
-0.2158,
0.2502,
-0.2055,
-0.0322,
0.1109,
0.1567,
-0.0729,
0.0899,
-0.2799,
-0.1230,
-0.0313,
-0.1649,
0.0117,
0.0723,
-0.2839,
-0.2083,
-0.0520,
0.3748,
0.0152,
0.1957,
0.1433,
-0.2944,
0.3573,
-0.0548,
-0.1681,
-0.0667,
]
)
VAE22_STD = mx.array([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
])
VAE22_STD = mx.array(
[
0.4765,
1.0364,
0.4514,
1.1677,
0.5313,
0.4990,
0.4818,
0.5013,
0.8158,
1.0344,
0.5894,
1.0901,
0.6885,
0.6165,
0.8454,
0.4978,
0.5759,
0.3523,
0.7135,
0.6804,
0.5833,
1.4146,
0.8986,
0.5659,
0.7069,
0.5338,
0.4889,
0.4917,
0.4069,
0.4999,
0.6866,
0.4093,
0.5709,
0.6065,
0.6415,
0.4944,
0.5726,
1.2042,
0.5458,
1.6887,
0.3971,
1.0600,
0.3943,
0.5537,
0.5444,
0.4089,
0.7468,
0.7744,
]
)
class CausalConv3d(nn.Module):
@@ -65,9 +152,9 @@ class CausalConv3d(nn.Module):
self._pad_w = padding[2]
# Weight: [O, D, H, W, I] for MLX
self.weight = mx.zeros((
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
))
self.weight = mx.zeros(
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
)
self.bias = mx.zeros((out_channels,))
def __call__(self, x, cache_x=None):
@@ -96,8 +183,16 @@ class CausalConv3d(nn.Module):
# Spatial padding
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
(self._pad_w, self._pad_w), (0, 0)])
x = mx.pad(
x,
[
(0, 0),
(0, 0),
(self._pad_h, self._pad_h),
(self._pad_w, self._pad_w),
(0, 0),
],
)
T_padded = x.shape[1]
H_padded, W_padded = x.shape[2], x.shape[3]
@@ -113,8 +208,9 @@ class CausalConv3d(nn.Module):
for d in range(kd):
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
conv_out = mx.conv_general(frame, w2d,
stride=(self.stride[1], self.stride[2]))
conv_out = mx.conv_general(
frame, w2d, stride=(self.stride[1], self.stride[2])
)
accum = conv_out if accum is None else accum + conv_out
outputs.append(accum + self.bias)
@@ -126,7 +222,7 @@ class RMS_norm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.scale = dim**0.5
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
self.gamma = mx.ones((dim,))
@@ -134,7 +230,9 @@ class RMS_norm(nn.Module):
# x: [..., C] (channels-last)
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
return (
x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
)
class ResidualBlock(nn.Module):
@@ -145,11 +243,7 @@ class ResidualBlock(nn.Module):
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
# We store as named layers matching PyTorch's indices
self.residual = ResidualBlockLayers(in_dim, out_dim)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim
else None
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
def __call__(self, x, feat_cache=None, feat_idx=None):
h = self.shortcut(x) if self.shortcut is not None else x
@@ -182,9 +276,7 @@ class ResidualBlockLayers(nn.Module):
# Save last CACHE_T frames before conv (for next chunk's context)
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
out = conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -231,7 +323,9 @@ class AttentionBlock(nn.Module):
x = self.norm(x)
# QKV via 1x1 conv2d (equivalent to linear on last dim)
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
qkv = (
mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias
) # [BT, H, W, 3C]
qkv = qkv.reshape(B * T, H * W, 3 * C)
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
@@ -240,8 +334,10 @@ class AttentionBlock(nn.Module):
k = k[:, None, :, :]
v = v[:, None, :, :]
scale = C ** -0.5
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
scale = C**-0.5
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale
) # [BT, 1, HW, C]
out = out.squeeze(1).reshape(B * T, H, W, C)
# Project output
@@ -270,16 +366,24 @@ class DupUp3D(nn.Module):
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
x = x.reshape(
B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s
)
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
x = x.reshape(
B,
T * self.factor_t,
H * self.factor_s,
W * self.factor_s,
self.out_channels,
)
if first_chunk:
x = x[:, self.factor_t - 1:, :, :, :]
x = x[:, self.factor_t - 1 :, :, :, :]
return x
@@ -348,7 +452,9 @@ class Resample(nn.Module):
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else:
raise ValueError(f"Unsupported mode: {mode}")
@@ -369,7 +475,9 @@ class Resample(nn.Module):
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
return (
mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
)
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
# x: [B, T, H, W, C]
@@ -444,14 +552,17 @@ class Resample(nn.Module):
class Up_ResidualBlock(nn.Module):
"""Upsampling residual block with optional DupUp3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
def __init__(
self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False
):
super().__init__()
self.up_flag = up_flag
# DupUp3D shortcut (no learnable params)
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim, out_dim,
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
@@ -490,13 +601,21 @@ class Up_ResidualBlock(nn.Module):
class Down_ResidualBlock(nn.Module):
"""Downsampling residual block with AvgDown3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
def __init__(
self,
in_dim,
out_dim,
num_res_blocks,
temperal_downsample=False,
down_flag=False,
):
super().__init__()
self.down_flag = down_flag
# AvgDown3D shortcut (no learnable params, always present)
self.avg_shortcut = AvgDown3D(
in_dim, out_dim,
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
@@ -562,13 +681,15 @@ class Decoder3d(nn.Module):
self.upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
self.upsamples.append(Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
))
self.upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
)
)
# Output head: [RMS_norm, SiLU, CausalConv3d]
self.head = Head22(dims[-1])
@@ -612,13 +733,15 @@ class Encoder3d(nn.Module):
for i in range(len(dim_mult)):
in_d, out_d = dims[i], dims[i + 1]
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
self.downsamples.append(Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
))
self.downsamples.append(
Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
)
)
# Middle blocks (same as decoder)
out_dim = dims[-1]
@@ -658,9 +781,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -700,9 +821,7 @@ class Head22(nn.Module):
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
x = self.layer_2(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -768,7 +887,7 @@ class Wan22VAEEncoder(nn.Module):
if i == 0:
chunk = x[:, :1]
else:
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None:
out = chunk_out
@@ -778,7 +897,7 @@ class Wan22VAEEncoder(nn.Module):
# conv1 (pointwise) + split into mu, log_var
out = self.conv1(out)
mu = out[:, :, :, :, :self.z_dim]
mu = out[:, :, :, :, : self.z_dim]
# Normalize
mu = normalize_latents(mu)
@@ -885,8 +1004,8 @@ class Wan22VAEDecoder(nn.Module):
decoder_fn=tile_decode,
latents=z_cf,
tiling_config=tiling_config,
spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
causal_temporal=True,
)