"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8). Module structure mirrors original PyTorch checkpoint key hierarchy so weights load directly without key sanitization. """ import mlx.core as mx import mlx.nn as nn import numpy as np CACHE_T = 2 # Per-channel normalization statistics for z_dim=16 VAE_MEAN = [ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921, ] VAE_STD = [ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160, ] class CausalConv3d(nn.Module): """3D convolution with causal temporal padding.""" def __init__( self, in_channels: int, out_channels: int, kernel_size: int | tuple, stride: int | tuple = 1, padding: int | tuple = 0, ): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride, stride) if isinstance(padding, int): padding = (padding, padding, padding) self.kernel_size = kernel_size self.stride = stride # Causal padding: match reference formula dilation*(k-1) + (1-stride) # With dilation=1: k-stride (pads left only, no future context) self._causal_pad_t = kernel_size[0] - stride[0] self._pad_h = padding[1] self._pad_w = padding[2] # MLX Conv3d: weight shape [O, D, H, W, I] self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)) self.bias = mx.zeros((out_channels,)) def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array: """x: [B, C, T, H, W] (channel-first)""" b, c, t, h, w = x.shape causal_pad = self._causal_pad_t if cache_x is not None and causal_pad > 0: x = mx.concatenate([cache_x, x], axis=2) causal_pad = max(0, causal_pad - cache_x.shape[2]) if causal_pad > 0: pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype) x = mx.concatenate([pad_t, x], axis=2) if self._pad_h > 0 or self._pad_w > 0: x = mx.pad(x, [(0, 0), (0, 0), (0, 0), (self._pad_h, self._pad_h), (self._pad_w, self._pad_w)]) x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C] out = self._conv3d(x) return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W'] def _conv3d(self, x: mx.array) -> mx.array: """3D conv via sliding window + 2D conv per time step. x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out] """ b, t, h, w, c_in = x.shape kt, kh, kw = self.kernel_size st, sh, sw = self.stride t_out = (t - kt) // st + 1 # Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I] w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape( self.weight.shape[0], kh, kw, kt * c_in ) outputs = [] for t_i in range(t_out): t_start = t_i * st window = x[:, t_start : t_start + kt] window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in) out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias outputs.append(out_2d) return mx.stack(outputs, axis=1) class RMS_norm(nn.Module): """Channel-first L2 normalization matching original Wan VAE. Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm. images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input. images=False: gamma shape (dim, 1, 1, 1) for 5D video input. """ def __init__(self, dim: int, channel_first: bool = True, images: bool = True): super().__init__() self.channel_first = channel_first self.scale = dim**0.5 if channel_first: broadcastable = (1, 1) if images else (1, 1, 1) self.gamma = mx.ones((dim, *broadcastable)) else: self.gamma = mx.ones((dim,)) def __call__(self, x: mx.array) -> mx.array: norm_dim = 1 if self.channel_first else -1 # L2 normalize along channel dim (matches F.normalize) norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None)) return (x / norm) * self.scale * self.gamma class ResidualBlock(nn.Module): """Residual block with causal 3D convolutions. Uses `residual` list with None gaps to match original PyTorch nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm, [4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params. """ def __init__(self, in_dim: int, out_dim: int): super().__init__() self.residual = [ RMS_norm(in_dim, images=False), # [0] None, # [1] SiLU CausalConv3d(in_dim, out_dim, 3, padding=1), # [2] RMS_norm(out_dim, images=False), # [3] None, # [4] SiLU None, # [5] Dropout CausalConv3d(out_dim, out_dim, 3, padding=1), # [6] ] self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array: h = x if self.shortcut is None else self.shortcut(x) if feat_cache is not None: # First conv: norm -> silu -> [cache] -> conv x = nn.silu(self.residual[0](x)) idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:] if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2) x = self.residual[2](x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 # Second conv: norm -> silu -> [cache] -> conv x = nn.silu(self.residual[3](x)) idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:] if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2) x = self.residual[6](x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = nn.silu(self.residual[0](x)) x = self.residual[2](x) x = nn.silu(self.residual[3](x)) x = self.residual[6](x) return x + h class AttentionBlock(nn.Module): """Single-head spatial self-attention.""" def __init__(self, dim: int): super().__init__() self.norm = RMS_norm(dim, images=True) self.to_qkv = nn.Conv2d(dim, dim * 3, 1) self.proj = nn.Conv2d(dim, dim, 1) def __call__(self, x: mx.array) -> mx.array: """x: [B, C, T, H, W]""" identity = x b, c, t, h, w = x.shape # [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C] x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w) x = self.norm(x) x = x.transpose(0, 2, 3, 1) # [BT, H, W, C] qkv = self.to_qkv(x) # [BT, H, W, 3C] qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3) q, k, v = qkv[0], qkv[1], qkv[2] q = q[:, None, :, :] # [BT, 1, HW, C] k = k[:, None, :, :] v = v[:, None, :, :] out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5) out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C] out = self.proj(out) # [BT, H, W, C] out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W] return out + identity class Resample(nn.Module): """Resample block matching original Wan VAE structure. Supports both upsampling (decoder) and downsampling (encoder). Uses list-based param storage to match original nn.Sequential key hierarchy. """ def __init__(self, dim: int, mode: str): super().__init__() assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d") self.mode = mode self.dim = dim if mode.startswith("upsample"): # resample.0 = Upsample (no params), resample.1 = Conv2d self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)] if mode == "upsample3d": self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) else: # resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2) self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)] if mode == "downsample3d": self.time_conv = CausalConv3d( dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array: """x: [B, C, T, H, W]""" b, c, t, h, w = x.shape if self.mode == "upsample3d": # Temporal upsample via learned conv x_t = self.time_conv(x) # [B, 2C, T, H, W] x_t = x_t.reshape(b, 2, c, t, h, w) x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w) t = t * 2 if self.mode.startswith("upsample"): # Per-frame spatial upsample: nearest 2x + Conv2d x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C] x = mx.repeat(x, 2, axis=1) x = mx.repeat(x, 2, axis=2) x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2] c_out = x.shape[-1] return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3) else: # Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2) x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C] x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1) x = self.resample[1](x) # Conv2d stride=2 c_out = x.shape[-1] h_out, w_out = x.shape[1], x.shape[2] x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3) if self.mode == "downsample3d": if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: # First chunk: save x, skip time_conv feat_cache[idx] = x feat_idx[0] += 1 else: # Subsequent chunks: use cached frame as temporal context cache_x = x[:, :, -1:] x = self.time_conv( x, cache_x=feat_cache[idx][:, :, -1:]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.time_conv(x) return x class Decoder3d(nn.Module): """3D VAE Decoder matching Wan2.1 architecture. Uses flat `middle` and `upsamples` lists to match original PyTorch nn.Sequential weight key hierarchy. """ def __init__( self, dim: int = 96, z_dim: int = 16, dim_mult: list = None, num_res_blocks: int = 2, temporal_upsample: list = None, ): super().__init__() if dim_mult is None: dim_mult = [1, 2, 4, 4] if temporal_upsample is None: temporal_upsample = [True, True, False] dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) # Middle: [ResBlock, AttentionBlock, ResBlock] self.middle = [ ResidualBlock(dims[0], dims[0]), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0]), ] # Flat upsample list matching original nn.Sequential indexing upsamples = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): if i in (1, 2, 3): in_dim = in_dim // 2 for _ in range(num_res_blocks + 1): upsamples.append(ResidualBlock(in_dim, out_dim)) in_dim = out_dim if i != len(dim_mult) - 1: mode = "upsample3d" if temporal_upsample[i] else "upsample2d" upsamples.append(Resample(out_dim, mode=mode)) self.upsamples = upsamples # Output head: [RMS_norm, SiLU (no params), CausalConv3d] self.head = [ RMS_norm(dims[-1], images=False), # [0] None, # [1] SiLU CausalConv3d(dims[-1], 3, 3, padding=1), # [2] ] def __call__(self, x: mx.array) -> mx.array: """x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]""" x = self.conv1(x) for layer in self.middle: x = layer(x) for layer in self.upsamples: x = layer(x) x = nn.silu(self.head[0](x)) x = self.head[2](x) return x class Encoder3d(nn.Module): """3D VAE Encoder matching Wan2.1 architecture. Mirror of Decoder3d with downsampling instead of upsampling. Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy. """ def __init__( self, dim: int = 96, z_dim: int = 16, dim_mult: list = None, num_res_blocks: int = 2, temporal_downsample: list = None, ): super().__init__() if dim_mult is None: dim_mult = [1, 2, 4, 4] if temporal_downsample is None: temporal_downsample = [False, True, True] dims = [dim * u for u in [1] + dim_mult] self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) # Flat downsample list matching original nn.Sequential indexing downsamples = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): for _ in range(num_res_blocks): downsamples.append(ResidualBlock(in_dim, out_dim)) in_dim = out_dim if i != len(dim_mult) - 1: mode = "downsample3d" if temporal_downsample[i] else "downsample2d" downsamples.append(Resample(out_dim, mode=mode)) self.downsamples = downsamples # Middle: [ResBlock, AttentionBlock, ResBlock] self.middle = [ ResidualBlock(dims[-1], dims[-1]), AttentionBlock(dims[-1]), ResidualBlock(dims[-1], dims[-1]), ] # Output head: [RMS_norm, SiLU (no params), CausalConv3d] self.head = [ RMS_norm(dims[-1], images=False), None, # SiLU CausalConv3d(dims[-1], z_dim, 3, padding=1), ] def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array: """x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]""" if feat_cache is not None: # conv1 with caching idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:] if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: cache_x = mx.concatenate( [feat_cache[idx][:, :, -1:], cache_x], axis=2) x = self.conv1(x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = self.conv1(x) for layer in self.downsamples: if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)): x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = layer(x) for layer in self.middle: if feat_cache is not None and isinstance(layer, ResidualBlock): x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) else: x = layer(x) if feat_cache is not None: # Head: norm -> silu -> [cache] -> conv x = nn.silu(self.head[0](x)) idx = feat_idx[0] cache_x = x[:, :, -CACHE_T:] if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None: cache_x = mx.concatenate( [feat_cache[idx][:, :, -1:], cache_x], axis=2) x = self.head[2](x, cache_x=feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 else: x = nn.silu(self.head[0](x)) x = self.head[2](x) return x class WanVAE(nn.Module): """Wan2.1 VAE wrapper with per-channel normalization. Supports both encode (for I2V) and decode (for all models). """ def __init__(self, z_dim: int = 16, encoder: bool = False): super().__init__() self.z_dim = z_dim self.mean = mx.array(VAE_MEAN) self.std = mx.array(VAE_STD) self.inv_std = 1.0 / self.std self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.decoder = Decoder3d(dim=96, z_dim=z_dim) if encoder: self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) def encode(self, x: mx.array) -> mx.array: """Encode video to normalized latent using chunked encoding. Uses chunked encoding with temporal caching to match reference behavior. First frame encoded alone, then 4-frame chunks with cached context. Args: x: Video [B, 3, T, H, W] in [-1, 1] Returns: Normalized latent [B, z_dim, T_lat, H_lat, W_lat] """ # Count cacheable CausalConv3d slots in encoder num_slots = self._count_encoder_cache_slots() feat_cache = [None] * num_slots t = x.shape[2] num_chunks = 1 + (t - 1) // 4 out = None for i in range(num_chunks): feat_idx = [0] if i == 0: chunk = x[:, :, :1] else: chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i] chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx) if out is None: out = chunk_out else: out = mx.concatenate([out, chunk_out], axis=2) mu, _ = mx.split(self.conv1(out), 2, axis=1) # Normalize: (mu - mean) * inv_std mean = self.mean.reshape(1, -1, 1, 1, 1) inv_std = self.inv_std.reshape(1, -1, 1, 1, 1) return (mu - mean) * inv_std def _count_encoder_cache_slots(self) -> int: """Count CausalConv3d that participate in chunked encoding cache.""" count = 1 # encoder.conv1 for layer in self.encoder.downsamples: if isinstance(layer, ResidualBlock): count += 2 # two convs in residual path elif isinstance(layer, Resample) and layer.mode == "downsample3d": count += 1 # time_conv for layer in self.encoder.middle: if isinstance(layer, ResidualBlock): count += 2 count += 1 # encoder.head CausalConv3d return count def decode(self, z: mx.array) -> mx.array: """Decode latent to video. Args: z: Normalized latent [B, z_dim, T, H, W] Returns: Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1] """ mean = self.mean.reshape(1, -1, 1, 1, 1) inv_std = self.inv_std.reshape(1, -1, 1, 1, 1) z = z / inv_std + mean x = self.conv2(z) out = self.decoder(x) return mx.clip(out, -1, 1) def decode_tiled(self, z: mx.array, tiling_config=None) -> mx.array: """Decode latent to video using tiling to reduce memory usage. Splits the latent tensor into overlapping spatial/temporal tiles, decodes each tile independently, and blends them with trapezoidal masks. Reuses the LTX-2 tiling infrastructure. Args: z: Normalized latent [B, z_dim, T, H, W] tiling_config: Optional TilingConfig. If None, uses default. Returns: Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1] """ from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling if tiling_config is None: tiling_config = TilingConfig.default() # Check if tiling is actually needed _, _, f, h, w = z.shape needs_tiling = False if tiling_config.spatial_config is not None: s_tile = tiling_config.spatial_config.tile_size_in_pixels // 8 if h > s_tile or w > s_tile: needs_tiling = True if tiling_config.temporal_config is not None: t_tile = tiling_config.temporal_config.tile_size_in_frames // 4 if f > t_tile: needs_tiling = True if not needs_tiling: return self.decode(z) # Denormalize once (small tensor), then tile the denormalized latents mean = self.mean.reshape(1, -1, 1, 1, 1) inv_std = self.inv_std.reshape(1, -1, 1, 1, 1) z_denorm = z / inv_std + mean def tile_decode(tile_latents, **kwargs): x = self.conv2(tile_latents) out = self.decoder(x) return mx.clip(out, -1, 1) return decode_with_tiling( decoder_fn=tile_decode, latents=z_denorm, tiling_config=tiling_config, spatial_scale=8, # 3× spatial 2× upsamples = 8× temporal_scale=4, # 2× temporal upsamples × 2 = 4× causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T) )