feat(wan): Add Wan2.2 I2V support

This commit is contained in:
Daniel
2026-02-27 13:46:23 +01:00
parent 93da550f65
commit 2bb95c61ed
26 changed files with 4401 additions and 2968 deletions

View File

@@ -71,8 +71,12 @@ class WanSelfAttention(nn.Module):
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
q = self.q(x)
k = self.k(x)
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
x_w = x.astype(w_dtype)
q = self.q(x_w)
k = self.k(x_w)
if self.norm_q is not None:
q = self.norm_q(q)
if self.norm_k is not None:
@@ -80,15 +84,15 @@ class WanSelfAttention(nn.Module):
q = q.reshape(b, s, n, d)
k = k.reshape(b, s, n, d)
v = self.v(x).reshape(b, s, n, d)
v = self.v(x_w).reshape(b, s, n, d)
# Apply RoPE
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Build attention mask from seq_lens
@@ -149,11 +153,14 @@ class WanCrossAttention(nn.Module):
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
k = self.k(context)
# Cast to weight dtype for efficient matmul
w_dtype = self.k.weight.dtype
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
return k, v
def __call__(
@@ -166,7 +173,9 @@ class WanCrossAttention(nn.Module):
b = x.shape[0]
n, d = self.num_heads, self.head_dim
q = self.q(x)
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
q = self.q(x.astype(w_dtype))
if self.norm_q is not None:
q = self.norm_q(q)
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
@@ -174,11 +183,12 @@ class WanCrossAttention(nn.Module):
if kv_cache is not None:
k, v = kv_cache
else:
k = self.k(context)
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
# Optional context masking
mask = None

View File

@@ -90,3 +90,24 @@ class WanModelConfig(BaseModelConfig):
def wan22_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
return cls()
@classmethod
def wan22_ti2v_5b(cls) -> "WanModelConfig":
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
return cls(
model_type="ti2v",
dim=3072,
ffn_dim=14336,
in_dim=48,
out_dim=48,
num_heads=24,
num_layers=30,
vae_z_dim=48,
vae_stride=(4, 16, 16),
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
sample_fps=24,
)

View File

@@ -0,0 +1,58 @@
"""Image-to-Video utility functions for Wan2.2."""
import mlx.core as mx
import numpy as np
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
"""Load, resize, center-crop, and normalize an image for I2V.
Args:
image_path: Path to input image
width: Target width
height: Target height
Returns:
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
"""
from PIL import Image
img = Image.open(image_path).convert("RGB")
# Resize so that the image covers the target size (LANCZOS)
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
# Center crop
x1 = (img.width - width) // 2
y1 = (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
# To tensor: [H, W, 3] float32 in [-1, 1]
arr = np.array(img, dtype=np.float32) / 255.0
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
def build_i2v_mask(z_shape, patch_size):
"""Build temporal mask for I2V: first frame = 0, rest = 1.
Args:
z_shape: Latent shape (C, T, H, W) in channels-first
patch_size: (pt, ph, pw) patch size
Returns:
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
"""
C, T, H, W = z_shape
mask = mx.ones(z_shape)
# Zero out the first temporal position
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
# Token-level mask for per-token timesteps: subsample to patch grid
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
pt, ph, pw = patch_size
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
return mask, mask_tokens

View File

@@ -0,0 +1,154 @@
"""Wan model loading utilities."""
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
"""Load and initialize WanModel, with optional quantization support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
"""
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
if quantization:
from mlx_video.convert_wan import _quantize_predicate
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights = mx.load(str(model_path))
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model
def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder.
Weights are upcast to float32 for maximum precision — the T5 encoder
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
dim=config.t5_dim,
dim_attn=config.t5_dim_attn,
dim_ffn=config.t5_dim_ffn,
num_heads=config.t5_num_heads,
num_layers=config.t5_num_layers,
num_buckets=config.t5_num_buckets,
shared_pos=False,
)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters())
return encoder
def load_vae_decoder(model_path: Path, config=None):
"""Load VAE decoder (skips encoder weights with strict=False).
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
For Wan2.1 (vae_z_dim=16), uses WanVAE.
"""
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def load_vae_encoder(model_path: Path, config=None):
"""Load VAE encoder for I2V image encoding.
Only supports Wan2.2 (vae_z_dim=48).
"""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()), strict=False)
mx.eval(encoder.parameters())
return encoder
def _clean_text(text: str) -> str:
"""Clean text matching official Wan2.2 tokenizer preprocessing.
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
double HTML unescape, and whitespace normalization. Critical for
correct tokenization of the Chinese negative prompt.
"""
import html
import re
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text).strip()
return text
def encode_text(
encoder,
tokenizer,
prompt: str,
text_len: int = 512,
) -> mx.array:
"""Encode text prompt using T5 encoder.
Args:
encoder: T5Encoder model
tokenizer: HuggingFace tokenizer
prompt: Text prompt
text_len: Maximum text length
Returns:
Text embeddings [L, dim]
"""
prompt = _clean_text(prompt)
tokens = tokenizer(
prompt,
max_length=text_len,
padding="max_length",
truncation=True,
return_tensors="np",
)
ids = mx.array(tokens["input_ids"])
mask = mx.array(tokens["attention_mask"])
embeddings = encoder(ids, mask=mask)
# Return only non-padding tokens
seq_len = int(mask.sum().item())
return embeddings[0, :seq_len]

