feat(wan): Add Wan2.1/2.2 T2V with quantization support

This commit is contained in:
Daniel
2026-02-26 16:16:07 +01:00
parent 7a74946c57
commit e64483a66a
21 changed files with 5309 additions and 35 deletions

View File

@@ -0,0 +1,2 @@
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel

View File

@@ -0,0 +1,201 @@
import mlx.core as mx
import mlx.nn as nn
from .rope import rope_apply
class WanRMSNorm(nn.Module):
"""RMS normalization with learnable scale."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class WanLayerNorm(nn.Module):
"""LayerNorm computed in float32, with optional affine."""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = mx.ones((dim,))
self.bias = mx.zeros((dim,))
def __call__(self, x: mx.array) -> mx.array:
if self.elementwise_affine:
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
else:
return mx.fast.layer_norm(x, None, None, self.eps)
class WanSelfAttention(nn.Module):
"""Self-attention with QK normalization and 3-way factorized RoPE."""
def __init__(
self,
dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def __call__(
self,
x: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
) -> mx.array:
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
q = self.q(x)
k = self.k(x)
if self.norm_q is not None:
q = self.norm_q(q)
if self.norm_k is not None:
k = self.norm_k(k)
q = q.reshape(b, s, n, d)
k = k.reshape(b, s, n, d)
v = self.v(x).reshape(b, s, n, d)
# Apply RoPE
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Build attention mask from seq_lens
max_len = s
mask = None
if any(sl < max_len for sl in seq_lens):
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
for i, sl in enumerate(seq_lens):
mask[i, :, :, sl:] = -1e9
# Use memory-efficient scaled dot-product attention
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out)
class WanCrossAttention(nn.Module):
"""Cross-attention: Q from hidden states, K/V from text context."""
def __init__(
self,
dim: int,
num_heads: int,
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def prepare_kv(self, context: mx.array) -> tuple:
"""Pre-compute K and V projections for caching.
Args:
context: [B, L_ctx, dim]
Returns:
(k, v) each [B, N, L_ctx, D] ready for attention
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
k = self.k(context)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
return k, v
def __call__(
self,
x: mx.array,
context: mx.array,
context_lens: list | None = None,
kv_cache: tuple | None = None,
) -> mx.array:
b = x.shape[0]
n, d = self.num_heads, self.head_dim
q = self.q(x)
if self.norm_q is not None:
q = self.norm_q(q)
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
if kv_cache is not None:
k, v = kv_cache
else:
k = self.k(context)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
# Optional context masking
mask = None
if context_lens is not None:
ctx_len = k.shape[2]
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
for i, cl in enumerate(context_lens):
mask[i, :, :, cl:] = -1e9
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out)

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union
from mlx_video.models.ltx.config import BaseModelConfig
@dataclass
class WanModelConfig(BaseModelConfig):
"""Configuration for Wan T2V models (supports both 2.1 and 2.2)."""
model_type: str = "t2v"
model_version: str = "2.2"
patch_size: Tuple[int, int, int] = (1, 2, 2)
text_len: int = 512
in_dim: int = 16
dim: int = 5120
ffn_dim: int = 13824
freq_dim: int = 256
text_dim: int = 4096
out_dim: int = 16
num_heads: int = 40
num_layers: int = 40
window_size: Tuple[int, int] = (-1, -1)
qk_norm: bool = True
cross_attn_norm: bool = True
eps: float = 1e-6
# VAE
vae_stride: Tuple[int, int, int] = (4, 8, 8)
vae_z_dim: int = 16
# Inference
dual_model: bool = True
boundary: float = 0.875
sample_shift: float = 12.0
sample_steps: int = 40
sample_guide_scale: Union[float, Tuple[float, float]] = (3.0, 4.0)
num_train_timesteps: int = 1000
sample_fps: int = 16
frame_num: int = 81
# T5
t5_vocab_size: int = 256384
t5_dim: int = 4096
t5_dim_attn: int = 4096
t5_dim_ffn: int = 10240
t5_num_heads: int = 64
t5_num_layers: int = 24
t5_num_buckets: int = 32
@property
def head_dim(self) -> int:
return self.dim // self.num_heads
@classmethod
def wan21_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.1 T2V 14B: single model, 40 layers, dim=5120."""
return cls(
model_version="2.1",
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
)
@classmethod
def wan21_t2v_1_3b(cls) -> "WanModelConfig":
"""Wan2.1 T2V 1.3B: single model, 30 layers, dim=1536."""
return cls(
model_version="2.1",
dim=1536,
ffn_dim=8960,
num_heads=12,
num_layers=30,
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
)
@classmethod
def wan22_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
return cls()

View File

@@ -0,0 +1,307 @@
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .attention import WanLayerNorm
from .config import WanModelConfig
from .rope import rope_params
from .transformer import WanAttentionBlock
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
"""Compute sinusoidal positional embeddings.
Args:
dim: Embedding dimension (must be even).
position: 1D tensor of positions.
Returns:
Embeddings of shape [len(position), dim].
"""
assert dim % 2 == 0
half = dim // 2
pos = position.astype(mx.float32)
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
sinusoid = pos[:, None] * inv_freq[None, :]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
class Head(nn.Module):
"""Output projection head with learned modulation."""
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
super().__init__()
self.out_dim = out_dim
self.patch_size = patch_size
proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim)
self.modulation = mx.random.normal((1, 2, dim)) * (dim**-0.5)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
"""
Args:
x: [B, L, dim]
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
e_f32 = e.astype(mx.float32)
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
x_norm = self.norm(x).astype(mx.float32)
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
return self.head(x_mod.astype(x.dtype))
class WanModel(nn.Module):
"""Wan2.2 diffusion backbone for text-to-video generation."""
def __init__(self, config: WanModelConfig):
super().__init__()
self.config = config
dim = config.dim
self.dim = dim
self.num_heads = config.num_heads
self.out_dim = config.out_dim
self.patch_size = config.patch_size
self.text_len = config.text_len
self.freq_dim = config.freq_dim
# Patch embedding: Conv3d implemented as a reshaped linear
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
patch_dim = config.in_dim * math.prod(config.patch_size)
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
self._patch_size = config.patch_size
# Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="precise")
self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
self.time_embedding_act = nn.SiLU()
self.time_embedding_1 = nn.Linear(dim, dim)
# Time projection for modulation (6x dim)
self.time_projection_act = nn.SiLU()
self.time_projection = nn.Linear(dim, dim * 6)
# Transformer blocks
self.blocks = [
WanAttentionBlock(
dim=dim,
ffn_dim=config.ffn_dim,
num_heads=config.num_heads,
window_size=config.window_size,
qk_norm=config.qk_norm,
cross_attn_norm=config.cross_attn_norm,
eps=config.eps,
)
for _ in range(config.num_layers)
]
# Output head
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
# Precompute RoPE frequencies
d = dim // config.num_heads
d_t = d - 4 * (d // 6)
d_h = 2 * (d // 6)
d_w = 2 * (d // 6)
# Each rope_params returns [1024, d_x//2, 2]
freqs_t = rope_params(1024, d_t)
freqs_h = rope_params(1024, d_h)
freqs_w = rope_params(1024, d_w)
# Concatenate along the frequency dimension: [1024, d//2, 2]
self.freqs = mx.concatenate([freqs_t, freqs_h, freqs_w], axis=1)
def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings.
Args:
x: Video latent [C, F, H, W]
Returns:
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
"""
c, f, h, w = x.shape
pt, ph, pw = self._patch_size
f_out = f // pt
h_out = h // ph
w_out = w // pw
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
# Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out)
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
"""Reconstruct video from patch embeddings.
Args:
x: [B, L, out_dim * prod(patch_size)]
grid_sizes: List of (F', H', W') per batch element
Returns:
List of tensors [C, F, H, W]
"""
c = self.out_dim
pt, ph, pw = self.patch_size
out = []
for i, (f, h, w) in enumerate(grid_sizes):
seq_len = f * h * w
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
u = u.reshape(f, h, w, pt, ph, pw, c)
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
u = u.reshape(c, f * pt, h * ph, w * pw)
out.append(u)
return out
def embed_text(self, context: list) -> mx.array:
"""Precompute text embeddings (call once, reuse across steps).
Args:
context: List of text embeddings [L_text, text_dim]
Returns:
Embedded context [B, text_len, dim] in model dtype
"""
model_dtype = self.patch_embedding_proj.weight.dtype
context_padded = []
for ctx in context:
pad_len = self.text_len - ctx.shape[0]
if pad_len > 0:
ctx = mx.concatenate(
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
axis=0,
)
context_padded.append(ctx)
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
context_batch = self.text_embedding_1(
self.text_embedding_act(self.text_embedding_0(context_batch))
)
return context_batch.astype(model_dtype)
def prepare_cross_kv(self, context: mx.array) -> list:
"""Pre-compute cross-attention K/V for all blocks.
Call once before the diffusion loop to cache K/V projections,
eliminating redundant computation at each denoising step.
Args:
context: Pre-embedded text [B, text_len, dim]
Returns:
List of (k, v) tuples, one per block
"""
kv_caches = []
for block in self.blocks:
kv_caches.append(block.cross_attn.prepare_kv(context))
return kv_caches
def __call__(
self,
x_list: list,
t: mx.array,
context: list | mx.array,
seq_len: int,
cross_kv_caches: list | None = None,
) -> list:
"""Forward pass.
Args:
x_list: List of video latent tensors [C, F, H, W]
t: Timestep tensor [B]
context: List of raw text embeddings, OR pre-embedded tensor
from embed_text() [B, text_len, dim]
seq_len: Maximum sequence length for padding
cross_kv_caches: Optional list of (k, v) tuples from
prepare_cross_kv(), one per block.
Returns:
List of denoised tensors [C, F, H, W]
"""
# Patchify each video
patches = []
grid_sizes = []
seq_lens_list = []
for vid in x_list:
p, gs = self._patchify(vid) # [1, L, dim]
patches.append(p)
grid_sizes.append(gs)
seq_lens_list.append(p.shape[1])
# Pad and batch
batch_size = len(patches)
x = mx.concatenate(
[
mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
)
if p.shape[1] < seq_len
else p
for p in patches
],
axis=0,
) # [B, seq_len, dim]
# Time embedding: compute once per sample, then broadcast to all tokens
if t.ndim == 0:
t = t[None]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
model_dtype = self.patch_embedding_proj.weight.dtype
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
e = e.astype(model_dtype)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):
# Pre-embedded: expand to batch size if needed
context_batch = context
if context_batch.shape[0] == 1 and batch_size > 1:
context_batch = mx.broadcast_to(
context_batch, (batch_size,) + context_batch.shape[1:]
)
else:
context_batch = self.embed_text(context)
# Run transformer blocks
kwargs = dict(
e=e0,
seq_lens=seq_lens_list,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context_batch,
context_lens=None,
)
for i, block in enumerate(self.blocks):
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
x = block(x, cross_kv_cache=kv, **kwargs)
# Output head
x = self.head(x, e)
# Unpatchify
outputs = self.unpatchify(x, grid_sizes)
return [u.astype(mx.float32) for u in outputs]

View File

@@ -0,0 +1,100 @@
import math
import mlx.core as mx
import numpy as np
def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
"""Precompute RoPE frequency parameters as complex numbers.
Returns:
Complex frequency tensor of shape [max_seq_len, dim // 2].
"""
assert dim % 2 == 0
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
1.0
/ np.power(
theta,
np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
cos_freqs = np.cos(freqs).astype(np.float32)
sin_freqs = np.sin(freqs).astype(np.float32)
return mx.array(np.stack([cos_freqs, sin_freqs], axis=-1))
def rope_apply(
x: mx.array,
grid_sizes: list,
freqs: mx.array,
) -> mx.array:
"""Apply 3-way factorized RoPE to Q or K tensor.
Args:
x: Shape [B, L, num_heads, head_dim]
grid_sizes: List of (F, H, W) tuples per batch element
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
"""
b, s, n, d = x.shape
half_d = d // 2
# Cast freqs to input dtype to prevent float32 promotion cascade
if freqs.dtype != x.dtype:
freqs = freqs.astype(x.dtype)
# Split frequency dimensions: temporal gets more capacity
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
# Split freqs along dim axis
freqs_t = freqs[:, :d_t] # [1024, d_t, 2]
freqs_h = freqs[:, d_t : d_t + d_h] # [1024, d_h, 2]
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] # [1024, d_w, 2]
outputs = []
for i in range(b):
f, h, w = grid_sizes[i]
seq_len = f * h * w
# Reshape x to pairs for rotation: [seq_len, n, half_d, 2]
x_i = x[i, :seq_len].reshape(seq_len, n, half_d, 2)
# Build per-position frequencies by expanding along grid dims
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
ft = mx.broadcast_to(
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
)
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
fh = mx.broadcast_to(
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
)
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
fw = mx.broadcast_to(
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
)
# Concatenate: [f*h*w, half_d, 2]
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
# Apply rotation: (a + bi) * (cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
cos_f = freqs_i[..., 0] # [seq_len, 1, half_d]
sin_f = freqs_i[..., 1] # [seq_len, 1, half_d]
x_real = x_i[..., 0] # [seq_len, n, half_d]
x_imag = x_i[..., 1] # [seq_len, n, half_d]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
# Interleave back: [seq_len, n, half_d, 2] -> [seq_len, n, d]
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, d)
# Handle padding: keep non-rotated tokens after seq_len
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[i, seq_len:]], axis=0)
outputs.append(x_rotated)
return mx.stack(outputs)

View File

@@ -0,0 +1,76 @@
"""Flow matching scheduler for Wan2.2 inference."""
import numpy as np
import mlx.core as mx
class FlowMatchEulerScheduler:
"""Simple Euler scheduler for flow matching diffusion.
Implements the flow matching formulation where the model predicts
velocity (flow) and we use Euler steps to denoise.
"""
def __init__(self, num_train_timesteps: int = 1000):
self.num_train_timesteps = num_train_timesteps
self.timesteps = None
self.sigmas = None
def set_timesteps(self, num_steps: int, shift: float = 1.0):
"""Compute sigma schedule with shift.
Args:
num_steps: Number of inference steps.
shift: Noise schedule shift factor.
"""
# Linear spacing from sigma_max to sigma_min
sigmas = np.linspace(1.0, 1.0 / self.num_train_timesteps, self.num_train_timesteps)[::-1]
sigmas = 1.0 - sigmas
# Select evenly spaced subset
indices = np.linspace(0, len(sigmas) - 1, num_steps + 1).astype(int)
sigmas = sigmas[indices[:-1]]
# Apply shift: sigma' = shift * sigma / (1 + (shift - 1) * sigma)
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
# Convert to timesteps
timesteps = sigmas * self.num_train_timesteps
self.timesteps = mx.array(timesteps.astype(np.float32))
# Append terminal sigma=0
sigmas = np.append(sigmas, 0.0)
self.sigmas = mx.array(sigmas.astype(np.float32))
self._step_index = 0
def step(
self,
model_output: mx.array,
timestep,
sample: mx.array,
) -> mx.array:
"""Euler step for flow matching.
In flow matching, model predicts velocity v, and:
x_{t-1} = sample + (sigma_{t-1} - sigma_t) * v
Args:
model_output: Predicted velocity [B, C, T, H, W]
timestep: Current timestep (unused, step index is tracked internally)
sample: Current noisy sample [B, C, T, H, W]
Returns:
Updated sample
"""
# Use Python floats to avoid creating mx.array scalars that
# could trigger type promotion (per fast-mlx guide)
dt = float(self.sigmas[self._step_index + 1].item()) - float(self.sigmas[self._step_index].item())
x_next = sample + dt * model_output
self._step_index += 1
return x_next
def reset(self):
"""Reset step counter for new generation."""
self._step_index = 0

View File

@@ -0,0 +1,234 @@
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
import math
import mlx.core as mx
import mlx.nn as nn
class T5LayerNorm(nn.Module):
"""RMS-based layer normalization (T5 style)."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class T5RelativeEmbedding(nn.Module):
"""T5-style relative position bias with bucketing."""
def __init__(
self,
num_buckets: int,
num_heads: int,
bidirectional: bool = True,
max_dist: int = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
self.embedding = nn.Embedding(num_buckets, num_heads)
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = (
max_exact
+ (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = mx.minimum(
rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
)
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array:
positions_k = mx.arange(lk)[None, :] # [1, lk]
positions_q = mx.arange(lq)[:, None] # [lq, 1]
rel_pos = positions_k - positions_q # [lq, lk]
buckets = self._relative_position_bucket(rel_pos)
embeds = self.embedding(buckets) # [lq, lk, num_heads]
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
return embeds
class T5Attention(nn.Module):
"""T5-style multi-head attention (no scaling)."""
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert dim_attn % num_heads == 0
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
def __call__(
self,
x: mx.array,
context: mx.array | None = None,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
context = x if context is None else context
b, n, c = x.shape[0], self.num_heads, self.head_dim
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
v = self.v(context).reshape(b, -1, n, c)
# T5 does not use scaling
# attn = einsum('binc,bjnc->bnij', q, k)
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Combine position bias and attention mask for SDPA
attn_mask = None
if pos_bias is not None:
attn_mask = pos_bias.astype(q.dtype)
if mask is not None:
if mask.ndim == 2:
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
elif mask.ndim == 3:
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype)
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask
# T5 uses no scaling (scale=1.0)
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=1.0, mask=attn_mask
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c)
return self.o(out)
class T5FeedForward(nn.Module):
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = nn.GELU(approx="precise")
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
class T5SelfAttentionBlock(nn.Module):
"""T5 encoder block: self-attention + FFN."""
def __init__(
self,
dim: int,
dim_attn: int,
dim_ffn: int,
num_heads: int,
num_buckets: int,
shared_pos: bool = True,
):
super().__init__()
self.shared_pos = shared_pos
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = x + self.ffn(self.norm2(x))
return x
class T5Encoder(nn.Module):
"""T5 Encoder (UMT5-XXL configuration)."""
def __init__(
self,
vocab_size: int = 256384,
dim: int = 4096,
dim_attn: int = 4096,
dim_ffn: int = 10240,
num_heads: int = 64,
num_layers: int = 24,
num_buckets: int = 32,
shared_pos: bool = False,
):
super().__init__()
self.dim = dim
self.token_embedding = nn.Embedding(vocab_size, dim)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.blocks = [
T5SelfAttentionBlock(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
)
for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
"""
Args:
ids: Token IDs [B, L]
mask: Attention mask [B, L]
Returns:
Hidden states [B, L, dim]
"""
x = self.token_embedding(ids)
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
for block in self.blocks:
x = block(x, mask=mask, pos_bias=e)
x = self.norm(x)
return x

View File

@@ -0,0 +1,89 @@
import mlx.core as mx
import mlx.nn as nn
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
class WanAttentionBlock(nn.Module):
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
cross_attn_norm: bool = False,
eps: float = 1e-6,
):
super().__init__()
# Self-attention
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
# Cross-attention (with optional norm on context)
self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else None
)
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
# Feed-forward
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate
self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5)
def __call__(
self,
x: mx.array,
e: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
context: mx.array,
context_lens: list | None = None,
cross_kv_cache: tuple | None = None,
) -> mx.array:
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
e0 = mod[:, :, 0, :] # shift for self-attn
e1 = mod[:, :, 1, :] # scale for self-attn
e2 = mod[:, :, 2, :] # gate for self-attn
e3 = mod[:, :, 3, :] # shift for ffn
e4 = mod[:, :, 4, :] # scale for ffn
e5 = mod[:, :, 5, :] # gate for ffn
# Self-attention with modulation
x_mod = self.norm1(x) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
x = x + y * e2
# Cross-attention (no modulation, just norm)
x_cross = self.norm3(x) if self.norm3 is not None else x
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
# FFN with modulation
x_mod = self.norm2(x) * (1 + e4) + e3
y = self.ffn(x_mod)
x = x + y * e5
return x
class WanFFN(nn.Module):
"""Gated feed-forward network with GELU(tanh) activation."""
def __init__(self, dim: int, ffn_dim: int):
super().__init__()
self.fc1 = nn.Linear(dim, ffn_dim)
self.act = nn.GELU(approx="precise")
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.act(self.fc1(x)))

315
mlx_video/models/wan/vae.py Normal file
View File

@@ -0,0 +1,315 @@
"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8).
Module structure mirrors original PyTorch checkpoint key hierarchy
so weights load directly without key sanitization.
"""
import mlx.core as mx
import mlx.nn as nn
import numpy as np
CACHE_T = 2
# Per-channel normalization statistics for z_dim=16
VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
]
VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
]
class CausalConv3d(nn.Module):
"""3D convolution with causal temporal padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple,
stride: int | tuple = 1,
padding: int | tuple = 0,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
self.kernel_size = kernel_size
self.stride = stride
self._causal_pad_t = 2 * padding[0]
self._pad_h = padding[1]
self._pad_w = padding[2]
# MLX Conv3d: weight shape [O, D, H, W, I]
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
self.bias = mx.zeros((out_channels,))
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, C, T, H, W] (channel-first)"""
b, c, t, h, w = x.shape
if self._causal_pad_t > 0:
pad_t = mx.zeros((b, c, self._causal_pad_t, h, w), dtype=x.dtype)
x = mx.concatenate([pad_t, x], axis=2)
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
out = self._conv3d(x)
return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W']
def _conv3d(self, x: mx.array) -> mx.array:
"""3D conv via sliding window + 2D conv per time step.
x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out]
"""
b, t, h, w, c_in = x.shape
kt, kh, kw = self.kernel_size
st, sh, sw = self.stride
t_out = (t - kt) // st + 1
# Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I]
w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape(
self.weight.shape[0], kh, kw, kt * c_in
)
outputs = []
for t_i in range(t_out):
t_start = t_i * st
window = x[:, t_start : t_start + kt]
window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in)
out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias
outputs.append(out_2d)
return mx.stack(outputs, axis=1)
class RMS_norm(nn.Module):
"""Channel-first L2 normalization matching original Wan VAE.
Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm.
images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input.
images=False: gamma shape (dim, 1, 1, 1) for 5D video input.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True):
super().__init__()
self.channel_first = channel_first
self.scale = dim**0.5
if channel_first:
broadcastable = (1, 1) if images else (1, 1, 1)
self.gamma = mx.ones((dim, *broadcastable))
else:
self.gamma = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
norm_dim = 1 if self.channel_first else -1
# L2 normalize along channel dim (matches F.normalize)
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
return (x / norm) * self.scale * self.gamma
class ResidualBlock(nn.Module):
"""Residual block with causal 3D convolutions.
Uses `residual` list with None gaps to match original PyTorch
nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm,
[4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params.
"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.residual = [
RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU
None, # [5] Dropout
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
]
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
def __call__(self, x: mx.array) -> mx.array:
h = x if self.shortcut is None else self.shortcut(x)
x = nn.silu(self.residual[0](x))
x = self.residual[2](x)
x = nn.silu(self.residual[3](x))
x = self.residual[6](x)
return x + h
class AttentionBlock(nn.Module):
"""Single-head spatial self-attention."""
def __init__(self, dim: int):
super().__init__()
self.norm = RMS_norm(dim, images=True)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, C, T, H, W]"""
identity = x
b, c, t, h, w = x.shape
# [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C]
x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.norm(x)
x = x.transpose(0, 2, 3, 1) # [BT, H, W, C]
qkv = self.to_qkv(x) # [BT, H, W, 3C]
qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q[:, None, :, :] # [BT, 1, HW, C]
k = k[:, None, :, :]
v = v[:, None, :, :]
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5)
out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C]
out = self.proj(out) # [BT, H, W, C]
out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W]
return out + identity
class Resample(nn.Module):
"""Upsample block matching original Wan VAE structure.
Uses `resample` list with [None, Conv2d] to match original
nn.Sequential(Upsample, Conv2d) where index 1 has the conv params.
"""
def __init__(self, dim: int, mode: str):
super().__init__()
assert mode in ("upsample2d", "upsample3d")
self.mode = mode
self.dim = dim
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
if mode == "upsample3d":
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, C, T, H, W]"""
b, c, t, h, w = x.shape
if self.mode == "upsample3d":
# Temporal upsample via learned conv
x_t = self.time_conv(x) # [B, 2C, T, H, W]
x_t = x_t.reshape(b, 2, c, t, h, w)
# Interleave along time: [B, C, 2T, H, W]
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
t = t * 2
# Per-frame spatial upsample: nearest 2x + Conv2d
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
x = mx.repeat(x, 2, axis=1)
x = mx.repeat(x, 2, axis=2)
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
c_out = x.shape[-1]
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
class Decoder3d(nn.Module):
"""3D VAE Decoder matching Wan2.1 architecture.
Uses flat `middle` and `upsamples` lists to match original
PyTorch nn.Sequential weight key hierarchy.
"""
def __init__(
self,
dim: int = 96,
z_dim: int = 16,
dim_mult: list = None,
num_res_blocks: int = 2,
temporal_upsample: list = None,
):
super().__init__()
if dim_mult is None:
dim_mult = [1, 2, 4, 4]
if temporal_upsample is None:
temporal_upsample = [True, True, False]
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# Middle: [ResBlock, AttentionBlock, ResBlock]
self.middle = [
ResidualBlock(dims[0], dims[0]),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0]),
]
# Flat upsample list matching original nn.Sequential indexing
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
if i in (1, 2, 3):
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim))
in_dim = out_dim
if i != len(dim_mult) - 1:
mode = "upsample3d" if temporal_upsample[i] else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
self.upsamples = upsamples
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
]
def __call__(self, x: mx.array) -> mx.array:
"""x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]"""
x = self.conv1(x)
for layer in self.middle:
x = layer(x)
for layer in self.upsamples:
x = layer(x)
x = nn.silu(self.head[0](x))
x = self.head[2](x)
return x
class WanVAE(nn.Module):
"""Wan2.1 VAE wrapper with per-channel normalization."""
def __init__(self, z_dim: int = 16):
super().__init__()
self.z_dim = z_dim
self.mean = mx.array(VAE_MEAN)
self.std = mx.array(VAE_STD)
self.inv_std = 1.0 / self.std
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
def decode(self, z: mx.array) -> mx.array:
"""Decode latent to video.
Args:
z: Normalized latent [B, z_dim, T, H, W]
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
mean = self.mean.reshape(1, -1, 1, 1, 1)
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
z = z / inv_std + mean
x = self.conv2(z)
out = self.decoder(x)
return mx.clip(out, -1, 1)

