"""Wan2.2 VAE Decoder (compression 4×16×16, z_dim=48). Architecture differs from Wan2.1 VAE: uses RMS_norm, DupUp3D shortcuts, spatial patchify (2×2), and different temporal upsampling pattern. Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format conversion (channels-first → channels-last) is needed. """ import math import mlx.core as mx import mlx.nn as nn import numpy as np CACHE_T = 2 # Per-channel normalization for z_dim=48 latent space VAE22_MEAN = mx.array([ -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667, ]) VAE22_STD = mx.array([ 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744, ]) class CausalConv3d(nn.Module): """3D causal convolution. Input/output: [B, T, H, W, C] (channels-last). Decomposes the 3D conv into per-frame 2D convolutions to avoid excessive memory usage from MLX's conv3d implementation. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride, stride) if isinstance(padding, int): padding = (padding, padding, padding) self.kernel_size = kernel_size self.stride = stride # Causal temporal padding: always kernel_size-1 on the left. # This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...). self._causal_pad_t = kernel_size[0] - 1 self._pad_h = padding[1] self._pad_w = padding[2] # Weight: [O, D, H, W, I] for MLX self.weight = mx.zeros(( out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels )) self.bias = mx.zeros((out_channels,)) def __call__(self, x): # x: [B, T, H, W, C] B, T, H, W, C = x.shape kd, kh, kw = self.kernel_size # For 1x1x1 kernel or kernel_d==1, use direct conv if kd == 1 and kh == 1 and kw == 1: # Simple pointwise: reshape to [B*T, 1, 1, C] → conv2d x_flat = x.reshape(B * T, H, W, C) w2d = self.weight[:, 0, :, :, :] # [O, kH, kW, I] y = mx.conv_general(x_flat, w2d) + self.bias return y.reshape(B, T, y.shape[1], y.shape[2], -1) # Causal temporal padding (left only) if self._causal_pad_t > 0: pad_t = mx.zeros((B, self._causal_pad_t, H, W, C)) x = mx.concatenate([pad_t, x], axis=1) # Spatial padding if self._pad_h > 0 or self._pad_w > 0: x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h), (self._pad_w, self._pad_w), (0, 0)]) T_padded = x.shape[1] H_padded, W_padded = x.shape[2], x.shape[3] T_out = (T_padded - kd) // self.stride[0] + 1 # Decompose 3D conv into sum of 2D convolutions over temporal kernel # weight shape: [O, kd, kh, kw, I] → split into kd 2D kernels [O, kh, kw, I] outputs = [] for t in range(T_out): t_start = t * self.stride[0] # Sum 2D convs for each temporal kernel position accum = None for d in range(kd): frame = x[:, t_start + d] # [B, H_padded, W_padded, C] w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I] conv_out = mx.conv_general(frame, w2d, stride=(self.stride[1], self.stride[2])) accum = conv_out if accum is None else accum + conv_out outputs.append(accum + self.bias) return mx.stack(outputs, axis=1) # [B, T_out, H_out, W_out, O] class RMS_norm(nn.Module): """RMS normalization along channel dimension.""" def __init__(self, dim): super().__init__() self.scale = dim ** 0.5 # Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze self.gamma = mx.ones((dim,)) def __call__(self, x): # x: [..., C] (channels-last) # PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps) l2_sq = mx.sum(x * x, axis=-1, keepdims=True) return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma class ResidualBlock(nn.Module): """Residual block: RMS_norm → SiLU → CausalConv3d × 2 + shortcut.""" def __init__(self, in_dim, out_dim): super().__init__() # Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d] # We store as named layers matching PyTorch's indices self.residual = ResidualBlockLayers(in_dim, out_dim) self.shortcut = ( CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None ) def __call__(self, x): h = self.shortcut(x) if self.shortcut is not None else x return self.residual(x) + h class ResidualBlockLayers(nn.Module): """The sequential layers inside a ResidualBlock. PyTorch stores these as nn.Sequential with indices 0-6: [0] RMS_norm, [1] SiLU, [2] CausalConv3d, [3] RMS_norm, [4] SiLU, [5] Dropout, [6] CausalConv3d We use matching attribute names for weight compatibility. """ def __init__(self, in_dim, out_dim): super().__init__() # Indices match PyTorch nn.Sequential indices for weight key compat # Index 0: RMS_norm self.layer_0 = RMS_norm(in_dim) # Index 2: CausalConv3d self.layer_2 = CausalConv3d(in_dim, out_dim, 3, padding=1) # Index 3: RMS_norm self.layer_3 = RMS_norm(out_dim) # Index 6: CausalConv3d self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1) def __call__(self, x): x = self.layer_0(x) x = nn.silu(x) x = self.layer_2(x) mx.eval(x) # Eval between convolutions to limit graph size x = self.layer_3(x) x = nn.silu(x) x = self.layer_6(x) return x class AttentionBlock(nn.Module): """2D self-attention applied per frame. Input: [B, T, H, W, C].""" def __init__(self, dim): super().__init__() self.dim = dim self.norm = RMS_norm(dim) # Conv2d as linear per spatial position — weight [O, H, W, I] for MLX # to_qkv: dim -> 3*dim, proj: dim -> dim (1x1 conv2d) self.to_qkv_weight = mx.zeros((3 * dim, 1, 1, dim)) self.to_qkv_bias = mx.zeros((3 * dim,)) self.proj_weight = mx.zeros((dim, 1, 1, dim)) self.proj_bias = mx.zeros((dim,)) def __call__(self, x): # x: [B, T, H, W, C] identity = x B, T, H, W, C = x.shape # Apply per frame: merge B and T x = x.reshape(B * T, H, W, C) x = self.norm(x) # QKV via 1x1 conv2d (equivalent to linear on last dim) qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C] qkv = qkv.reshape(B * T, H * W, 3 * C) q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C] # Single-head attention q = q[:, None, :, :] # [BT, 1, HW, C] k = k[:, None, :, :] v = v[:, None, :, :] scale = C ** -0.5 out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C] out = out.squeeze(1).reshape(B * T, H, W, C) # Project output out = mx.conv_general(out, self.proj_weight) + self.proj_bias # [BT, H, W, C] out = out.reshape(B, T, H, W, C) return out + identity class DupUp3D(nn.Module): """Upsample by duplicating channels and reshaping. No learnable parameters.""" def __init__(self, in_channels, out_channels, factor_t, factor_s=1): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.factor_t = factor_t self.factor_s = factor_s self.factor = factor_t * factor_s * factor_s self.repeats = out_channels * self.factor // in_channels def __call__(self, x, first_chunk=False): # x: [B, T, H, W, C] B, T, H, W, C = x.shape # Repeat channels x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats] # Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s] x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s) # Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C] x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4) # Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C] x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels) if first_chunk: x = x[:, self.factor_t - 1:, :, :, :] return x class AvgDown3D(nn.Module): """Downsample by grouping channels across spatial/temporal factors and averaging. Inverse of DupUp3D. No learnable parameters. Input: [B, T, H, W, C_in] → Output: [B, T//ft, H//fs, W//fs, C_out] """ def __init__(self, in_channels, out_channels, factor_t, factor_s=1): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.factor_t = factor_t self.factor_s = factor_s self.factor = factor_t * factor_s * factor_s assert in_channels * self.factor % out_channels == 0 self.group_size = in_channels * self.factor // out_channels def __call__(self, x): # x: [B, T, H, W, C] B, T, H, W, C = x.shape # Pad temporal if not divisible by factor_t pad_t = (self.factor_t - T % self.factor_t) % self.factor_t if pad_t > 0: x = mx.pad(x, [(0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)]) T = T + pad_t ft, fs = self.factor_t, self.factor_s # Reshape to split spatial/temporal dims x = x.reshape(B, T // ft, ft, H // fs, fs, W // fs, fs, C) # Move factors next to channels x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, T', H', W', C, ft, fs, fs] # Expand channels x = x.reshape(B, T // ft, H // fs, W // fs, C * self.factor) # Group and average x = x.reshape(B, T // ft, H // fs, W // fs, self.out_channels, self.group_size) x = x.mean(axis=-1) return x class Resample(nn.Module): """Spatial up/downsampling with optional temporal up/downsampling.""" def __init__(self, dim, mode): super().__init__() self.dim = dim self.mode = mode if mode == "upsample2d": # resample.0 = Upsample (no params), resample.1 = Conv2d self.resample_weight = mx.zeros((dim, 3, 3, dim)) # Conv2d [O, H, W, I] self.resample_bias = mx.zeros((dim,)) elif mode == "upsample3d": self.resample_weight = mx.zeros((dim, 3, 3, dim)) self.resample_bias = mx.zeros((dim,)) # time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0)) self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) elif mode == "downsample2d": # resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2) self.resample_weight = mx.zeros((dim, 3, 3, dim)) self.resample_bias = mx.zeros((dim,)) elif mode == "downsample3d": self.resample_weight = mx.zeros((dim, 3, 3, dim)) self.resample_bias = mx.zeros((dim,)) # time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1)) self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: raise ValueError(f"Unsupported mode: {mode}") def _upsample2x(self, x): """Nearest-neighbor 2x spatial upsample. x: [N, H, W, C].""" N, H, W, C = x.shape # Repeat along H and W axes separately x = mx.repeat(x, repeats=2, axis=1) # [N, 2H, W, C] x = mx.repeat(x, repeats=2, axis=2) # [N, 2H, 2W, C] return x def _conv2d(self, x): """Apply the Conv2d with padding=1. x: [N, H, W, C].""" x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)]) return mx.conv_general(x, self.resample_weight) + self.resample_bias def _downsample_conv2d(self, x): """Apply strided Conv2d for downsampling. x: [N, H, W, C].""" # ZeroPad2d((0,1,0,1)): pad right=1, bottom=1 x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias def __call__(self, x, first_chunk=False): # x: [B, T, H, W, C] B, T, H, W, C = x.shape if self.mode == "upsample3d": if first_chunk and T > 1: # Match official chunked behavior: the first frame bypasses # time_conv entirely (only spatial upsample). Remaining frames # go through time_conv with causal zero-padding, which # naturally gives each frame the same limited temporal context # as the official frame-by-frame decode with caching. first_frame = x[:, 0:1] # [B, 1, H, W, C] rest = x[:, 1:] # [B, T-1, H, W, C] # time_conv on remaining frames (causal pad gives zero context # before rest[0], matching the official "Rep" cache path) tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C] tc_out = tc_out.reshape(B, T - 1, H, W, 2, C) stream0 = tc_out[:, :, :, :, 0, :] stream1 = tc_out[:, :, :, :, 1, :] interleaved = mx.stack([stream0, stream1], axis=2) interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C) # first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames x = mx.concatenate([first_frame, interleaved], axis=1) elif self.mode == "upsample3d": # Non-first-chunk or single frame: time_conv all frames tc_out = self.time_conv(x) # [B, T, H, W, 2C] tc_out = tc_out.reshape(B, T, H, W, 2, C) stream0 = tc_out[:, :, :, :, 0, :] stream1 = tc_out[:, :, :, :, 1, :] x = mx.stack([stream0, stream1], axis=2) x = x.reshape(B, T * 2, H, W, C) mx.eval(x) T = x.shape[1] if self.mode == "downsample3d" and T > 1: # Temporal downsample via strided CausalConv3d # Skip for T=1 (single frame) — matches official chunked encoding # where first chunk stores cache but doesn't apply time_conv x = self.time_conv(x) mx.eval(x) T = x.shape[1] if self.mode in ("upsample2d", "upsample3d"): # Spatial upsample in temporal chunks to limit peak memory chunk_size = 8 chunks = [] for t_start in range(0, T, chunk_size): t_end = min(t_start + chunk_size, T) x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C) x_chunk = self._upsample2x(x_chunk) x_chunk = self._conv2d(x_chunk) mx.eval(x_chunk) chunks.append(x_chunk) x = mx.concatenate(chunks, axis=0) H2, W2 = x.shape[1], x.shape[2] x = x.reshape(B, T, H2, W2, C) elif self.mode in ("downsample2d", "downsample3d"): # Spatial downsample: per-frame strided Conv2d x_flat = x.reshape(B * T, H, W, C) x_flat = self._downsample_conv2d(x_flat) mx.eval(x_flat) H2, W2 = x_flat.shape[1], x_flat.shape[2] x = x_flat.reshape(B, T, H2, W2, C) return x class Up_ResidualBlock(nn.Module): """Upsampling residual block with optional DupUp3D shortcut.""" def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False): super().__init__() self.up_flag = up_flag # DupUp3D shortcut (no learnable params) if up_flag: self.avg_shortcut = DupUp3D( in_dim, out_dim, factor_t=2 if temperal_upsample else 1, factor_s=2 if up_flag else 1, ) else: self.avg_shortcut = None # Main path: ResidualBlocks + optional Resample blocks = [] dim_in = in_dim for _ in range(num_res_blocks): blocks.append(ResidualBlock(dim_in, out_dim)) dim_in = out_dim if up_flag: mode = "upsample3d" if temperal_upsample else "upsample2d" blocks.append(Resample(out_dim, mode=mode)) self.upsamples = blocks def __call__(self, x, first_chunk=False): x_main = x for module in self.upsamples: if isinstance(module, Resample): x_main = module(x_main, first_chunk) else: x_main = module(x_main) mx.eval(x_main) # Limit graph size per sub-block if self.avg_shortcut is not None: x_shortcut = self.avg_shortcut(x, first_chunk) mx.eval(x_shortcut) return x_main + x_shortcut return x_main class Down_ResidualBlock(nn.Module): """Downsampling residual block with AvgDown3D shortcut.""" def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False): super().__init__() self.down_flag = down_flag # AvgDown3D shortcut (no learnable params, always present) self.avg_shortcut = AvgDown3D( in_dim, out_dim, factor_t=2 if temperal_downsample else 1, factor_s=2 if down_flag else 1, ) # Main path: ResidualBlocks + optional Resample blocks = [] dim_in = in_dim for _ in range(num_res_blocks): blocks.append(ResidualBlock(dim_in, out_dim)) dim_in = out_dim if down_flag: mode = "downsample3d" if temperal_downsample else "downsample2d" blocks.append(Resample(out_dim, mode=mode)) self.downsamples = blocks def __call__(self, x): x_shortcut = self.avg_shortcut(x) mx.eval(x_shortcut) for module in self.downsamples: x = module(x) mx.eval(x) return x + x_shortcut class Decoder3d(nn.Module): """Wan2.2 3D VAE Decoder.""" def __init__( self, dim=256, z_dim=48, dim_mult=(1, 2, 4, 4), num_res_blocks=2, temperal_upsample=(True, True, False), ): super().__init__() # Compute layer dimensions dims = [dim * dim_mult[-1]] + [dim * m for m in reversed(dim_mult)] # Initial conv self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) # Middle blocks self.middle = [ ResidualBlock(dims[0], dims[0]), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0]), ] # Upsample blocks self.upsamples = [] for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): t_up = temperal_upsample[i] if i < len(temperal_upsample) else False self.upsamples.append(Up_ResidualBlock( in_dim=in_dim, out_dim=out_dim, num_res_blocks=num_res_blocks + 1, temperal_upsample=t_up, up_flag=(i != len(dim_mult) - 1), )) # Output head: [RMS_norm, SiLU, CausalConv3d] self.head = Head22(dims[-1]) def __call__(self, x, first_chunk=False): # x: [B, T, H, W, C=z_dim] x = self.conv1(x) for layer in self.middle: x = layer(x) mx.eval(x) # Evaluate to limit graph size for i, layer in enumerate(self.upsamples): x = layer(x, first_chunk) mx.eval(x) # Evaluate after each upsample block x = self.head(x) return x class Encoder3d(nn.Module): """Wan2.2 3D VAE Encoder. Mirror of Decoder3d with downsampling.""" def __init__( self, dim=160, z_dim=96, dim_mult=(1, 2, 4, 4), num_res_blocks=2, temperal_downsample=(False, True, True), ): super().__init__() # Channel dimensions: [160, 160, 320, 640, 640] dims = [dim * m for m in [1] + list(dim_mult)] # Initial conv: patchified input (12 ch) → first dim self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) # Downsample blocks self.downsamples = [] for i in range(len(dim_mult)): in_d, out_d = dims[i], dims[i + 1] t_down = temperal_downsample[i] if i < len(temperal_downsample) else False self.downsamples.append(Down_ResidualBlock( in_dim=in_d, out_dim=out_d, num_res_blocks=num_res_blocks, temperal_downsample=t_down, down_flag=(i < len(dim_mult) - 1), )) # Middle blocks (same as decoder) out_dim = dims[-1] self.middle = [ ResidualBlock(out_dim, out_dim), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim), ] # Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels self.head = Head22(out_dim, out_channels=z_dim) def __call__(self, x): # x: [B, T, H, W, 12] (patchified) x = self.conv1(x) for layer in self.downsamples: x = layer(x) for layer in self.middle: x = layer(x) mx.eval(x) x = self.head(x) return x class Head22(nn.Module): """Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3). PyTorch key mapping: head.0 = RMS_norm, head.2 = CausalConv3d (index 1 = SiLU has no params) """ def __init__(self, dim, out_channels=12): super().__init__() # Index 0: RMS_norm self.layer_0 = RMS_norm(dim) # Index 2: CausalConv3d self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1) def __call__(self, x): x = self.layer_0(x) x = nn.silu(x) x = self.layer_2(x) return x class Wan22VAEEncoder(nn.Module): """Full Wan2.2 VAE encoder with patchify and normalization.""" def __init__(self, z_dim=48, dim=160): super().__init__() self.z_dim = z_dim # conv1: top-level 1x1x1 conv after encoder (z_dim*2 → z_dim*2) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.encoder = Encoder3d( dim=dim, z_dim=z_dim * 2, # Encoder outputs z_dim*2, split into mu + log_var dim_mult=(1, 2, 4, 4), num_res_blocks=2, temperal_downsample=(False, True, True), ) def __call__(self, img): """Encode image/video to latent space. Args: img: [B, T, H, W, 3] image/video in [-1, 1] Returns: mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent """ # Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12] x = _patchify(img, patch_size=2) # Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2] out = self.encoder(x) # conv1 (pointwise) + split into mu, log_var out = self.conv1(out) mu = out[:, :, :, :, :self.z_dim] # Normalize mu = normalize_latents(mu) return mu class Wan22VAEDecoder(nn.Module): """Full Wan2.2 VAE decoder with normalization and unpatchify.""" def __init__(self, z_dim=48, dim=160, dec_dim=256): super().__init__() self.z_dim = z_dim # conv2: 1x1x1 conv before decoder self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.decoder = Decoder3d( dim=dec_dim, z_dim=z_dim, dim_mult=(1, 2, 4, 4), num_res_blocks=2, temperal_upsample=(True, True, False), ) def __call__(self, z): """Decode latents to video. Args: z: [B, T, H, W, C=48] latent tensor (already denormalized) Returns: video: [B, T', H', W', 3] decoded RGB in [-1, 1] """ x = self.conv2(z) # All-at-once decode with first_chunk=True to trim extra temporal # frames from causal padding (matches PyTorch's chunked behavior) out = self.decoder(x, first_chunk=True) # Unpatchify: 12 channels → 3 RGB (spatial 2×2) out = _unpatchify(out, patch_size=2) return mx.clip(out, -1.0, 1.0) def denormalize_latents(z, mean=None, std=None): """Denormalize latents: z = z / (1/std) + mean.""" if mean is None: mean = VAE22_MEAN if std is None: std = VAE22_STD inv_scale = std # scale was 1/std, so divide by scale = multiply by std return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1) def normalize_latents(z, mean=None, std=None): """Normalize latents: z_norm = (z - mean) / std. Inverse of denormalize_latents.""" if mean is None: mean = VAE22_MEAN if std is None: std = VAE22_STD return (z - mean.reshape(1, 1, 1, 1, -1)) / std.reshape(1, 1, 1, 1, -1) def _unpatchify(x, patch_size=2): """Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)] Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3] PyTorch: b (c r q) f h w -> b c f (h q) (w r) with q=p, r=p In channels-last: [B, T, H, W, C*r*q] -> [B, T, H*q, W*r, C] """ if patch_size == 1: return x B, T, H, W, Cpacked = x.shape C = Cpacked // (patch_size * patch_size) # Reshape: [B, T, H, W, r, q, C] then rearrange to [B, T, H*q, W*r, C] # PyTorch patchify: "b c f (h q) (w r) -> b (c r q) f h w" — so c is packed as (c, r, q) # Unpatchify reverses: [B, T, H, W, (C, r, q)] -> [B, T, H, q, W, r, C] x = x.reshape(B, T, H, W, C, patch_size, patch_size) # Rearrange: put q next to H, r next to W x = x.transpose(0, 1, 2, 6, 3, 5, 4) # [B, T, H, q, W, r, C] x = x.reshape(B, T, H * patch_size, W * patch_size, C) return x def _patchify(x, patch_size=2): """Convert spatial to packed channels: [B, T, H*p, W*p, C] → [B, T, H, W, C*p*p] Inverse of _unpatchify. PyTorch: b c f (h q) (w r) -> b (c r q) f h w In channels-last: [B, T, H*q, W*r, C] → [B, T, H, W, C*r*q] """ if patch_size == 1: return x B, T, Hfull, Wfull, C = x.shape H = Hfull // patch_size W = Wfull // patch_size # [B, T, H, q, W, r, C] x = x.reshape(B, T, H, patch_size, W, patch_size, C) # Rearrange to pack q,r into channels: [B, T, H, W, C, r, q] x = x.transpose(0, 1, 2, 4, 6, 5, 3) # [B, T, H, W, C, r, q] x = x.reshape(B, T, H, W, C * patch_size * patch_size) return x def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> dict: """Convert PyTorch Wan2.2 VAE weights to MLX format. By default keeps decoder + conv2 weights only. Set include_encoder=True to also keep encoder + conv1 weights (needed for I2V encoding). Transposes conv weights from channels-first to channels-last. Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,). Maps PyTorch nn.Sequential indices to our named layers. """ sanitized = {} for key, value in weights.items(): # Skip encoder and conv1 unless requested if not include_encoder: if key.startswith("encoder.") or key.startswith("conv1."): continue new_key = key # Map nn.Sequential indexed layers to our named attributes # ResidualBlockLayers: indices 0, 2, 3, 6 → _layer_0, _layer_2, _layer_3, _layer_6 # Head22: indices 0, 2 → _layer_0, _layer_2 for idx in ["0", "2", "3", "6"]: # Match patterns like "residual.0.gamma" → "residual.layer_0.gamma" # or "head.0.gamma" → "head.layer_0.gamma" old_pattern = f".residual.{idx}." new_pattern = f".residual.layer_{idx}." new_key = new_key.replace(old_pattern, new_pattern) # Head layer mapping: head.0.gamma → head.layer_0.gamma, head.2.weight → head.layer_2.weight for idx in ["0", "2"]: old_pattern = f".head.{idx}." new_pattern = f".head.layer_{idx}." new_key = new_key.replace(old_pattern, new_pattern) # Map Resample Conv2d: resample.1.weight → resample_weight, resample.1.bias → resample_bias if ".resample.1.weight" in new_key: new_key = new_key.replace(".resample.1.weight", ".resample_weight") elif ".resample.1.bias" in new_key: new_key = new_key.replace(".resample.1.bias", ".resample_bias") # Map AttentionBlock Conv2d weights if ".to_qkv.weight" in new_key: new_key = new_key.replace(".to_qkv.weight", ".to_qkv_weight") elif ".to_qkv.bias" in new_key: new_key = new_key.replace(".to_qkv.bias", ".to_qkv_bias") elif ".proj.weight" in new_key and "time_projection" not in new_key: new_key = new_key.replace(".proj.weight", ".proj_weight") elif ".proj.bias" in new_key and "time_projection" not in new_key: new_key = new_key.replace(".proj.bias", ".proj_bias") # Transpose conv weights to channels-last is_weight = new_key.endswith(".weight") or new_key.endswith("_weight") if is_weight: if value.ndim == 5: # Conv3d: [O, I, D, H, W] → [O, D, H, W, I] value = np.transpose(np.array(value), (0, 2, 3, 4, 1)) value = mx.array(value) elif value.ndim == 4: # Conv2d: [O, I, H, W] → [O, H, W, I] value = np.transpose(np.array(value), (0, 2, 3, 1)) value = mx.array(value) # Squeeze RMS_norm gamma: (dim, 1, 1, 1) or (dim, 1, 1) → (dim,) if "gamma" in new_key: value = mx.array(np.array(value).squeeze()) sanitized[new_key] = value return sanitized