View File

@@ -15,17 +15,17 @@ def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
Args:
dim: Embedding dimension (must be even).
position: 1D tensor of positions.
position: Tensor of positions — 1D [L] or 2D [B, L].
Returns:
Embeddings of shape [len(position), dim].
Embeddings of shape [L, dim] or [B, L, dim].
"""
assert dim % 2 == 0
half = dim // 2
pos = position.astype(mx.float32)
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
sinusoid = pos[:, None] * inv_freq[None, :]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
sinusoid = pos[..., None] * inv_freq # [..., half]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
class Head(nn.Module):
@@ -44,16 +44,17 @@ class Head(nn.Module):
"""
Args:
x: [B, L, dim]
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
e_f32 = e.astype(mx.float32)
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x).astype(mx.float32)
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1
return self.head(x_mod.astype(x.dtype))
@@ -261,18 +262,30 @@ class WanModel(nn.Module):
axis=0,
) # [B, seq_len, dim]
# Time embedding: compute once per sample, then broadcast to all tokens
# Time embedding
if t.ndim == 0:
t = t[None]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
model_dtype = self.patch_embedding_proj.weight.dtype
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
e = e.astype(model_dtype)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
# Keep e and e0 in float32 — official asserts float32 for modulation
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
e = e.astype(mx.float32)
else:
# I2V: per-token timesteps [B, L]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, L, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
# Keep e and e0 in float32 — official asserts float32 for modulation
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
e = e.astype(mx.float32)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):

View File

@@ -187,7 +187,7 @@ class FlowUniPCScheduler:
solver_order: int = 2,
lower_order_final: bool = True,
disable_corrector: list | None = None,
use_corrector: bool = False,
use_corrector: bool = True,
):
self.num_train_timesteps = num_train_timesteps
self.solver_order = solver_order

View File

@@ -49,9 +49,9 @@ class WanAttentionBlock(nn.Module):
context_lens: list | None = None,
cross_kv_cache: tuple | None = None,
) -> mx.array:
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
# Modulation in float32 (matching official torch.amp.autocast float32)
e_f32 = e.astype(mx.float32)
mod = self.modulation.astype(mx.float32) + e_f32
e0 = mod[:, :, 0, :] # shift for self-attn
e1 = mod[:, :, 1, :] # scale for self-attn
e2 = mod[:, :, 2, :] # gate for self-attn
@@ -59,19 +59,19 @@ class WanAttentionBlock(nn.Module):
e4 = mod[:, :, 4, :] # scale for ffn
e5 = mod[:, :, 5, :] # gate for ffn
# Self-attention with modulation
x_mod = self.norm1(x) * (1 + e1) + e0
# Self-attention with modulation (norm output in float32)
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
x = x + y * e2
x = x.astype(mx.float32) + y.astype(mx.float32) * e2
# Cross-attention (no modulation, just norm)
x_cross = self.norm3(x) if self.norm3 is not None else x
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
# FFN with modulation
x_mod = self.norm2(x) * (1 + e4) + e3
# FFN with modulation (norm output in float32)
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3
y = self.ffn(x_mod)
x = x + y * e5
x = x + y.astype(mx.float32) * e5
return x
@@ -86,4 +86,6 @@ class WanFFN(nn.Module):
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(self.fc1.weight.dtype)
return self.fc2(self.act(self.fc1(x_w)))

View File