View File

@@ -0,0 +1,584 @@
"""Wan2.2 VAE Decoder (compression 4×16×16, z_dim=48).
Architecture differs from Wan2.1 VAE: uses RMS_norm, DupUp3D shortcuts,
spatial patchify (2×2), and different temporal upsampling pattern.
Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
conversion (channels-first → channels-last) is needed.
"""
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
VAE22_MEAN = mx.array([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
])
VAE22_STD = mx.array([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
])
class CausalConv3d(nn.Module):
"""3D causal convolution. Input/output: [B, T, H, W, C] (channels-last).
Decomposes the 3D conv into per-frame 2D convolutions to avoid
excessive memory usage from MLX's conv3d implementation.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
self.kernel_size = kernel_size
self.stride = stride
self._causal_pad_t = 2 * padding[0]
self._pad_h = padding[1]
self._pad_w = padding[2]
# Weight: [O, D, H, W, I] for MLX
self.weight = mx.zeros((
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
))
self.bias = mx.zeros((out_channels,))
def __call__(self, x):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
kd, kh, kw = self.kernel_size
# For 1x1x1 kernel or kernel_d==1, use direct conv
if kd == 1 and kh == 1 and kw == 1:
# Simple pointwise: reshape to [B*T, 1, 1, C] → conv2d
x_flat = x.reshape(B * T, H, W, C)
w2d = self.weight[:, 0, :, :, :] # [O, kH, kW, I]
y = mx.conv_general(x_flat, w2d) + self.bias
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
# Causal temporal padding (left only)
if self._causal_pad_t > 0:
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
x = mx.concatenate([pad_t, x], axis=1)
# Spatial padding
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
(self._pad_w, self._pad_w), (0, 0)])
T_padded = x.shape[1]
H_padded, W_padded = x.shape[2], x.shape[3]
T_out = (T_padded - kd) // self.stride[0] + 1
# Decompose 3D conv into sum of 2D convolutions over temporal kernel
# weight shape: [O, kd, kh, kw, I] → split into kd 2D kernels [O, kh, kw, I]
outputs = []
for t in range(T_out):
t_start = t * self.stride[0]
# Sum 2D convs for each temporal kernel position
accum = None
for d in range(kd):
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
conv_out = mx.conv_general(frame, w2d,
stride=(self.stride[1], self.stride[2]))
accum = conv_out if accum is None else accum + conv_out
outputs.append(accum + self.bias)
return mx.stack(outputs, axis=1) # [B, T_out, H_out, W_out, O]
class RMS_norm(nn.Module):
"""RMS normalization along channel dimension."""
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
self.gamma = mx.ones((dim,))
def __call__(self, x):
# x: [..., C] (channels-last)
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
class ResidualBlock(nn.Module):
"""Residual block: RMS_norm → SiLU → CausalConv3d × 2 + shortcut."""
def __init__(self, in_dim, out_dim):
super().__init__()
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
# We store as named layers matching PyTorch's indices
self.residual = ResidualBlockLayers(in_dim, out_dim)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim
else None
)
def __call__(self, x):
h = self.shortcut(x) if self.shortcut is not None else x
return self.residual(x) + h
class ResidualBlockLayers(nn.Module):
"""The sequential layers inside a ResidualBlock.
PyTorch stores these as nn.Sequential with indices 0-6:
[0] RMS_norm, [1] SiLU, [2] CausalConv3d, [3] RMS_norm, [4] SiLU, [5] Dropout, [6] CausalConv3d
We use matching attribute names for weight compatibility.
"""
def __init__(self, in_dim, out_dim):
super().__init__()
# Indices match PyTorch nn.Sequential indices for weight key compat
# Index 0: RMS_norm
self.layer_0 = RMS_norm(in_dim)
# Index 2: CausalConv3d
self.layer_2 = CausalConv3d(in_dim, out_dim, 3, padding=1)
# Index 3: RMS_norm
self.layer_3 = RMS_norm(out_dim)
# Index 6: CausalConv3d
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
def __call__(self, x):
x = self.layer_0(x)
x = nn.silu(x)
x = self.layer_2(x)
mx.eval(x) # Eval between convolutions to limit graph size
x = self.layer_3(x)
x = nn.silu(x)
x = self.layer_6(x)
return x
class AttentionBlock(nn.Module):
"""2D self-attention applied per frame. Input: [B, T, H, W, C]."""
def __init__(self, dim):
super().__init__()
self.dim = dim
self.norm = RMS_norm(dim)
# Conv2d as linear per spatial position — weight [O, H, W, I] for MLX
# to_qkv: dim -> 3*dim, proj: dim -> dim (1x1 conv2d)
self.to_qkv_weight = mx.zeros((3 * dim, 1, 1, dim))
self.to_qkv_bias = mx.zeros((3 * dim,))
self.proj_weight = mx.zeros((dim, 1, 1, dim))
self.proj_bias = mx.zeros((dim,))
def __call__(self, x):
# x: [B, T, H, W, C]
identity = x
B, T, H, W, C = x.shape
# Apply per frame: merge B and T
x = x.reshape(B * T, H, W, C)
x = self.norm(x)
# QKV via 1x1 conv2d (equivalent to linear on last dim)
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
qkv = qkv.reshape(B * T, H * W, 3 * C)
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
# Single-head attention
q = q[:, None, :, :] # [BT, 1, HW, C]
k = k[:, None, :, :]
v = v[:, None, :, :]
scale = C ** -0.5
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
out = out.squeeze(1).reshape(B * T, H, W, C)
# Project output
out = mx.conv_general(out, self.proj_weight) + self.proj_bias # [BT, H, W, C]
out = out.reshape(B, T, H, W, C)
return out + identity
class DupUp3D(nn.Module):
"""Upsample by duplicating channels and reshaping. No learnable parameters."""
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = factor_t * factor_s * factor_s
self.repeats = out_channels * self.factor // in_channels
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
# Repeat channels
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
if first_chunk:
x = x[:, self.factor_t - 1:, :, :, :]
return x
class Resample(nn.Module):
"""Spatial up/downsampling with optional temporal up/downsampling."""
def __init__(self, dim, mode):
super().__init__()
self.dim = dim
self.mode = mode
if mode == "upsample2d":
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample_weight = mx.zeros((dim, 3, 3, dim)) # Conv2d [O, H, W, I]
self.resample_bias = mx.zeros((dim,))
elif mode == "upsample3d":
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
else:
raise ValueError(f"Unsupported mode: {mode}")
def _upsample2x(self, x):
"""Nearest-neighbor 2x spatial upsample. x: [N, H, W, C]."""
N, H, W, C = x.shape
# Repeat along H and W axes separately
x = mx.repeat(x, repeats=2, axis=1) # [N, 2H, W, C]
x = mx.repeat(x, repeats=2, axis=2) # [N, 2H, 2W, C]
return x
def _conv2d(self, x):
"""Apply the Conv2d with padding=1. x: [N, H, W, C]."""
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight) + self.resample_bias
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
if self.mode == "upsample3d":
# Temporal upsample via time_conv
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
# Split into two interleaved temporal streams
tc_out = tc_out.reshape(B, T, H, W, 2, C)
# Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C]
stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C]
stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C]
x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C]
x = x.reshape(B, T * 2, H, W, C)
if first_chunk:
# PyTorch skips time_conv for first chunk entirely. In all-at-once
# mode, we trim the first frame to match (the first interleaved
# frame is from zero-padded causal context and shouldn't be kept).
x = x[:, 1:, :, :, :]
mx.eval(x)
T = x.shape[1]
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8
chunks = []
for t_start in range(0, T, chunk_size):
t_end = min(t_start + chunk_size, T)
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
x_chunk = self._upsample2x(x_chunk)
x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk)
chunks.append(x_chunk)
x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C)
return x
class Up_ResidualBlock(nn.Module):
"""Upsampling residual block with optional DupUp3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
super().__init__()
self.up_flag = up_flag
# DupUp3D shortcut (no learnable params)
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim, out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
else:
self.avg_shortcut = None
# Main path: ResidualBlocks + optional Resample
blocks = []
dim_in = in_dim
for _ in range(num_res_blocks):
blocks.append(ResidualBlock(dim_in, out_dim))
dim_in = out_dim
if up_flag:
mode = "upsample3d" if temperal_upsample else "upsample2d"
blocks.append(Resample(out_dim, mode=mode))
self.upsamples = blocks
def __call__(self, x, first_chunk=False):
x_main = x
for module in self.upsamples:
if isinstance(module, Resample):
x_main = module(x_main, first_chunk)
else:
x_main = module(x_main)
mx.eval(x_main) # Limit graph size per sub-block
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
mx.eval(x_shortcut)
return x_main + x_shortcut
return x_main
class Decoder3d(nn.Module):
"""Wan2.2 3D VAE Decoder."""
def __init__(
self,
dim=256,
z_dim=48,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_upsample=(True, True, False),
):
super().__init__()
# Compute layer dimensions
dims = [dim * dim_mult[-1]] + [dim * m for m in reversed(dim_mult)]
# Initial conv
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# Middle blocks
self.middle = [
ResidualBlock(dims[0], dims[0]),
AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0]),
]
# Upsample blocks
self.upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
self.upsamples.append(Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
))
# Output head: [RMS_norm, SiLU, CausalConv3d]
self.head = Head22(dims[-1])
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C=z_dim]
x = self.conv1(x)
for layer in self.middle:
x = layer(x)
mx.eval(x) # Evaluate to limit graph size
for i, layer in enumerate(self.upsamples):
x = layer(x, first_chunk)
mx.eval(x) # Evaluate after each upsample block
x = self.head(x)
return x
class Head22(nn.Module):
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
PyTorch key mapping: head.0 = RMS_norm, head.2 = CausalConv3d
(index 1 = SiLU has no params)
"""
def __init__(self, dim, out_channels=12):
super().__init__()
# Index 0: RMS_norm
self.layer_0 = RMS_norm(dim)
# Index 2: CausalConv3d
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
def __call__(self, x):
x = self.layer_0(x)
x = nn.silu(x)
x = self.layer_2(x)
return x
class Wan22VAEDecoder(nn.Module):
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
def __init__(self, z_dim=48, dim=160, dec_dim=256):
super().__init__()
self.z_dim = z_dim
# conv2: 1x1x1 conv before decoder
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(
dim=dec_dim,
z_dim=z_dim,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_upsample=(True, True, False),
)
def __call__(self, z):
"""Decode latents to video.
Args:
z: [B, T, H, W, C=48] latent tensor (already denormalized)
Returns:
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
"""
x = self.conv2(z)
# All-at-once decode with first_chunk=True to trim extra temporal
# frames from causal padding (matches PyTorch's chunked behavior)
out = self.decoder(x, first_chunk=True)
# Unpatchify: 12 channels → 3 RGB (spatial 2×2)
out = _unpatchify(out, patch_size=2)
return mx.clip(out, -1.0, 1.0)
def denormalize_latents(z, mean=None, std=None):
"""Denormalize latents: z = z / (1/std) + mean."""
if mean is None:
mean = VAE22_MEAN
if std is None:
std = VAE22_STD
inv_scale = std # scale was 1/std, so divide by scale = multiply by std
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
def _unpatchify(x, patch_size=2):
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
PyTorch: b (c r q) f h w -> b c f (h q) (w r) with q=p, r=p
In channels-last: [B, T, H, W, C*r*q] -> [B, T, H*q, W*r, C]
"""
if patch_size == 1:
return x
B, T, H, W, Cpacked = x.shape
C = Cpacked // (patch_size * patch_size)
# Reshape: [B, T, H, W, r, q, C] then rearrange to [B, T, H*q, W*r, C]
# PyTorch patchify: "b c f (h q) (w r) -> b (c r q) f h w" — so c is packed as (c, r, q)
# Unpatchify reverses: [B, T, H, W, (C, r, q)] -> [B, T, H, q, W, r, C]
x = x.reshape(B, T, H, W, C, patch_size, patch_size)
# Rearrange: put q next to H, r next to W
x = x.transpose(0, 1, 2, 6, 3, 5, 4) # [B, T, H, q, W, r, C]
x = x.reshape(B, T, H * patch_size, W * patch_size, C)
return x
def sanitize_wan22_vae_weights(weights: dict) -> dict:
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
Transposes conv weights from channels-first to channels-last.
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
Maps PyTorch nn.Sequential indices to our named layers.
"""
sanitized = {}
for key, value in weights.items():
# Skip encoder and conv1 (encoder-only)
if key.startswith("encoder.") or key.startswith("conv1."):
continue
new_key = key
# Map nn.Sequential indexed layers to our named attributes
# ResidualBlockLayers: indices 0, 2, 3, 6 → _layer_0, _layer_2, _layer_3, _layer_6
# Head22: indices 0, 2 → _layer_0, _layer_2
for idx in ["0", "2", "3", "6"]:
# Match patterns like "residual.0.gamma" → "residual.layer_0.gamma"
# or "head.0.gamma" → "head.layer_0.gamma"
old_pattern = f".residual.{idx}."
new_pattern = f".residual.layer_{idx}."
new_key = new_key.replace(old_pattern, new_pattern)
# Head layer mapping: head.0.gamma → head.layer_0.gamma, head.2.weight → head.layer_2.weight
for idx in ["0", "2"]:
old_pattern = f".head.{idx}."
new_pattern = f".head.layer_{idx}."
new_key = new_key.replace(old_pattern, new_pattern)
# Map Resample Conv2d: resample.1.weight → resample_weight, resample.1.bias → resample_bias
if ".resample.1.weight" in new_key:
new_key = new_key.replace(".resample.1.weight", ".resample_weight")
elif ".resample.1.bias" in new_key:
new_key = new_key.replace(".resample.1.bias", ".resample_bias")
# Map AttentionBlock Conv2d weights
if ".to_qkv.weight" in new_key:
new_key = new_key.replace(".to_qkv.weight", ".to_qkv_weight")
elif ".to_qkv.bias" in new_key:
new_key = new_key.replace(".to_qkv.bias", ".to_qkv_bias")
elif ".proj.weight" in new_key and "time_projection" not in new_key:
new_key = new_key.replace(".proj.weight", ".proj_weight")
elif ".proj.bias" in new_key and "time_projection" not in new_key:
new_key = new_key.replace(".proj.bias", ".proj_bias")
# Transpose conv weights to channels-last
is_weight = new_key.endswith(".weight") or new_key.endswith("_weight")
if is_weight:
if value.ndim == 5:
# Conv3d: [O, I, D, H, W] → [O, D, H, W, I]
value = np.transpose(np.array(value), (0, 2, 3, 4, 1))
value = mx.array(value)
elif value.ndim == 4:
# Conv2d: [O, I, H, W] → [O, H, W, I]
value = np.transpose(np.array(value), (0, 2, 3, 1))
value = mx.array(value)
# Squeeze RMS_norm gamma: (dim, 1, 1, 1) or (dim, 1, 1) → (dim,)
if "gamma" in new_key:
value = mx.array(np.array(value).squeeze())
sanitized[new_key] = value
return sanitized