feat(wan): Add Wan2.2 I2V support
This commit is contained in:
@@ -71,8 +71,12 @@ class WanSelfAttention(nn.Module):
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = self.q.weight.dtype
|
||||
x_w = x.astype(w_dtype)
|
||||
|
||||
q = self.q(x_w)
|
||||
k = self.k(x_w)
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
if self.norm_k is not None:
|
||||
@@ -80,15 +84,15 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
q = q.reshape(b, s, n, d)
|
||||
k = k.reshape(b, s, n, d)
|
||||
v = self.v(x).reshape(b, s, n, d)
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# Apply RoPE
|
||||
q = rope_apply(q, grid_sizes, freqs)
|
||||
k = rope_apply(k, grid_sizes, freqs)
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
|
||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
|
||||
|
||||
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
|
||||
q = q.transpose(0, 2, 1, 3)
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Build attention mask from seq_lens
|
||||
@@ -149,11 +153,14 @@ class WanCrossAttention(nn.Module):
|
||||
"""
|
||||
b = context.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
k = self.k(context)
|
||||
# Cast to weight dtype for efficient matmul
|
||||
w_dtype = self.k.weight.dtype
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
return k, v
|
||||
|
||||
def __call__(
|
||||
@@ -166,7 +173,9 @@ class WanCrossAttention(nn.Module):
|
||||
b = x.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x)
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = self.q.weight.dtype
|
||||
q = self.q(x.astype(w_dtype))
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
@@ -174,11 +183,12 @@ class WanCrossAttention(nn.Module):
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache
|
||||
else:
|
||||
k = self.k(context)
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
# Optional context masking
|
||||
mask = None
|
||||
|
||||
@@ -90,3 +90,24 @@ class WanModelConfig(BaseModelConfig):
|
||||
def wan22_t2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
|
||||
return cls(
|
||||
model_type="ti2v",
|
||||
dim=3072,
|
||||
ffn_dim=14336,
|
||||
in_dim=48,
|
||||
out_dim=48,
|
||||
num_heads=24,
|
||||
num_layers=30,
|
||||
vae_z_dim=48,
|
||||
vae_stride=(4, 16, 16),
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
sample_fps=24,
|
||||
)
|
||||
|
||||
58
mlx_video/models/wan/i2v_utils.py
Normal file
58
mlx_video/models/wan/i2v_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Image-to-Video utility functions for Wan2.2."""
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
|
||||
"""Load, resize, center-crop, and normalize an image for I2V.
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
width: Target width
|
||||
height: Target height
|
||||
|
||||
Returns:
|
||||
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
|
||||
"""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Resize so that the image covers the target size (LANCZOS)
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
||||
|
||||
# Center crop
|
||||
x1 = (img.width - width) // 2
|
||||
y1 = (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
|
||||
# To tensor: [H, W, 3] float32 in [-1, 1]
|
||||
arr = np.array(img, dtype=np.float32) / 255.0
|
||||
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
|
||||
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
|
||||
|
||||
|
||||
def build_i2v_mask(z_shape, patch_size):
|
||||
"""Build temporal mask for I2V: first frame = 0, rest = 1.
|
||||
|
||||
Args:
|
||||
z_shape: Latent shape (C, T, H, W) in channels-first
|
||||
patch_size: (pt, ph, pw) patch size
|
||||
|
||||
Returns:
|
||||
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
|
||||
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
|
||||
"""
|
||||
C, T, H, W = z_shape
|
||||
mask = mx.ones(z_shape)
|
||||
# Zero out the first temporal position
|
||||
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
|
||||
|
||||
# Token-level mask for per-token timesteps: subsample to patch grid
|
||||
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
|
||||
pt, ph, pw = patch_size
|
||||
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
|
||||
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
|
||||
return mask, mask_tokens
|
||||
154
mlx_video/models/wan/loading.py
Normal file
154
mlx_video/models/wan/loading.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Wan model loading utilities."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||
"""Load and initialize WanModel, with optional quantization support.
|
||||
|
||||
Args:
|
||||
model_path: Path to model safetensors file
|
||||
config: WanModelConfig
|
||||
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||
If provided, creates QuantizedLinear stubs before loading.
|
||||
"""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
|
||||
if quantization:
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
|
||||
def load_t5_encoder(model_path: Path, config):
|
||||
"""Load T5 text encoder.
|
||||
|
||||
Weights are upcast to float32 for maximum precision — the T5 encoder
|
||||
only runs once per generation, so performance impact is negligible.
|
||||
This matches the official which computes softmax in float32 explicitly.
|
||||
"""
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=config.t5_vocab_size,
|
||||
dim=config.t5_dim,
|
||||
dim_attn=config.t5_dim_attn,
|
||||
dim_ffn=config.t5_dim_ffn,
|
||||
num_heads=config.t5_num_heads,
|
||||
num_layers=config.t5_num_layers,
|
||||
num_buckets=config.t5_num_buckets,
|
||||
shared_pos=False,
|
||||
)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()))
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: Path, config=None):
|
||||
"""Load VAE decoder (skips encoder weights with strict=False).
|
||||
|
||||
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
|
||||
For Wan2.1 (vae_z_dim=16), uses WanVAE.
|
||||
"""
|
||||
is_wan22 = config is not None and config.vae_z_dim == 48
|
||||
|
||||
if is_wan22:
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
vae = Wan22VAEDecoder(z_dim=48)
|
||||
else:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def load_vae_encoder(model_path: Path, config=None):
|
||||
"""Load VAE encoder for I2V image encoding.
|
||||
|
||||
Only supports Wan2.2 (vae_z_dim=48).
|
||||
"""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
"""Clean text matching official Wan2.2 tokenizer preprocessing.
|
||||
|
||||
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
|
||||
double HTML unescape, and whitespace normalization. Critical for
|
||||
correct tokenization of the Chinese negative prompt.
|
||||
"""
|
||||
import html
|
||||
import re
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
text = ftfy.fix_text(text)
|
||||
except ImportError:
|
||||
pass
|
||||
text = html.unescape(html.unescape(text))
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def encode_text(
|
||||
encoder,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
text_len: int = 512,
|
||||
) -> mx.array:
|
||||
"""Encode text prompt using T5 encoder.
|
||||
|
||||
Args:
|
||||
encoder: T5Encoder model
|
||||
tokenizer: HuggingFace tokenizer
|
||||
prompt: Text prompt
|
||||
text_len: Maximum text length
|
||||
|
||||
Returns:
|
||||
Text embeddings [L, dim]
|
||||
"""
|
||||
prompt = _clean_text(prompt)
|
||||
tokens = tokenizer(
|
||||
prompt,
|
||||
max_length=text_len,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
ids = mx.array(tokens["input_ids"])
|
||||
mask = mx.array(tokens["attention_mask"])
|
||||
|
||||
embeddings = encoder(ids, mask=mask)
|
||||
|
||||
# Return only non-padding tokens
|
||||
seq_len = int(mask.sum().item())
|
||||
return embeddings[0, :seq_len]
|
||||
@@ -15,17 +15,17 @@ def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (must be even).
|
||||
position: 1D tensor of positions.
|
||||
position: Tensor of positions — 1D [L] or 2D [B, L].
|
||||
|
||||
Returns:
|
||||
Embeddings of shape [len(position), dim].
|
||||
Embeddings of shape [L, dim] or [B, L, dim].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
pos = position.astype(mx.float32)
|
||||
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
|
||||
sinusoid = pos[:, None] * inv_freq[None, :]
|
||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
||||
sinusoid = pos[..., None] * inv_freq # [..., half]
|
||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
@@ -44,16 +44,17 @@ class Head(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
x: [B, L, dim]
|
||||
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
|
||||
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
e_f32 = e.astype(mx.float32)
|
||||
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
|
||||
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
|
||||
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
|
||||
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze
|
||||
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x).astype(mx.float32)
|
||||
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
|
||||
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1
|
||||
return self.head(x_mod.astype(x.dtype))
|
||||
|
||||
|
||||
@@ -261,18 +262,30 @@ class WanModel(nn.Module):
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding: compute once per sample, then broadcast to all tokens
|
||||
# Time embedding
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||
|
||||
model_dtype = self.patch_embedding_proj.weight.dtype
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
|
||||
e = e.astype(model_dtype)
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||
# Keep e and e0 in float32 — official asserts float32 for modulation
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
|
||||
e = e.astype(mx.float32)
|
||||
else:
|
||||
# I2V: per-token timesteps [B, L]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, L, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
|
||||
# Keep e and e0 in float32 — official asserts float32 for modulation
|
||||
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
|
||||
e = e.astype(mx.float32)
|
||||
|
||||
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||
if isinstance(context, mx.array):
|
||||
|
||||
@@ -187,7 +187,7 @@ class FlowUniPCScheduler:
|
||||
solver_order: int = 2,
|
||||
lower_order_final: bool = True,
|
||||
disable_corrector: list | None = None,
|
||||
use_corrector: bool = False,
|
||||
use_corrector: bool = True,
|
||||
):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.solver_order = solver_order
|
||||
|
||||
@@ -49,9 +49,9 @@ class WanAttentionBlock(nn.Module):
|
||||
context_lens: list | None = None,
|
||||
cross_kv_cache: tuple | None = None,
|
||||
) -> mx.array:
|
||||
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
|
||||
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
|
||||
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
|
||||
# Modulation in float32 (matching official torch.amp.autocast float32)
|
||||
e_f32 = e.astype(mx.float32)
|
||||
mod = self.modulation.astype(mx.float32) + e_f32
|
||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||
@@ -59,19 +59,19 @@ class WanAttentionBlock(nn.Module):
|
||||
e4 = mod[:, :, 4, :] # scale for ffn
|
||||
e5 = mod[:, :, 5, :] # gate for ffn
|
||||
|
||||
# Self-attention with modulation
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
# Self-attention with modulation (norm output in float32)
|
||||
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
||||
x = x + y * e2
|
||||
x = x.astype(mx.float32) + y.astype(mx.float32) * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||
|
||||
# FFN with modulation
|
||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||
# FFN with modulation (norm output in float32)
|
||||
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3
|
||||
y = self.ffn(x_mod)
|
||||
x = x + y * e5
|
||||
x = x + y.astype(mx.float32) * e5
|
||||
|
||||
return x
|
||||
|
||||
@@ -86,4 +86,6 @@ class WanFFN(nn.Module):
|
||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
x_w = x.astype(self.fc1.weight.dtype)
|
||||
return self.fc2(self.act(self.fc1(x_w)))
|
||||
|
||||
@@ -53,7 +53,9 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self._causal_pad_t = 2 * padding[0]
|
||||
# Causal temporal padding: always kernel_size-1 on the left.
|
||||
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
|
||||
self._causal_pad_t = kernel_size[0] - 1
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
@@ -250,6 +252,46 @@ class DupUp3D(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class AvgDown3D(nn.Module):
|
||||
"""Downsample by grouping channels across spatial/temporal factors and averaging.
|
||||
|
||||
Inverse of DupUp3D. No learnable parameters.
|
||||
Input: [B, T, H, W, C_in] → Output: [B, T//ft, H//fs, W//fs, C_out]
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = factor_t * factor_s * factor_s
|
||||
assert in_channels * self.factor % out_channels == 0
|
||||
self.group_size = in_channels * self.factor // out_channels
|
||||
|
||||
def __call__(self, x):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
# Pad temporal if not divisible by factor_t
|
||||
pad_t = (self.factor_t - T % self.factor_t) % self.factor_t
|
||||
if pad_t > 0:
|
||||
x = mx.pad(x, [(0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)])
|
||||
T = T + pad_t
|
||||
|
||||
ft, fs = self.factor_t, self.factor_s
|
||||
# Reshape to split spatial/temporal dims
|
||||
x = x.reshape(B, T // ft, ft, H // fs, fs, W // fs, fs, C)
|
||||
# Move factors next to channels
|
||||
x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, T', H', W', C, ft, fs, fs]
|
||||
# Expand channels
|
||||
x = x.reshape(B, T // ft, H // fs, W // fs, C * self.factor)
|
||||
# Group and average
|
||||
x = x.reshape(B, T // ft, H // fs, W // fs, self.out_channels, self.group_size)
|
||||
x = x.mean(axis=-1)
|
||||
return x
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
"""Spatial up/downsampling with optional temporal up/downsampling."""
|
||||
|
||||
@@ -267,6 +309,15 @@ class Resample(nn.Module):
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
elif mode == "downsample2d":
|
||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
elif mode == "downsample3d":
|
||||
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
|
||||
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
else:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
|
||||
@@ -283,6 +334,12 @@ class Resample(nn.Module):
|
||||
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
||||
return mx.conv_general(x, self.resample_weight) + self.resample_bias
|
||||
|
||||
def _downsample_conv2d(self, x):
|
||||
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
|
||||
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
|
||||
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
|
||||
|
||||
def __call__(self, x, first_chunk=False):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
@@ -320,20 +377,37 @@ class Resample(nn.Module):
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
# Spatial upsample in temporal chunks to limit peak memory
|
||||
chunk_size = 8
|
||||
chunks = []
|
||||
for t_start in range(0, T, chunk_size):
|
||||
t_end = min(t_start + chunk_size, T)
|
||||
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
|
||||
x_chunk = self._upsample2x(x_chunk)
|
||||
x_chunk = self._conv2d(x_chunk)
|
||||
mx.eval(x_chunk)
|
||||
chunks.append(x_chunk)
|
||||
if self.mode == "downsample3d" and T > 1:
|
||||
# Temporal downsample via strided CausalConv3d
|
||||
# Skip for T=1 (single frame) — matches official chunked encoding
|
||||
# where first chunk stores cache but doesn't apply time_conv
|
||||
x = self.time_conv(x)
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
if self.mode in ("upsample2d", "upsample3d"):
|
||||
# Spatial upsample in temporal chunks to limit peak memory
|
||||
chunk_size = 8
|
||||
chunks = []
|
||||
for t_start in range(0, T, chunk_size):
|
||||
t_end = min(t_start + chunk_size, T)
|
||||
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
|
||||
x_chunk = self._upsample2x(x_chunk)
|
||||
x_chunk = self._conv2d(x_chunk)
|
||||
mx.eval(x_chunk)
|
||||
chunks.append(x_chunk)
|
||||
|
||||
x = mx.concatenate(chunks, axis=0)
|
||||
H2, W2 = x.shape[1], x.shape[2]
|
||||
x = x.reshape(B, T, H2, W2, C)
|
||||
elif self.mode in ("downsample2d", "downsample3d"):
|
||||
# Spatial downsample: per-frame strided Conv2d
|
||||
x_flat = x.reshape(B * T, H, W, C)
|
||||
x_flat = self._downsample_conv2d(x_flat)
|
||||
mx.eval(x_flat)
|
||||
H2, W2 = x_flat.shape[1], x_flat.shape[2]
|
||||
x = x_flat.reshape(B, T, H2, W2, C)
|
||||
|
||||
x = mx.concatenate(chunks, axis=0)
|
||||
H2, W2 = x.shape[1], x.shape[2]
|
||||
x = x.reshape(B, T, H2, W2, C)
|
||||
return x
|
||||
|
||||
|
||||
@@ -383,6 +457,44 @@ class Up_ResidualBlock(nn.Module):
|
||||
return x_main
|
||||
|
||||
|
||||
class Down_ResidualBlock(nn.Module):
|
||||
"""Downsampling residual block with AvgDown3D shortcut."""
|
||||
|
||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
|
||||
super().__init__()
|
||||
self.down_flag = down_flag
|
||||
|
||||
# AvgDown3D shortcut (no learnable params, always present)
|
||||
self.avg_shortcut = AvgDown3D(
|
||||
in_dim, out_dim,
|
||||
factor_t=2 if temperal_downsample else 1,
|
||||
factor_s=2 if down_flag else 1,
|
||||
)
|
||||
|
||||
# Main path: ResidualBlocks + optional Resample
|
||||
blocks = []
|
||||
dim_in = in_dim
|
||||
for _ in range(num_res_blocks):
|
||||
blocks.append(ResidualBlock(dim_in, out_dim))
|
||||
dim_in = out_dim
|
||||
|
||||
if down_flag:
|
||||
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
||||
blocks.append(Resample(out_dim, mode=mode))
|
||||
|
||||
self.downsamples = blocks
|
||||
|
||||
def __call__(self, x):
|
||||
x_shortcut = self.avg_shortcut(x)
|
||||
mx.eval(x_shortcut)
|
||||
|
||||
for module in self.downsamples:
|
||||
x = module(x)
|
||||
mx.eval(x)
|
||||
|
||||
return x + x_shortcut
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
"""Wan2.2 3D VAE Decoder."""
|
||||
|
||||
@@ -439,6 +551,63 @@ class Decoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
"""Wan2.2 3D VAE Encoder. Mirror of Decoder3d with downsampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=160,
|
||||
z_dim=96,
|
||||
dim_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
temperal_downsample=(False, True, True),
|
||||
):
|
||||
super().__init__()
|
||||
# Channel dimensions: [160, 160, 320, 640, 640]
|
||||
dims = [dim * m for m in [1] + list(dim_mult)]
|
||||
|
||||
# Initial conv: patchified input (12 ch) → first dim
|
||||
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
||||
|
||||
# Downsample blocks
|
||||
self.downsamples = []
|
||||
for i in range(len(dim_mult)):
|
||||
in_d, out_d = dims[i], dims[i + 1]
|
||||
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
|
||||
self.downsamples.append(Down_ResidualBlock(
|
||||
in_dim=in_d,
|
||||
out_dim=out_d,
|
||||
num_res_blocks=num_res_blocks,
|
||||
temperal_downsample=t_down,
|
||||
down_flag=(i < len(dim_mult) - 1),
|
||||
))
|
||||
|
||||
# Middle blocks (same as decoder)
|
||||
out_dim = dims[-1]
|
||||
self.middle = [
|
||||
ResidualBlock(out_dim, out_dim),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim),
|
||||
]
|
||||
|
||||
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
|
||||
self.head = Head22(out_dim, out_channels=z_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
# x: [B, T, H, W, 12] (patchified)
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.downsamples:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.middle:
|
||||
x = layer(x)
|
||||
mx.eval(x)
|
||||
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class Head22(nn.Module):
|
||||
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
|
||||
|
||||
@@ -460,6 +629,46 @@ class Head22(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class Wan22VAEEncoder(nn.Module):
|
||||
"""Full Wan2.2 VAE encoder with patchify and normalization."""
|
||||
|
||||
def __init__(self, z_dim=48, dim=160):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
# conv1: top-level 1x1x1 conv after encoder (z_dim*2 → z_dim*2)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.encoder = Encoder3d(
|
||||
dim=dim,
|
||||
z_dim=z_dim * 2, # Encoder outputs z_dim*2, split into mu + log_var
|
||||
dim_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
temperal_downsample=(False, True, True),
|
||||
)
|
||||
|
||||
def __call__(self, img):
|
||||
"""Encode image/video to latent space.
|
||||
|
||||
Args:
|
||||
img: [B, T, H, W, 3] image/video in [-1, 1]
|
||||
|
||||
Returns:
|
||||
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
|
||||
"""
|
||||
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
|
||||
x = _patchify(img, patch_size=2)
|
||||
|
||||
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
|
||||
out = self.encoder(x)
|
||||
|
||||
# conv1 (pointwise) + split into mu, log_var
|
||||
out = self.conv1(out)
|
||||
mu = out[:, :, :, :, :self.z_dim]
|
||||
|
||||
# Normalize
|
||||
mu = normalize_latents(mu)
|
||||
return mu
|
||||
|
||||
|
||||
class Wan22VAEDecoder(nn.Module):
|
||||
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
||||
|
||||
@@ -507,6 +716,15 @@ def denormalize_latents(z, mean=None, std=None):
|
||||
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
|
||||
|
||||
|
||||
def normalize_latents(z, mean=None, std=None):
|
||||
"""Normalize latents: z_norm = (z - mean) / std. Inverse of denormalize_latents."""
|
||||
if mean is None:
|
||||
mean = VAE22_MEAN
|
||||
if std is None:
|
||||
std = VAE22_STD
|
||||
return (z - mean.reshape(1, 1, 1, 1, -1)) / std.reshape(1, 1, 1, 1, -1)
|
||||
|
||||
|
||||
def _unpatchify(x, patch_size=2):
|
||||
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
|
||||
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
|
||||
@@ -527,10 +745,30 @@ def _unpatchify(x, patch_size=2):
|
||||
return x
|
||||
|
||||
|
||||
def sanitize_wan22_vae_weights(weights: dict) -> dict:
|
||||
def _patchify(x, patch_size=2):
|
||||
"""Convert spatial to packed channels: [B, T, H*p, W*p, C] → [B, T, H, W, C*p*p]
|
||||
Inverse of _unpatchify.
|
||||
PyTorch: b c f (h q) (w r) -> b (c r q) f h w
|
||||
In channels-last: [B, T, H*q, W*r, C] → [B, T, H, W, C*r*q]
|
||||
"""
|
||||
if patch_size == 1:
|
||||
return x
|
||||
B, T, Hfull, Wfull, C = x.shape
|
||||
H = Hfull // patch_size
|
||||
W = Wfull // patch_size
|
||||
# [B, T, H, q, W, r, C]
|
||||
x = x.reshape(B, T, H, patch_size, W, patch_size, C)
|
||||
# Rearrange to pack q,r into channels: [B, T, H, W, C, r, q]
|
||||
x = x.transpose(0, 1, 2, 4, 6, 5, 3) # [B, T, H, W, C, r, q]
|
||||
x = x.reshape(B, T, H, W, C * patch_size * patch_size)
|
||||
return x
|
||||
|
||||
|
||||
def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> dict:
|
||||
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
|
||||
|
||||
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
|
||||
By default keeps decoder + conv2 weights only. Set include_encoder=True
|
||||
to also keep encoder + conv1 weights (needed for I2V encoding).
|
||||
Transposes conv weights from channels-first to channels-last.
|
||||
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
|
||||
Maps PyTorch nn.Sequential indices to our named layers.
|
||||
@@ -538,9 +776,10 @@ def sanitize_wan22_vae_weights(weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
# Skip encoder and conv1 (encoder-only)
|
||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||
continue
|
||||
# Skip encoder and conv1 unless requested
|
||||
if not include_encoder:
|
||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||
continue
|
||||
|
||||
new_key = key
|
||||
|
||||
|
||||
Reference in New Issue
Block a user