@@ -53,7 +53,9 @@ class CausalConv3d(nn.Module):
self.kernel_size = kernel_size
self.stride = stride
self._causal_pad_t = 2 * padding[0]
# Causal temporal padding: always kernel_size-1 on the left.
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
self._causal_pad_t = kernel_size[0] - 1
self._pad_h = padding[1]
self._pad_w = padding[2]
@@ -250,6 +252,46 @@ class DupUp3D(nn.Module):
return x
class AvgDown3D(nn.Module):
"""Downsample by grouping channels across spatial/temporal factors and averaging.
Inverse of DupUp3D. No learnable parameters.
Input: [B, T, H, W, C_in] → Output: [B, T//ft, H//fs, W//fs, C_out]
"""
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = factor_t * factor_s * factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def __call__(self, x):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
# Pad temporal if not divisible by factor_t
pad_t = (self.factor_t - T % self.factor_t) % self.factor_t
if pad_t > 0:
x = mx.pad(x, [(0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)])
T = T + pad_t
ft, fs = self.factor_t, self.factor_s
# Reshape to split spatial/temporal dims
x = x.reshape(B, T // ft, ft, H // fs, fs, W // fs, fs, C)
# Move factors next to channels
x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, T', H', W', C, ft, fs, fs]
# Expand channels
x = x.reshape(B, T // ft, H // fs, W // fs, C * self.factor)
# Group and average
x = x.reshape(B, T // ft, H // fs, W // fs, self.out_channels, self.group_size)
x = x.mean(axis=-1)
return x
class Resample(nn.Module):
"""Spatial up/downsampling with optional temporal up/downsampling."""
@@ -267,6 +309,15 @@ class Resample(nn.Module):
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
elif mode == "downsample3d":
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
raise ValueError(f"Unsupported mode: {mode}")
@@ -283,6 +334,12 @@ class Resample(nn.Module):
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight) + self.resample_bias
def _downsample_conv2d(self, x):
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
@@ -320,20 +377,37 @@ class Resample(nn.Module):
mx.eval(x)
T = x.shape[1]
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8
chunks = []
for t_start in range(0, T, chunk_size):
t_end = min(t_start + chunk_size, T)
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
x_chunk = self._upsample2x(x_chunk)
x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk)
chunks.append(x_chunk)
if self.mode == "downsample3d" and T > 1:
# Temporal downsample via strided CausalConv3d
# Skip for T=1 (single frame) — matches official chunked encoding
# where first chunk stores cache but doesn't apply time_conv
x = self.time_conv(x)
mx.eval(x)
T = x.shape[1]
if self.mode in ("upsample2d", "upsample3d"):
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8
chunks = []
for t_start in range(0, T, chunk_size):
t_end = min(t_start + chunk_size, T)
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
x_chunk = self._upsample2x(x_chunk)
x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk)
chunks.append(x_chunk)
x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C)
elif self.mode in ("downsample2d", "downsample3d"):
# Spatial downsample: per-frame strided Conv2d
x_flat = x.reshape(B * T, H, W, C)
x_flat = self._downsample_conv2d(x_flat)
mx.eval(x_flat)
H2, W2 = x_flat.shape[1], x_flat.shape[2]
x = x_flat.reshape(B, T, H2, W2, C)
x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C)
return x
@@ -383,6 +457,44 @@ class Up_ResidualBlock(nn.Module):
return x_main
class Down_ResidualBlock(nn.Module):
"""Downsampling residual block with AvgDown3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
super().__init__()
self.down_flag = down_flag
# AvgDown3D shortcut (no learnable params, always present)
self.avg_shortcut = AvgDown3D(
in_dim, out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path: ResidualBlocks + optional Resample
blocks = []
dim_in = in_dim
for _ in range(num_res_blocks):
blocks.append(ResidualBlock(dim_in, out_dim))
dim_in = out_dim
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
blocks.append(Resample(out_dim, mode=mode))
self.downsamples = blocks
def __call__(self, x):
x_shortcut = self.avg_shortcut(x)
mx.eval(x_shortcut)
for module in self.downsamples:
x = module(x)
mx.eval(x)
return x + x_shortcut
class Decoder3d(nn.Module):
"""Wan2.2 3D VAE Decoder."""
@@ -439,6 +551,63 @@ class Decoder3d(nn.Module):
return x
class Encoder3d(nn.Module):
"""Wan2.2 3D VAE Encoder. Mirror of Decoder3d with downsampling."""
def __init__(
self,
dim=160,
z_dim=96,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_downsample=(False, True, True),
):
super().__init__()
# Channel dimensions: [160, 160, 320, 640, 640]
dims = [dim * m for m in [1] + list(dim_mult)]
# Initial conv: patchified input (12 ch) → first dim
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# Downsample blocks
self.downsamples = []
for i in range(len(dim_mult)):
in_d, out_d = dims[i], dims[i + 1]
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
self.downsamples.append(Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
))
# Middle blocks (same as decoder)
out_dim = dims[-1]
self.middle = [
ResidualBlock(out_dim, out_dim),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim),
]
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
self.head = Head22(out_dim, out_channels=z_dim)
def __call__(self, x):
# x: [B, T, H, W, 12] (patchified)
x = self.conv1(x)
for layer in self.downsamples:
x = layer(x)
for layer in self.middle:
x = layer(x)
mx.eval(x)
x = self.head(x)
return x
class Head22(nn.Module):
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
@@ -460,6 +629,46 @@ class Head22(nn.Module):
return x
class Wan22VAEEncoder(nn.Module):
"""Full Wan2.2 VAE encoder with patchify and normalization."""
def __init__(self, z_dim=48, dim=160):
super().__init__()
self.z_dim = z_dim
# conv1: top-level 1x1x1 conv after encoder (z_dim*2 → z_dim*2)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.encoder = Encoder3d(
dim=dim,
z_dim=z_dim * 2, # Encoder outputs z_dim*2, split into mu + log_var
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_downsample=(False, True, True),
)
def __call__(self, img):
"""Encode image/video to latent space.
Args:
img: [B, T, H, W, 3] image/video in [-1, 1]
Returns:
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
"""
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
x = _patchify(img, patch_size=2)
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
out = self.encoder(x)
# conv1 (pointwise) + split into mu, log_var
out = self.conv1(out)
mu = out[:, :, :, :, :self.z_dim]
# Normalize
mu = normalize_latents(mu)
return mu
class Wan22VAEDecoder(nn.Module):
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
@@ -507,6 +716,15 @@ def denormalize_latents(z, mean=None, std=None):
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
def normalize_latents(z, mean=None, std=None):
"""Normalize latents: z_norm = (z - mean) / std. Inverse of denormalize_latents."""
if mean is None:
mean = VAE22_MEAN
if std is None:
std = VAE22_STD
return (z - mean.reshape(1, 1, 1, 1, -1)) / std.reshape(1, 1, 1, 1, -1)
def _unpatchify(x, patch_size=2):
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
@@ -527,10 +745,30 @@ def _unpatchify(x, patch_size=2):
return x
def sanitize_wan22_vae_weights(weights: dict) -> dict:
def _patchify(x, patch_size=2):
"""Convert spatial to packed channels: [B, T, H*p, W*p, C] → [B, T, H, W, C*p*p]
Inverse of _unpatchify.
PyTorch: b c f (h q) (w r) -> b (c r q) f h w
In channels-last: [B, T, H*q, W*r, C] → [B, T, H, W, C*r*q]
"""
if patch_size == 1:
return x
B, T, Hfull, Wfull, C = x.shape
H = Hfull // patch_size
W = Wfull // patch_size
# [B, T, H, q, W, r, C]
x = x.reshape(B, T, H, patch_size, W, patch_size, C)
# Rearrange to pack q,r into channels: [B, T, H, W, C, r, q]
x = x.transpose(0, 1, 2, 4, 6, 5, 3) # [B, T, H, W, C, r, q]
x = x.reshape(B, T, H, W, C * patch_size * patch_size)
return x
def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> dict:
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
By default keeps decoder + conv2 weights only. Set include_encoder=True
to also keep encoder + conv1 weights (needed for I2V encoding).
Transposes conv weights from channels-first to channels-last.
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
Maps PyTorch nn.Sequential indices to our named layers.
@@ -538,9 +776,10 @@ def sanitize_wan22_vae_weights(weights: dict) -> dict:
sanitized = {}
for key, value in weights.items():
# Skip encoder and conv1 (encoder-only)
if key.startswith("encoder.") or key.startswith("conv1."):
continue
# Skip encoder and conv1 unless requested
if not include_encoder:
if key.startswith("encoder.") or key.startswith("conv1."):
continue
new_key = key