Files
mlx-video/mlx_video/models/wan/vae22.py

848 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Wan2.2 VAE Decoder (compression 4×16×16, z_dim=48).
Architecture differs from Wan2.1 VAE: uses RMS_norm, DupUp3D shortcuts,
spatial patchify (2×2), and different temporal upsampling pattern.
Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
conversion (channels-first → channels-last) is needed.
"""
import logging
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
logger = logging.getLogger(__name__)
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
VAE22_MEAN = mx.array([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
])
VAE22_STD = mx.array([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
])
class CausalConv3d(nn.Module):
"""3D causal convolution. Input/output: [B, T, H, W, C] (channels-last).
Decomposes the 3D conv into per-frame 2D convolutions to avoid
excessive memory usage from MLX's conv3d implementation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
self.kernel_size = kernel_size
self.stride = stride
# Causal temporal padding: always kernel_size-1 on the left.
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
self._causal_pad_t = kernel_size[0] - 1
self._pad_h = padding[1]
self._pad_w = padding[2]
# Weight: [O, D, H, W, I] for MLX
self.weight = mx.zeros((
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
))
self.bias = mx.zeros((out_channels,))
def __call__(self, x):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
kd, kh, kw = self.kernel_size
# For 1x1x1 kernel or kernel_d==1, use direct conv
if kd == 1 and kh == 1 and kw == 1:
# Simple pointwise: reshape to [B*T, 1, 1, C] → conv2d
x_flat = x.reshape(B * T, H, W, C)
w2d = self.weight[:, 0, :, :, :] # [O, kH, kW, I]
y = mx.conv_general(x_flat, w2d) + self.bias
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
# Causal temporal padding (left only) — zeros match the reference
# implementation and what the model was trained with.
if self._causal_pad_t > 0:
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
x = mx.concatenate([pad_t, x], axis=1)
# Spatial padding
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
(self._pad_w, self._pad_w), (0, 0)])
T_padded = x.shape[1]
H_padded, W_padded = x.shape[2], x.shape[3]
T_out = (T_padded - kd) // self.stride[0] + 1
# Decompose 3D conv into sum of 2D convolutions over temporal kernel
# weight shape: [O, kd, kh, kw, I] → split into kd 2D kernels [O, kh, kw, I]
outputs = []
for t in range(T_out):
t_start = t * self.stride[0]
# Sum 2D convs for each temporal kernel position
accum = None
for d in range(kd):
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
conv_out = mx.conv_general(frame, w2d,
stride=(self.stride[1], self.stride[2]))
accum = conv_out if accum is None else accum + conv_out
outputs.append(accum + self.bias)
return mx.stack(outputs, axis=1) # [B, T_out, H_out, W_out, O]
class RMS_norm(nn.Module):
"""RMS normalization along channel dimension."""
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
self.gamma = mx.ones((dim,))
def __call__(self, x):
# x: [..., C] (channels-last)
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
class ResidualBlock(nn.Module):
"""Residual block: RMS_norm → SiLU → CausalConv3d × 2 + shortcut."""
def __init__(self, in_dim, out_dim):
super().__init__()
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
# We store as named layers matching PyTorch's indices
self.residual = ResidualBlockLayers(in_dim, out_dim)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim
else None
)
def __call__(self, x):
h = self.shortcut(x) if self.shortcut is not None else x
return self.residual(x) + h
class ResidualBlockLayers(nn.Module):
"""The sequential layers inside a ResidualBlock.
PyTorch stores these as nn.Sequential with indices 0-6:
[0] RMS_norm, [1] SiLU, [2] CausalConv3d, [3] RMS_norm, [4] SiLU, [5] Dropout, [6] CausalConv3d
We use matching attribute names for weight compatibility.
"""
def __init__(self, in_dim, out_dim):
super().__init__()
# Indices match PyTorch nn.Sequential indices for weight key compat
# Index 0: RMS_norm
self.layer_0 = RMS_norm(in_dim)
# Index 2: CausalConv3d
self.layer_2 = CausalConv3d(in_dim, out_dim, 3, padding=1)
# Index 3: RMS_norm
self.layer_3 = RMS_norm(out_dim)
# Index 6: CausalConv3d
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
def __call__(self, x):
x = self.layer_0(x)
x = nn.silu(x)
x = self.layer_2(x)
mx.eval(x) # Eval between convolutions to limit graph size
x = self.layer_3(x)
x = nn.silu(x)
x = self.layer_6(x)
return x
class AttentionBlock(nn.Module):
"""2D self-attention applied per frame. Input: [B, T, H, W, C]."""
def __init__(self, dim):
super().__init__()
self.dim = dim
self.norm = RMS_norm(dim)
# Conv2d as linear per spatial position — weight [O, H, W, I] for MLX
# to_qkv: dim -> 3*dim, proj: dim -> dim (1x1 conv2d)
self.to_qkv_weight = mx.zeros((3 * dim, 1, 1, dim))
self.to_qkv_bias = mx.zeros((3 * dim,))
self.proj_weight = mx.zeros((dim, 1, 1, dim))
self.proj_bias = mx.zeros((dim,))
def __call__(self, x):
# x: [B, T, H, W, C]
identity = x
B, T, H, W, C = x.shape
# Apply per frame: merge B and T
x = x.reshape(B * T, H, W, C)
x = self.norm(x)
# QKV via 1x1 conv2d (equivalent to linear on last dim)
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
qkv = qkv.reshape(B * T, H * W, 3 * C)
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
# Single-head attention
q = q[:, None, :, :] # [BT, 1, HW, C]
k = k[:, None, :, :]
v = v[:, None, :, :]
scale = C ** -0.5
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
out = out.squeeze(1).reshape(B * T, H, W, C)
# Project output
out = mx.conv_general(out, self.proj_weight) + self.proj_bias # [BT, H, W, C]
out = out.reshape(B, T, H, W, C)
return out + identity
class DupUp3D(nn.Module):
"""Upsample by duplicating channels and reshaping. No learnable parameters."""
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = factor_t * factor_s * factor_s
self.repeats = out_channels * self.factor // in_channels
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
# Repeat channels
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
if first_chunk:
x = x[:, self.factor_t - 1:, :, :, :]
return x
class AvgDown3D(nn.Module):
"""Downsample by grouping channels across spatial/temporal factors and averaging.
Inverse of DupUp3D. No learnable parameters.
Input: [B, T, H, W, C_in] → Output: [B, T//ft, H//fs, W//fs, C_out]
"""
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = factor_t * factor_s * factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def __call__(self, x):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
# Pad temporal if not divisible by factor_t
pad_t = (self.factor_t - T % self.factor_t) % self.factor_t
if pad_t > 0:
x = mx.pad(x, [(0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)])
T = T + pad_t
ft, fs = self.factor_t, self.factor_s
# Reshape to split spatial/temporal dims
x = x.reshape(B, T // ft, ft, H // fs, fs, W // fs, fs, C)
# Move factors next to channels
x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, T', H', W', C, ft, fs, fs]
# Expand channels
x = x.reshape(B, T // ft, H // fs, W // fs, C * self.factor)
# Group and average
x = x.reshape(B, T // ft, H // fs, W // fs, self.out_channels, self.group_size)
x = x.mean(axis=-1)
return x
class Resample(nn.Module):
"""Spatial up/downsampling with optional temporal up/downsampling."""
def __init__(self, dim, mode):
super().__init__()
self.dim = dim
self.mode = mode
if mode == "upsample2d":
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample_weight = mx.zeros((dim, 3, 3, dim)) # Conv2d [O, H, W, I]
self.resample_bias = mx.zeros((dim,))
elif mode == "upsample3d":
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
elif mode == "downsample3d":
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
raise ValueError(f"Unsupported mode: {mode}")
def _upsample2x(self, x):
"""Nearest-neighbor 2x spatial upsample. x: [N, H, W, C]."""
N, H, W, C = x.shape
# Repeat along H and W axes separately
x = mx.repeat(x, repeats=2, axis=1) # [N, 2H, W, C]
x = mx.repeat(x, repeats=2, axis=2) # [N, 2H, 2W, C]
return x
def _conv2d(self, x):
"""Apply the Conv2d with padding=1. x: [N, H, W, C]."""
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight) + self.resample_bias
def _downsample_conv2d(self, x):
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
if self.mode == "upsample3d":
if first_chunk and T > 1:
# Match official chunked behavior: the first frame bypasses
# time_conv entirely (only spatial upsample). Remaining frames
# go through time_conv with causal zero-padding, which
# naturally gives each frame the same limited temporal context
# as the official frame-by-frame decode with caching.
first_frame = x[:, 0:1] # [B, 1, H, W, C]
rest = x[:, 1:] # [B, T-1, H, W, C]
# time_conv on remaining frames (causal pad gives zero context
# before rest[0], matching the official "Rep" cache path)
tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :]
interleaved = mx.stack([stream0, stream1], axis=2)
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
x = mx.concatenate([first_frame, interleaved], axis=1)
elif self.mode == "upsample3d":
# Non-first-chunk or single frame: time_conv all frames
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
tc_out = tc_out.reshape(B, T, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :]
x = mx.stack([stream0, stream1], axis=2)
x = x.reshape(B, T * 2, H, W, C)
mx.eval(x)
T = x.shape[1]
if self.mode == "downsample3d" and T > 1:
# Temporal downsample via strided CausalConv3d
# Skip for T=1 (single frame) — matches official chunked encoding
# where first chunk stores cache but doesn't apply time_conv
x = self.time_conv(x)
mx.eval(x)
T = x.shape[1]
if self.mode in ("upsample2d", "upsample3d"):
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8
chunks = []
for t_start in range(0, T, chunk_size):
t_end = min(t_start + chunk_size, T)
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
x_chunk = self._upsample2x(x_chunk)
x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk)
chunks.append(x_chunk)
x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C)
elif self.mode in ("downsample2d", "downsample3d"):
# Spatial downsample: per-frame strided Conv2d
x_flat = x.reshape(B * T, H, W, C)
x_flat = self._downsample_conv2d(x_flat)
mx.eval(x_flat)
H2, W2 = x_flat.shape[1], x_flat.shape[2]
x = x_flat.reshape(B, T, H2, W2, C)
return x
class Up_ResidualBlock(nn.Module):
"""Upsampling residual block with optional DupUp3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
super().__init__()
self.up_flag = up_flag
# DupUp3D shortcut (no learnable params)
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim, out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path: ResidualBlocks + optional Resample
blocks = []
dim_in = in_dim
for _ in range(num_res_blocks):
blocks.append(ResidualBlock(dim_in, out_dim))
dim_in = out_dim
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
blocks.append(Resample(out_dim, mode=mode))
self.upsamples = blocks
def __call__(self, x, first_chunk=False):
x_main = x
for module in self.upsamples:
if isinstance(module, Resample):
x_main = module(x_main, first_chunk)
else:
x_main = module(x_main)
mx.eval(x_main) # Limit graph size per sub-block
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
mx.eval(x_shortcut)
return x_main + x_shortcut
return x_main
class Down_ResidualBlock(nn.Module):
"""Downsampling residual block with AvgDown3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
super().__init__()
self.down_flag = down_flag
# AvgDown3D shortcut (no learnable params, always present)
self.avg_shortcut = AvgDown3D(
in_dim, out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path: ResidualBlocks + optional Resample
blocks = []
dim_in = in_dim
for _ in range(num_res_blocks):
blocks.append(ResidualBlock(dim_in, out_dim))
dim_in = out_dim
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
blocks.append(Resample(out_dim, mode=mode))
self.downsamples = blocks
def __call__(self, x):
x_shortcut = self.avg_shortcut(x)
mx.eval(x_shortcut)
for module in self.downsamples:
x = module(x)
mx.eval(x)
return x + x_shortcut
class Decoder3d(nn.Module):
"""Wan2.2 3D VAE Decoder."""
def __init__(
self,
dim=256,
z_dim=48,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_upsample=(True, True, False),
):
super().__init__()
# Compute layer dimensions
dims = [dim * dim_mult[-1]] + [dim * m for m in reversed(dim_mult)]
# Initial conv
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# Middle blocks
self.middle = [
ResidualBlock(dims[0], dims[0]),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0]),
]
# Upsample blocks
self.upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
self.upsamples.append(Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
))
# Output head: [RMS_norm, SiLU, CausalConv3d]
self.head = Head22(dims[-1])
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C=z_dim]
x = self.conv1(x)
for layer in self.middle:
x = layer(x)
mx.eval(x) # Evaluate to limit graph size
for i, layer in enumerate(self.upsamples):
x = layer(x, first_chunk)
mx.eval(x) # Evaluate after each upsample block
x = self.head(x)
return x
class Encoder3d(nn.Module):
"""Wan2.2 3D VAE Encoder. Mirror of Decoder3d with downsampling."""
def __init__(
self,
dim=160,
z_dim=96,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_downsample=(False, True, True),
):
super().__init__()
# Channel dimensions: [160, 160, 320, 640, 640]
dims = [dim * m for m in [1] + list(dim_mult)]
# Initial conv: patchified input (12 ch) → first dim
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# Downsample blocks
self.downsamples = []
for i in range(len(dim_mult)):
in_d, out_d = dims[i], dims[i + 1]
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
self.downsamples.append(Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
))
# Middle blocks (same as decoder)
out_dim = dims[-1]
self.middle = [
ResidualBlock(out_dim, out_dim),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim),
]
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
self.head = Head22(out_dim, out_channels=z_dim)
def __call__(self, x):
# x: [B, T, H, W, 12] (patchified)
x = self.conv1(x)
for layer in self.downsamples:
x = layer(x)
for layer in self.middle:
x = layer(x)
mx.eval(x)
x = self.head(x)
return x
class Head22(nn.Module):
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
PyTorch key mapping: head.0 = RMS_norm, head.2 = CausalConv3d
(index 1 = SiLU has no params)
"""
def __init__(self, dim, out_channels=12):
super().__init__()
# Index 0: RMS_norm
self.layer_0 = RMS_norm(dim)
# Index 2: CausalConv3d
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
def __call__(self, x):
x = self.layer_0(x)
x = nn.silu(x)
x = self.layer_2(x)
return x
class Wan22VAEEncoder(nn.Module):
"""Full Wan2.2 VAE encoder with patchify and normalization."""
def __init__(self, z_dim=48, dim=160):
super().__init__()
self.z_dim = z_dim
# conv1: top-level 1x1x1 conv after encoder (z_dim*2 → z_dim*2)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.encoder = Encoder3d(
dim=dim,
z_dim=z_dim * 2, # Encoder outputs z_dim*2, split into mu + log_var
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_downsample=(False, True, True),
)
def __call__(self, img):
"""Encode image/video to latent space.
Args:
img: [B, T, H, W, 3] image/video in [-1, 1]
Returns:
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
"""
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
x = _patchify(img, patch_size=2)
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
out = self.encoder(x)
# conv1 (pointwise) + split into mu, log_var
out = self.conv1(out)
mu = out[:, :, :, :, :self.z_dim]
# Normalize
mu = normalize_latents(mu)
return mu
class Wan22VAEDecoder(nn.Module):
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
def __init__(self, z_dim=48, dim=160, dec_dim=256):
super().__init__()
self.z_dim = z_dim
# conv2: 1x1x1 conv before decoder
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dim=dec_dim,
z_dim=z_dim,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_upsample=(True, True, False),
)
def __call__(self, z):
"""Decode latents to video.
Args:
z: [B, T, H, W, C=48] latent tensor (already denormalized)
Returns:
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
"""
x = self.conv2(z)
# All-at-once decode with first_chunk=True to trim extra temporal
# frames from causal padding (matches PyTorch's chunked behavior)
out = self.decoder(x, first_chunk=True)
# Unpatchify: 12 channels → 3 RGB (spatial 2×2)
out = _unpatchify(out, patch_size=2)
return mx.clip(out, -1.0, 1.0)
def denormalize_latents(z, mean=None, std=None):
"""Denormalize latents: z = z / (1/std) + mean."""
if mean is None:
mean = VAE22_MEAN
if std is None:
std = VAE22_STD
inv_scale = std # scale was 1/std, so divide by scale = multiply by std
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
def normalize_latents(z, mean=None, std=None):
"""Normalize latents: z_norm = (z - mean) / std. Inverse of denormalize_latents."""
if mean is None:
mean = VAE22_MEAN
if std is None:
std = VAE22_STD
return (z - mean.reshape(1, 1, 1, 1, -1)) / std.reshape(1, 1, 1, 1, -1)
def _unpatchify(x, patch_size=2):
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
PyTorch: b (c r q) f h w -> b c f (h q) (w r) with q=p, r=p
In channels-last: [B, T, H, W, C*r*q] -> [B, T, H*q, W*r, C]
"""
if patch_size == 1:
return x
B, T, H, W, Cpacked = x.shape
C = Cpacked // (patch_size * patch_size)
# Reshape: [B, T, H, W, r, q, C] then rearrange to [B, T, H*q, W*r, C]
# PyTorch patchify: "b c f (h q) (w r) -> b (c r q) f h w" — so c is packed as (c, r, q)
# Unpatchify reverses: [B, T, H, W, (C, r, q)] -> [B, T, H, q, W, r, C]
x = x.reshape(B, T, H, W, C, patch_size, patch_size)
# Rearrange: put q next to H, r next to W
x = x.transpose(0, 1, 2, 6, 3, 5, 4) # [B, T, H, q, W, r, C]
x = x.reshape(B, T, H * patch_size, W * patch_size, C)
return x
def _patchify(x, patch_size=2):
"""Convert spatial to packed channels: [B, T, H*p, W*p, C] → [B, T, H, W, C*p*p]
Inverse of _unpatchify.
PyTorch: b c f (h q) (w r) -> b (c r q) f h w
In channels-last: [B, T, H*q, W*r, C] → [B, T, H, W, C*r*q]
"""
if patch_size == 1:
return x
B, T, Hfull, Wfull, C = x.shape
H = Hfull // patch_size
W = Wfull // patch_size
# [B, T, H, q, W, r, C]
x = x.reshape(B, T, H, patch_size, W, patch_size, C)
# Rearrange to pack q,r into channels: [B, T, H, W, C, r, q]
x = x.transpose(0, 1, 2, 4, 6, 5, 3) # [B, T, H, W, C, r, q]
x = x.reshape(B, T, H, W, C * patch_size * patch_size)
return x
def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> dict:
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
By default keeps decoder + conv2 weights only. Set include_encoder=True
to also keep encoder + conv1 weights (needed for I2V encoding).
Transposes conv weights from channels-first to channels-last.
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
Maps PyTorch nn.Sequential indices to our named layers.
"""
sanitized = {}
consumed = set()
for key, value in weights.items():
# Skip encoder and conv1 unless requested
if not include_encoder:
if key.startswith("encoder.") or key.startswith("conv1."):
consumed.add(key)
continue
new_key = key
# Map nn.Sequential indexed layers to our named attributes
# ResidualBlockLayers: indices 0, 2, 3, 6 → _layer_0, _layer_2, _layer_3, _layer_6
# Head22: indices 0, 2 → _layer_0, _layer_2
for idx in ["0", "2", "3", "6"]:
# Match patterns like "residual.0.gamma" → "residual.layer_0.gamma"
# or "head.0.gamma" → "head.layer_0.gamma"
old_pattern = f".residual.{idx}."
new_pattern = f".residual.layer_{idx}."
new_key = new_key.replace(old_pattern, new_pattern)
# Head layer mapping: head.0.gamma → head.layer_0.gamma, head.2.weight → head.layer_2.weight
for idx in ["0", "2"]:
old_pattern = f".head.{idx}."
new_pattern = f".head.layer_{idx}."
new_key = new_key.replace(old_pattern, new_pattern)
# Map Resample Conv2d: resample.1.weight → resample_weight, resample.1.bias → resample_bias
if ".resample.1.weight" in new_key:
new_key = new_key.replace(".resample.1.weight", ".resample_weight")
elif ".resample.1.bias" in new_key:
new_key = new_key.replace(".resample.1.bias", ".resample_bias")
# Map AttentionBlock Conv2d weights
if ".to_qkv.weight" in new_key:
new_key = new_key.replace(".to_qkv.weight", ".to_qkv_weight")
elif ".to_qkv.bias" in new_key:
new_key = new_key.replace(".to_qkv.bias", ".to_qkv_bias")
elif ".proj.weight" in new_key and "time_projection" not in new_key:
new_key = new_key.replace(".proj.weight", ".proj_weight")
elif ".proj.bias" in new_key and "time_projection" not in new_key:
new_key = new_key.replace(".proj.bias", ".proj_bias")
# Transpose conv weights to channels-last
is_weight = new_key.endswith(".weight") or new_key.endswith("_weight")
if is_weight:
if value.ndim == 5:
# Conv3d: [O, I, D, H, W] → [O, D, H, W, I]
value = np.transpose(np.array(value), (0, 2, 3, 4, 1))
value = mx.array(value)
elif value.ndim == 4:
# Conv2d: [O, I, H, W] → [O, H, W, I]
value = np.transpose(np.array(value), (0, 2, 3, 1))
value = mx.array(value)
# Squeeze RMS_norm gamma: (dim, 1, 1, 1) or (dim, 1, 1) → (dim,)
if "gamma" in new_key:
value = mx.array(np.array(value).squeeze())
sanitized[new_key] = value
consumed.add(key)
unconsumed = set(weights.keys()) - consumed
if unconsumed:
logger.warning("Unconsumed Wan2.2 VAE weight keys: %s", sorted(unconsumed))
return sanitized