feat(wan): Add Wan2.1/2.2 T2V with quantization support
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
|
||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||
from mlx_video.models.wan import WanModel, WanModelConfig
|
||||
|
||||
2
mlx_video/models/wan/__init__.py
Normal file
2
mlx_video/models/wan/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
201
mlx_video/models/wan/attention.py
Normal file
201
mlx_video/models/wan/attention.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .rope import rope_apply
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
"""RMS normalization with learnable scale."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.Module):
|
||||
"""LayerNorm computed in float32, with optional affine."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if elementwise_affine:
|
||||
self.weight = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.elementwise_affine:
|
||||
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
return mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
"""Self-attention with QK normalization and 3-way factorized RoPE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: tuple = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
) -> mx.array:
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
|
||||
q = q.reshape(b, s, n, d)
|
||||
k = k.reshape(b, s, n, d)
|
||||
v = self.v(x).reshape(b, s, n, d)
|
||||
|
||||
# Apply RoPE
|
||||
q = rope_apply(q, grid_sizes, freqs)
|
||||
k = rope_apply(k, grid_sizes, freqs)
|
||||
|
||||
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
|
||||
q = q.transpose(0, 2, 1, 3)
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Build attention mask from seq_lens
|
||||
max_len = s
|
||||
mask = None
|
||||
if any(sl < max_len for sl in seq_lens):
|
||||
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
|
||||
for i, sl in enumerate(seq_lens):
|
||||
mask[i, :, :, sl:] = -1e9
|
||||
|
||||
# Use memory-efficient scaled dot-product attention
|
||||
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
|
||||
if mask is not None:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale
|
||||
)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class WanCrossAttention(nn.Module):
|
||||
"""Cross-attention: Q from hidden states, K/V from text context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
|
||||
def prepare_kv(self, context: mx.array) -> tuple:
|
||||
"""Pre-compute K and V projections for caching.
|
||||
|
||||
Args:
|
||||
context: [B, L_ctx, dim]
|
||||
|
||||
Returns:
|
||||
(k, v) each [B, N, L_ctx, D] ready for attention
|
||||
"""
|
||||
b = context.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
k = self.k(context)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
return k, v
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
kv_cache: tuple | None = None,
|
||||
) -> mx.array:
|
||||
b = x.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x)
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache
|
||||
else:
|
||||
k = self.k(context)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
# Optional context masking
|
||||
mask = None
|
||||
if context_lens is not None:
|
||||
ctx_len = k.shape[2]
|
||||
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
|
||||
for i, cl in enumerate(context_lens):
|
||||
mask[i, :, :, cl:] = -1e9
|
||||
|
||||
if mask is not None:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale
|
||||
)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
||||
return self.o(out)
|
||||
86
mlx_video/models/wan/config.py
Normal file
86
mlx_video/models/wan/config.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from mlx_video.models.ltx.config import BaseModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class WanModelConfig(BaseModelConfig):
|
||||
"""Configuration for Wan T2V models (supports both 2.1 and 2.2)."""
|
||||
|
||||
model_type: str = "t2v"
|
||||
model_version: str = "2.2"
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2)
|
||||
text_len: int = 512
|
||||
in_dim: int = 16
|
||||
dim: int = 5120
|
||||
ffn_dim: int = 13824
|
||||
freq_dim: int = 256
|
||||
text_dim: int = 4096
|
||||
out_dim: int = 16
|
||||
num_heads: int = 40
|
||||
num_layers: int = 40
|
||||
window_size: Tuple[int, int] = (-1, -1)
|
||||
qk_norm: bool = True
|
||||
cross_attn_norm: bool = True
|
||||
eps: float = 1e-6
|
||||
|
||||
# VAE
|
||||
vae_stride: Tuple[int, int, int] = (4, 8, 8)
|
||||
vae_z_dim: int = 16
|
||||
|
||||
# Inference
|
||||
dual_model: bool = True
|
||||
boundary: float = 0.875
|
||||
sample_shift: float = 12.0
|
||||
sample_steps: int = 40
|
||||
sample_guide_scale: Union[float, Tuple[float, float]] = (3.0, 4.0)
|
||||
num_train_timesteps: int = 1000
|
||||
sample_fps: int = 16
|
||||
frame_num: int = 81
|
||||
|
||||
# T5
|
||||
t5_vocab_size: int = 256384
|
||||
t5_dim: int = 4096
|
||||
t5_dim_attn: int = 4096
|
||||
t5_dim_ffn: int = 10240
|
||||
t5_num_heads: int = 64
|
||||
t5_num_layers: int = 24
|
||||
t5_num_buckets: int = 32
|
||||
|
||||
@property
|
||||
def head_dim(self) -> int:
|
||||
return self.dim // self.num_heads
|
||||
|
||||
@classmethod
|
||||
def wan21_t2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.1 T2V 14B: single model, 40 layers, dim=5120."""
|
||||
return cls(
|
||||
model_version="2.1",
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan21_t2v_1_3b(cls) -> "WanModelConfig":
|
||||
"""Wan2.1 T2V 1.3B: single model, 30 layers, dim=1536."""
|
||||
return cls(
|
||||
model_version="2.1",
|
||||
dim=1536,
|
||||
ffn_dim=8960,
|
||||
num_heads=12,
|
||||
num_layers=30,
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan22_t2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
|
||||
return cls()
|
||||
307
mlx_video/models/wan/model.py
Normal file
307
mlx_video/models/wan/model.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .attention import WanLayerNorm
|
||||
from .config import WanModelConfig
|
||||
from .rope import rope_params
|
||||
from .transformer import WanAttentionBlock
|
||||
|
||||
|
||||
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
||||
"""Compute sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (must be even).
|
||||
position: 1D tensor of positions.
|
||||
|
||||
Returns:
|
||||
Embeddings of shape [len(position), dim].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
pos = position.astype(mx.float32)
|
||||
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
|
||||
sinusoid = pos[:, None] * inv_freq[None, :]
|
||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
"""Output projection head with learned modulation."""
|
||||
|
||||
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
proj_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, proj_dim)
|
||||
self.modulation = mx.random.normal((1, 2, dim)) * (dim**-0.5)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
x: [B, L, dim]
|
||||
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
e_f32 = e.astype(mx.float32)
|
||||
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
|
||||
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
|
||||
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
|
||||
x_norm = self.norm(x).astype(mx.float32)
|
||||
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
|
||||
return self.head(x_mod.astype(x.dtype))
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
"""Wan2.2 diffusion backbone for text-to-video generation."""
|
||||
|
||||
def __init__(self, config: WanModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
dim = config.dim
|
||||
self.dim = dim
|
||||
self.num_heads = config.num_heads
|
||||
self.out_dim = config.out_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.text_len = config.text_len
|
||||
self.freq_dim = config.freq_dim
|
||||
|
||||
# Patch embedding: Conv3d implemented as a reshaped linear
|
||||
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
|
||||
patch_dim = config.in_dim * math.prod(config.patch_size)
|
||||
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
|
||||
self._patch_size = config.patch_size
|
||||
|
||||
# Text embedding MLP
|
||||
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
|
||||
self.text_embedding_act = nn.GELU(approx="precise")
|
||||
self.text_embedding_1 = nn.Linear(dim, dim)
|
||||
|
||||
# Time embedding MLP
|
||||
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
|
||||
self.time_embedding_act = nn.SiLU()
|
||||
self.time_embedding_1 = nn.Linear(dim, dim)
|
||||
|
||||
# Time projection for modulation (6x dim)
|
||||
self.time_projection_act = nn.SiLU()
|
||||
self.time_projection = nn.Linear(dim, dim * 6)
|
||||
|
||||
# Transformer blocks
|
||||
self.blocks = [
|
||||
WanAttentionBlock(
|
||||
dim=dim,
|
||||
ffn_dim=config.ffn_dim,
|
||||
num_heads=config.num_heads,
|
||||
window_size=config.window_size,
|
||||
qk_norm=config.qk_norm,
|
||||
cross_attn_norm=config.cross_attn_norm,
|
||||
eps=config.eps,
|
||||
)
|
||||
for _ in range(config.num_layers)
|
||||
]
|
||||
|
||||
# Output head
|
||||
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
|
||||
|
||||
# Precompute RoPE frequencies
|
||||
d = dim // config.num_heads
|
||||
d_t = d - 4 * (d // 6)
|
||||
d_h = 2 * (d // 6)
|
||||
d_w = 2 * (d // 6)
|
||||
# Each rope_params returns [1024, d_x//2, 2]
|
||||
freqs_t = rope_params(1024, d_t)
|
||||
freqs_h = rope_params(1024, d_h)
|
||||
freqs_w = rope_params(1024, d_w)
|
||||
# Concatenate along the frequency dimension: [1024, d//2, 2]
|
||||
self.freqs = mx.concatenate([freqs_t, freqs_h, freqs_w], axis=1)
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
"""Convert video tensor to patch embeddings.
|
||||
|
||||
Args:
|
||||
x: Video latent [C, F, H, W]
|
||||
|
||||
Returns:
|
||||
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
|
||||
"""
|
||||
c, f, h, w = x.shape
|
||||
pt, ph, pw = self._patch_size
|
||||
|
||||
f_out = f // pt
|
||||
h_out = h // ph
|
||||
w_out = w // pw
|
||||
|
||||
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
|
||||
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
|
||||
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
|
||||
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
|
||||
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
|
||||
|
||||
# Project and cast to model dtype to prevent float32 cascade from input latents
|
||||
patches = self.patch_embedding_proj(x) # [L, dim]
|
||||
patches = patches.astype(self.patch_embedding_proj.weight.dtype)
|
||||
patches = patches[None, :, :] # [1, L, dim]
|
||||
|
||||
return patches, (f_out, h_out, w_out)
|
||||
|
||||
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
|
||||
"""Reconstruct video from patch embeddings.
|
||||
|
||||
Args:
|
||||
x: [B, L, out_dim * prod(patch_size)]
|
||||
grid_sizes: List of (F', H', W') per batch element
|
||||
|
||||
Returns:
|
||||
List of tensors [C, F, H, W]
|
||||
"""
|
||||
c = self.out_dim
|
||||
pt, ph, pw = self.patch_size
|
||||
out = []
|
||||
for i, (f, h, w) in enumerate(grid_sizes):
|
||||
seq_len = f * h * w
|
||||
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
|
||||
u = u.reshape(f, h, w, pt, ph, pw, c)
|
||||
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
|
||||
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
|
||||
u = u.reshape(c, f * pt, h * ph, w * pw)
|
||||
out.append(u)
|
||||
return out
|
||||
|
||||
def embed_text(self, context: list) -> mx.array:
|
||||
"""Precompute text embeddings (call once, reuse across steps).
|
||||
|
||||
Args:
|
||||
context: List of text embeddings [L_text, text_dim]
|
||||
|
||||
Returns:
|
||||
Embedded context [B, text_len, dim] in model dtype
|
||||
"""
|
||||
model_dtype = self.patch_embedding_proj.weight.dtype
|
||||
context_padded = []
|
||||
for ctx in context:
|
||||
pad_len = self.text_len - ctx.shape[0]
|
||||
if pad_len > 0:
|
||||
ctx = mx.concatenate(
|
||||
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
|
||||
axis=0,
|
||||
)
|
||||
context_padded.append(ctx)
|
||||
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
|
||||
context_batch = self.text_embedding_1(
|
||||
self.text_embedding_act(self.text_embedding_0(context_batch))
|
||||
)
|
||||
return context_batch.astype(model_dtype)
|
||||
|
||||
def prepare_cross_kv(self, context: mx.array) -> list:
|
||||
"""Pre-compute cross-attention K/V for all blocks.
|
||||
|
||||
Call once before the diffusion loop to cache K/V projections,
|
||||
eliminating redundant computation at each denoising step.
|
||||
|
||||
Args:
|
||||
context: Pre-embedded text [B, text_len, dim]
|
||||
|
||||
Returns:
|
||||
List of (k, v) tuples, one per block
|
||||
"""
|
||||
kv_caches = []
|
||||
for block in self.blocks:
|
||||
kv_caches.append(block.cross_attn.prepare_kv(context))
|
||||
return kv_caches
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x_list: list,
|
||||
t: mx.array,
|
||||
context: list | mx.array,
|
||||
seq_len: int,
|
||||
cross_kv_caches: list | None = None,
|
||||
) -> list:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x_list: List of video latent tensors [C, F, H, W]
|
||||
t: Timestep tensor [B]
|
||||
context: List of raw text embeddings, OR pre-embedded tensor
|
||||
from embed_text() [B, text_len, dim]
|
||||
seq_len: Maximum sequence length for padding
|
||||
cross_kv_caches: Optional list of (k, v) tuples from
|
||||
prepare_cross_kv(), one per block.
|
||||
|
||||
Returns:
|
||||
List of denoised tensors [C, F, H, W]
|
||||
"""
|
||||
# Patchify each video
|
||||
patches = []
|
||||
grid_sizes = []
|
||||
seq_lens_list = []
|
||||
for vid in x_list:
|
||||
p, gs = self._patchify(vid) # [1, L, dim]
|
||||
patches.append(p)
|
||||
grid_sizes.append(gs)
|
||||
seq_lens_list.append(p.shape[1])
|
||||
|
||||
# Pad and batch
|
||||
batch_size = len(patches)
|
||||
x = mx.concatenate(
|
||||
[
|
||||
mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding: compute once per sample, then broadcast to all tokens
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||
|
||||
model_dtype = self.patch_embedding_proj.weight.dtype
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
|
||||
e = e.astype(model_dtype)
|
||||
|
||||
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||
if isinstance(context, mx.array):
|
||||
# Pre-embedded: expand to batch size if needed
|
||||
context_batch = context
|
||||
if context_batch.shape[0] == 1 and batch_size > 1:
|
||||
context_batch = mx.broadcast_to(
|
||||
context_batch, (batch_size,) + context_batch.shape[1:]
|
||||
)
|
||||
else:
|
||||
context_batch = self.embed_text(context)
|
||||
|
||||
# Run transformer blocks
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens_list,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context_batch,
|
||||
context_lens=None,
|
||||
)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
|
||||
x = block(x, cross_kv_cache=kv, **kwargs)
|
||||
|
||||
# Output head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
outputs = self.unpatchify(x, grid_sizes)
|
||||
return [u.astype(mx.float32) for u in outputs]
|
||||
100
mlx_video/models/wan/rope.py
Normal file
100
mlx_video/models/wan/rope.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
|
||||
"""Precompute RoPE frequency parameters as complex numbers.
|
||||
|
||||
Returns:
|
||||
Complex frequency tensor of shape [max_seq_len, dim // 2].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
|
||||
1.0
|
||||
/ np.power(
|
||||
theta,
|
||||
np.arange(0, dim, 2, dtype=np.float64) / dim,
|
||||
)
|
||||
)[None, :]
|
||||
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
|
||||
cos_freqs = np.cos(freqs).astype(np.float32)
|
||||
sin_freqs = np.sin(freqs).astype(np.float32)
|
||||
return mx.array(np.stack([cos_freqs, sin_freqs], axis=-1))
|
||||
|
||||
|
||||
def rope_apply(
|
||||
x: mx.array,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply 3-way factorized RoPE to Q or K tensor.
|
||||
|
||||
Args:
|
||||
x: Shape [B, L, num_heads, head_dim]
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
|
||||
"""
|
||||
b, s, n, d = x.shape
|
||||
half_d = d // 2
|
||||
|
||||
# Cast freqs to input dtype to prevent float32 promotion cascade
|
||||
if freqs.dtype != x.dtype:
|
||||
freqs = freqs.astype(x.dtype)
|
||||
|
||||
# Split frequency dimensions: temporal gets more capacity
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
|
||||
# Split freqs along dim axis
|
||||
freqs_t = freqs[:, :d_t] # [1024, d_t, 2]
|
||||
freqs_h = freqs[:, d_t : d_t + d_h] # [1024, d_h, 2]
|
||||
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w] # [1024, d_w, 2]
|
||||
|
||||
outputs = []
|
||||
for i in range(b):
|
||||
f, h, w = grid_sizes[i]
|
||||
seq_len = f * h * w
|
||||
|
||||
# Reshape x to pairs for rotation: [seq_len, n, half_d, 2]
|
||||
x_i = x[i, :seq_len].reshape(seq_len, n, half_d, 2)
|
||||
|
||||
# Build per-position frequencies by expanding along grid dims
|
||||
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
|
||||
ft = mx.broadcast_to(
|
||||
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
|
||||
)
|
||||
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
|
||||
fh = mx.broadcast_to(
|
||||
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
|
||||
)
|
||||
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
|
||||
fw = mx.broadcast_to(
|
||||
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
|
||||
)
|
||||
|
||||
# Concatenate: [f*h*w, half_d, 2]
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
|
||||
# Apply rotation: (a + bi) * (cos + sin*i) = (a*cos - b*sin) + (a*sin + b*cos)i
|
||||
cos_f = freqs_i[..., 0] # [seq_len, 1, half_d]
|
||||
sin_f = freqs_i[..., 1] # [seq_len, 1, half_d]
|
||||
|
||||
x_real = x_i[..., 0] # [seq_len, n, half_d]
|
||||
x_imag = x_i[..., 1] # [seq_len, n, half_d]
|
||||
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
|
||||
# Interleave back: [seq_len, n, half_d, 2] -> [seq_len, n, d]
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(seq_len, n, d)
|
||||
|
||||
# Handle padding: keep non-rotated tokens after seq_len
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[i, seq_len:]], axis=0)
|
||||
|
||||
outputs.append(x_rotated)
|
||||
|
||||
return mx.stack(outputs)
|
||||
76
mlx_video/models/wan/scheduler.py
Normal file
76
mlx_video/models/wan/scheduler.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Flow matching scheduler for Wan2.2 inference."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class FlowMatchEulerScheduler:
|
||||
"""Simple Euler scheduler for flow matching diffusion.
|
||||
|
||||
Implements the flow matching formulation where the model predicts
|
||||
velocity (flow) and we use Euler steps to denoise.
|
||||
"""
|
||||
|
||||
def __init__(self, num_train_timesteps: int = 1000):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps = None
|
||||
self.sigmas = None
|
||||
|
||||
def set_timesteps(self, num_steps: int, shift: float = 1.0):
|
||||
"""Compute sigma schedule with shift.
|
||||
|
||||
Args:
|
||||
num_steps: Number of inference steps.
|
||||
shift: Noise schedule shift factor.
|
||||
"""
|
||||
# Linear spacing from sigma_max to sigma_min
|
||||
sigmas = np.linspace(1.0, 1.0 / self.num_train_timesteps, self.num_train_timesteps)[::-1]
|
||||
sigmas = 1.0 - sigmas
|
||||
|
||||
# Select evenly spaced subset
|
||||
indices = np.linspace(0, len(sigmas) - 1, num_steps + 1).astype(int)
|
||||
sigmas = sigmas[indices[:-1]]
|
||||
|
||||
# Apply shift: sigma' = shift * sigma / (1 + (shift - 1) * sigma)
|
||||
sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
|
||||
|
||||
# Convert to timesteps
|
||||
timesteps = sigmas * self.num_train_timesteps
|
||||
self.timesteps = mx.array(timesteps.astype(np.float32))
|
||||
|
||||
# Append terminal sigma=0
|
||||
sigmas = np.append(sigmas, 0.0)
|
||||
self.sigmas = mx.array(sigmas.astype(np.float32))
|
||||
self._step_index = 0
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: mx.array,
|
||||
timestep,
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""Euler step for flow matching.
|
||||
|
||||
In flow matching, model predicts velocity v, and:
|
||||
x_{t-1} = sample + (sigma_{t-1} - sigma_t) * v
|
||||
|
||||
Args:
|
||||
model_output: Predicted velocity [B, C, T, H, W]
|
||||
timestep: Current timestep (unused, step index is tracked internally)
|
||||
sample: Current noisy sample [B, C, T, H, W]
|
||||
|
||||
Returns:
|
||||
Updated sample
|
||||
"""
|
||||
# Use Python floats to avoid creating mx.array scalars that
|
||||
# could trigger type promotion (per fast-mlx guide)
|
||||
dt = float(self.sigmas[self._step_index + 1].item()) - float(self.sigmas[self._step_index].item())
|
||||
x_next = sample + dt * model_output
|
||||
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
def reset(self):
|
||||
"""Reset step counter for new generation."""
|
||||
self._step_index = 0
|
||||
234
mlx_video/models/wan/text_encoder.py
Normal file
234
mlx_video/models/wan/text_encoder.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
"""RMS-based layer normalization (T5 style)."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
"""T5-style relative position bias with bucketing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_buckets: int,
|
||||
num_heads: int,
|
||||
bidirectional: bool = True,
|
||||
max_dist: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
|
||||
rel_pos = mx.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
|
||||
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
|
||||
|
||||
max_exact = num_buckets // 2
|
||||
is_small = rel_pos < max_exact
|
||||
|
||||
rel_pos_f = rel_pos.astype(mx.float32)
|
||||
rel_pos_large = (
|
||||
max_exact
|
||||
+ (
|
||||
mx.log(rel_pos_f / max_exact)
|
||||
/ math.log(self.max_dist / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).astype(mx.int32)
|
||||
)
|
||||
rel_pos_large = mx.minimum(
|
||||
rel_pos_large,
|
||||
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
|
||||
)
|
||||
|
||||
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
|
||||
return rel_buckets
|
||||
|
||||
def __call__(self, lq: int, lk: int) -> mx.array:
|
||||
positions_k = mx.arange(lk)[None, :] # [1, lk]
|
||||
positions_q = mx.arange(lq)[:, None] # [lq, 1]
|
||||
rel_pos = positions_k - positions_q # [lq, lk]
|
||||
|
||||
buckets = self._relative_position_bucket(rel_pos)
|
||||
embeds = self.embedding(buckets) # [lq, lk, num_heads]
|
||||
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
|
||||
return embeds
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
"""T5-style multi-head attention (no scaling)."""
|
||||
|
||||
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
assert dim_attn % num_heads == 0
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: mx.array | None = None,
|
||||
mask: mx.array | None = None,
|
||||
pos_bias: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
context = x if context is None else context
|
||||
b, n, c = x.shape[0], self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
|
||||
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
|
||||
v = self.v(context).reshape(b, -1, n, c)
|
||||
|
||||
# T5 does not use scaling
|
||||
# attn = einsum('binc,bjnc->bnij', q, k)
|
||||
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Combine position bias and attention mask for SDPA
|
||||
attn_mask = None
|
||||
if pos_bias is not None:
|
||||
attn_mask = pos_bias.astype(q.dtype)
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
|
||||
elif mask.ndim == 3:
|
||||
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
|
||||
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype)
|
||||
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask
|
||||
|
||||
# T5 uses no scaling (scale=1.0)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=1.0, mask=attn_mask
|
||||
)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
|
||||
|
||||
def __init__(self, dim: int, dim_ffn: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.gate_act = nn.GELU(approx="precise")
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
|
||||
|
||||
|
||||
class T5SelfAttentionBlock(nn.Module):
|
||||
"""T5 encoder block: self-attention + FFN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_attn: int,
|
||||
dim_ffn: int,
|
||||
num_heads: int,
|
||||
num_buckets: int,
|
||||
shared_pos: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.shared_pos = shared_pos
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn)
|
||||
self.pos_embedding = (
|
||||
None
|
||||
if shared_pos
|
||||
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
pos_bias: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
|
||||
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
"""T5 Encoder (UMT5-XXL configuration)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 256384,
|
||||
dim: int = 4096,
|
||||
dim_attn: int = 4096,
|
||||
dim_ffn: int = 10240,
|
||||
num_heads: int = 64,
|
||||
num_layers: int = 24,
|
||||
num_buckets: int = 32,
|
||||
shared_pos: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
self.token_embedding = nn.Embedding(vocab_size, dim)
|
||||
self.pos_embedding = (
|
||||
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
|
||||
if shared_pos
|
||||
else None
|
||||
)
|
||||
self.blocks = [
|
||||
T5SelfAttentionBlock(
|
||||
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
"""
|
||||
Args:
|
||||
ids: Token IDs [B, L]
|
||||
mask: Attention mask [B, L]
|
||||
|
||||
Returns:
|
||||
Hidden states [B, L, dim]
|
||||
"""
|
||||
x = self.token_embedding(ids)
|
||||
|
||||
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask=mask, pos_bias=e)
|
||||
|
||||
x = self.norm(x)
|
||||
return x
|
||||
89
mlx_video/models/wan/transformer.py
Normal file
89
mlx_video/models/wan/transformer.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
window_size: tuple = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Self-attention
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||
|
||||
# Cross-attention (with optional norm on context)
|
||||
self.norm3 = (
|
||||
WanLayerNorm(dim, eps, elementwise_affine=True)
|
||||
if cross_attn_norm
|
||||
else None
|
||||
)
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
||||
|
||||
# Feed-forward
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = WanFFN(dim, ffn_dim)
|
||||
|
||||
# Learned modulation: 6 vectors for scale/shift/gate
|
||||
self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
e: mx.array,
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
cross_kv_cache: tuple | None = None,
|
||||
) -> mx.array:
|
||||
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
|
||||
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
|
||||
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
|
||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||
e3 = mod[:, :, 3, :] # shift for ffn
|
||||
e4 = mod[:, :, 4, :] # scale for ffn
|
||||
e5 = mod[:, :, 5, :] # gate for ffn
|
||||
|
||||
# Self-attention with modulation
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
||||
x = x + y * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||
|
||||
# FFN with modulation
|
||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||
y = self.ffn(x_mod)
|
||||
x = x + y * e5
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanFFN(nn.Module):
|
||||
"""Gated feed-forward network with GELU(tanh) activation."""
|
||||
|
||||
def __init__(self, dim: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, ffn_dim)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
315
mlx_video/models/wan/vae.py
Normal file
315
mlx_video/models/wan/vae.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""3D VAE Decoder for Wan2.1/2.2 (compression 4×8×8).
|
||||
|
||||
Module structure mirrors original PyTorch checkpoint key hierarchy
|
||||
so weights load directly without key sanitization.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization statistics for z_dim=16
|
||||
VAE_MEAN = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
|
||||
]
|
||||
VAE_STD = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
|
||||
]
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
"""3D convolution with causal temporal padding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple,
|
||||
stride: int | tuple = 1,
|
||||
padding: int | tuple = 0,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self._causal_pad_t = 2 * padding[0]
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
# MLX Conv3d: weight shape [O, D, H, W, I]
|
||||
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""x: [B, C, T, H, W] (channel-first)"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
if self._causal_pad_t > 0:
|
||||
pad_t = mx.zeros((b, c, self._causal_pad_t, h, w), dtype=x.dtype)
|
||||
x = mx.concatenate([pad_t, x], axis=2)
|
||||
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
|
||||
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
|
||||
|
||||
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
|
||||
out = self._conv3d(x)
|
||||
return out.transpose(0, 4, 1, 2, 3) # [B, O, T', H', W']
|
||||
|
||||
def _conv3d(self, x: mx.array) -> mx.array:
|
||||
"""3D conv via sliding window + 2D conv per time step.
|
||||
x: [B, T, H, W, C_in] -> [B, T_out, H_out, W_out, C_out]
|
||||
"""
|
||||
b, t, h, w, c_in = x.shape
|
||||
kt, kh, kw = self.kernel_size
|
||||
st, sh, sw = self.stride
|
||||
t_out = (t - kt) // st + 1
|
||||
|
||||
# Pre-reshape weight: [O, D, H, W, I] -> [O, H, W, D*I]
|
||||
w_2d = self.weight.transpose(0, 2, 3, 1, 4).reshape(
|
||||
self.weight.shape[0], kh, kw, kt * c_in
|
||||
)
|
||||
outputs = []
|
||||
for t_i in range(t_out):
|
||||
t_start = t_i * st
|
||||
window = x[:, t_start : t_start + kt]
|
||||
window = window.transpose(0, 2, 3, 1, 4).reshape(b, h, w, kt * c_in)
|
||||
out_2d = mx.conv2d(window, w_2d, stride=(sh, sw)) + self.bias
|
||||
outputs.append(out_2d)
|
||||
return mx.stack(outputs, axis=1)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
"""Channel-first L2 normalization matching original Wan VAE.
|
||||
|
||||
Uses F.normalize (L2 norm) with learned scale, equivalent to RMS norm.
|
||||
images=True: gamma shape (dim, 1, 1) for 4D (per-frame) input.
|
||||
images=False: gamma shape (dim, 1, 1, 1) for 5D video input.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True):
|
||||
super().__init__()
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
if channel_first:
|
||||
broadcastable = (1, 1) if images else (1, 1, 1)
|
||||
self.gamma = mx.ones((dim, *broadcastable))
|
||||
else:
|
||||
self.gamma = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
norm_dim = 1 if self.channel_first else -1
|
||||
# L2 normalize along channel dim (matches F.normalize)
|
||||
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
|
||||
return (x / norm) * self.scale * self.gamma
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""Residual block with causal 3D convolutions.
|
||||
|
||||
Uses `residual` list with None gaps to match original PyTorch
|
||||
nn.Sequential indices: [0]=norm, [1]=SiLU, [2]=conv, [3]=norm,
|
||||
[4]=SiLU, [5]=Dropout, [6]=conv. Only indices 0,2,3,6 have params.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.residual = [
|
||||
RMS_norm(in_dim, images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
|
||||
RMS_norm(out_dim, images=False), # [3]
|
||||
None, # [4] SiLU
|
||||
None, # [5] Dropout
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
|
||||
]
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
h = x if self.shortcut is None else self.shortcut(x)
|
||||
x = nn.silu(self.residual[0](x))
|
||||
x = self.residual[2](x)
|
||||
x = nn.silu(self.residual[3](x))
|
||||
x = self.residual[6](x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Single-head spatial self-attention."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.norm = RMS_norm(dim, images=True)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
identity = x
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
# [B,C,T,H,W] -> [B,T,C,H,W] -> [BT,C,H,W] -> norm -> [BT,H,W,C]
|
||||
x = x.transpose(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(0, 2, 3, 1) # [BT, H, W, C]
|
||||
|
||||
qkv = self.to_qkv(x) # [BT, H, W, 3C]
|
||||
qkv = qkv.reshape(b * t, h * w, 3, c).transpose(2, 0, 1, 3)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q[:, None, :, :] # [BT, 1, HW, C]
|
||||
k = k[:, None, :, :]
|
||||
v = v[:, None, :, :]
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=c**-0.5)
|
||||
out = out.squeeze(1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
|
||||
out = self.proj(out) # [BT, H, W, C]
|
||||
out = out.reshape(b, t, h, w, c).transpose(0, 4, 1, 2, 3) # [B, C, T, H, W]
|
||||
return out + identity
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
"""Upsample block matching original Wan VAE structure.
|
||||
|
||||
Uses `resample` list with [None, Conv2d] to match original
|
||||
nn.Sequential(Upsample, Conv2d) where index 1 has the conv params.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str):
|
||||
super().__init__()
|
||||
assert mode in ("upsample2d", "upsample3d")
|
||||
self.mode = mode
|
||||
self.dim = dim
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
if self.mode == "upsample3d":
|
||||
# Temporal upsample via learned conv
|
||||
x_t = self.time_conv(x) # [B, 2C, T, H, W]
|
||||
x_t = x_t.reshape(b, 2, c, t, h, w)
|
||||
# Interleave along time: [B, C, 2T, H, W]
|
||||
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
|
||||
t = t * 2
|
||||
|
||||
# Per-frame spatial upsample: nearest 2x + Conv2d
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.repeat(x, 2, axis=1)
|
||||
x = mx.repeat(x, 2, axis=2)
|
||||
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
|
||||
c_out = x.shape[-1]
|
||||
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
"""3D VAE Decoder matching Wan2.1 architecture.
|
||||
|
||||
Uses flat `middle` and `upsamples` lists to match original
|
||||
PyTorch nn.Sequential weight key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: list = None,
|
||||
num_res_blocks: int = 2,
|
||||
temporal_upsample: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dim_mult is None:
|
||||
dim_mult = [1, 2, 4, 4]
|
||||
if temporal_upsample is None:
|
||||
temporal_upsample = [True, True, False]
|
||||
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# Middle: [ResBlock, AttentionBlock, ResBlock]
|
||||
self.middle = [
|
||||
ResidualBlock(dims[0], dims[0]),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0]),
|
||||
]
|
||||
|
||||
# Flat upsample list matching original nn.Sequential indexing
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
if i in (1, 2, 3):
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim))
|
||||
in_dim = out_dim
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "upsample3d" if temporal_upsample[i] else "upsample2d"
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
self.upsamples = upsamples
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False), # [0]
|
||||
None, # [1] SiLU
|
||||
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""x: [B, z_dim, T, H, W] -> [B, 3, T_out, H_out, W_out]"""
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.middle:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.upsamples:
|
||||
x = layer(x)
|
||||
|
||||
x = nn.silu(self.head[0](x))
|
||||
x = self.head[2](x)
|
||||
return x
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
"""Wan2.1 VAE wrapper with per-channel normalization."""
|
||||
|
||||
def __init__(self, z_dim: int = 16):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.mean = mx.array(VAE_MEAN)
|
||||
self.std = mx.array(VAE_STD)
|
||||
self.inv_std = 1.0 / self.std
|
||||
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
|
||||
|
||||
def decode(self, z: mx.array) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
Args:
|
||||
z: Normalized latent [B, z_dim, T, H, W]
|
||||
|
||||
Returns:
|
||||
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
|
||||
"""
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
z = z / inv_std + mean
|
||||
|
||||
x = self.conv2(z)
|
||||
out = self.decoder(x)
|
||||
return mx.clip(out, -1, 1)
|
||||
584
mlx_video/models/wan/vae22.py
Normal file
584
mlx_video/models/wan/vae22.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""Wan2.2 VAE Decoder (compression 4×16×16, z_dim=48).
|
||||
|
||||
Architecture differs from Wan2.1 VAE: uses RMS_norm, DupUp3D shortcuts,
|
||||
spatial patchify (2×2), and different temporal upsampling pattern.
|
||||
|
||||
Weight keys mirror the PyTorch checkpoint hierarchy so only tensor format
|
||||
conversion (channels-first → channels-last) is needed.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
# Per-channel normalization for z_dim=48 latent space
|
||||
VAE22_MEAN = mx.array([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
||||
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
||||
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
||||
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
||||
])
|
||||
|
||||
VAE22_STD = mx.array([
|
||||
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
||||
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
||||
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
||||
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
||||
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
||||
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
|
||||
])
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
"""3D causal convolution. Input/output: [B, T, H, W, C] (channels-last).
|
||||
|
||||
Decomposes the 3D conv into per-frame 2D convolutions to avoid
|
||||
excessive memory usage from MLX's conv3d implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self._causal_pad_t = 2 * padding[0]
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
# Weight: [O, D, H, W, I] for MLX
|
||||
self.weight = mx.zeros((
|
||||
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
|
||||
))
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
kd, kh, kw = self.kernel_size
|
||||
|
||||
# For 1x1x1 kernel or kernel_d==1, use direct conv
|
||||
if kd == 1 and kh == 1 and kw == 1:
|
||||
# Simple pointwise: reshape to [B*T, 1, 1, C] → conv2d
|
||||
x_flat = x.reshape(B * T, H, W, C)
|
||||
w2d = self.weight[:, 0, :, :, :] # [O, kH, kW, I]
|
||||
y = mx.conv_general(x_flat, w2d) + self.bias
|
||||
return y.reshape(B, T, y.shape[1], y.shape[2], -1)
|
||||
|
||||
# Causal temporal padding (left only)
|
||||
if self._causal_pad_t > 0:
|
||||
pad_t = mx.zeros((B, self._causal_pad_t, H, W, C))
|
||||
x = mx.concatenate([pad_t, x], axis=1)
|
||||
|
||||
# Spatial padding
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
|
||||
(self._pad_w, self._pad_w), (0, 0)])
|
||||
|
||||
T_padded = x.shape[1]
|
||||
H_padded, W_padded = x.shape[2], x.shape[3]
|
||||
T_out = (T_padded - kd) // self.stride[0] + 1
|
||||
|
||||
# Decompose 3D conv into sum of 2D convolutions over temporal kernel
|
||||
# weight shape: [O, kd, kh, kw, I] → split into kd 2D kernels [O, kh, kw, I]
|
||||
outputs = []
|
||||
for t in range(T_out):
|
||||
t_start = t * self.stride[0]
|
||||
# Sum 2D convs for each temporal kernel position
|
||||
accum = None
|
||||
for d in range(kd):
|
||||
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
|
||||
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
|
||||
conv_out = mx.conv_general(frame, w2d,
|
||||
stride=(self.stride[1], self.stride[2]))
|
||||
accum = conv_out if accum is None else accum + conv_out
|
||||
outputs.append(accum + self.bias)
|
||||
|
||||
return mx.stack(outputs, axis=1) # [B, T_out, H_out, W_out, O]
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
"""RMS normalization along channel dimension."""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
|
||||
self.gamma = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
# x: [..., C] (channels-last)
|
||||
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
|
||||
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
|
||||
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""Residual block: RMS_norm → SiLU → CausalConv3d × 2 + shortcut."""
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
|
||||
# We store as named layers matching PyTorch's indices
|
||||
self.residual = ResidualBlockLayers(in_dim, out_dim)
|
||||
self.shortcut = (
|
||||
CausalConv3d(in_dim, out_dim, 1)
|
||||
if in_dim != out_dim
|
||||
else None
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
h = self.shortcut(x) if self.shortcut is not None else x
|
||||
return self.residual(x) + h
|
||||
|
||||
|
||||
class ResidualBlockLayers(nn.Module):
|
||||
"""The sequential layers inside a ResidualBlock.
|
||||
|
||||
PyTorch stores these as nn.Sequential with indices 0-6:
|
||||
[0] RMS_norm, [1] SiLU, [2] CausalConv3d, [3] RMS_norm, [4] SiLU, [5] Dropout, [6] CausalConv3d
|
||||
We use matching attribute names for weight compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
# Indices match PyTorch nn.Sequential indices for weight key compat
|
||||
# Index 0: RMS_norm
|
||||
self.layer_0 = RMS_norm(in_dim)
|
||||
# Index 2: CausalConv3d
|
||||
self.layer_2 = CausalConv3d(in_dim, out_dim, 3, padding=1)
|
||||
# Index 3: RMS_norm
|
||||
self.layer_3 = RMS_norm(out_dim)
|
||||
# Index 6: CausalConv3d
|
||||
self.layer_6 = CausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.layer_0(x)
|
||||
x = nn.silu(x)
|
||||
x = self.layer_2(x)
|
||||
mx.eval(x) # Eval between convolutions to limit graph size
|
||||
x = self.layer_3(x)
|
||||
x = nn.silu(x)
|
||||
x = self.layer_6(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""2D self-attention applied per frame. Input: [B, T, H, W, C]."""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.norm = RMS_norm(dim)
|
||||
# Conv2d as linear per spatial position — weight [O, H, W, I] for MLX
|
||||
# to_qkv: dim -> 3*dim, proj: dim -> dim (1x1 conv2d)
|
||||
self.to_qkv_weight = mx.zeros((3 * dim, 1, 1, dim))
|
||||
self.to_qkv_bias = mx.zeros((3 * dim,))
|
||||
self.proj_weight = mx.zeros((dim, 1, 1, dim))
|
||||
self.proj_bias = mx.zeros((dim,))
|
||||
|
||||
def __call__(self, x):
|
||||
# x: [B, T, H, W, C]
|
||||
identity = x
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
# Apply per frame: merge B and T
|
||||
x = x.reshape(B * T, H, W, C)
|
||||
x = self.norm(x)
|
||||
|
||||
# QKV via 1x1 conv2d (equivalent to linear on last dim)
|
||||
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
|
||||
qkv = qkv.reshape(B * T, H * W, 3 * C)
|
||||
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
|
||||
|
||||
# Single-head attention
|
||||
q = q[:, None, :, :] # [BT, 1, HW, C]
|
||||
k = k[:, None, :, :]
|
||||
v = v[:, None, :, :]
|
||||
|
||||
scale = C ** -0.5
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
|
||||
out = out.squeeze(1).reshape(B * T, H, W, C)
|
||||
|
||||
# Project output
|
||||
out = mx.conv_general(out, self.proj_weight) + self.proj_bias # [BT, H, W, C]
|
||||
out = out.reshape(B, T, H, W, C)
|
||||
return out + identity
|
||||
|
||||
|
||||
class DupUp3D(nn.Module):
|
||||
"""Upsample by duplicating channels and reshaping. No learnable parameters."""
|
||||
|
||||
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.factor_t = factor_t
|
||||
self.factor_s = factor_s
|
||||
self.factor = factor_t * factor_s * factor_s
|
||||
self.repeats = out_channels * self.factor // in_channels
|
||||
|
||||
def __call__(self, x, first_chunk=False):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
# Repeat channels
|
||||
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
|
||||
|
||||
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
|
||||
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
|
||||
|
||||
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
|
||||
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
|
||||
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
|
||||
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
|
||||
|
||||
if first_chunk:
|
||||
x = x[:, self.factor_t - 1:, :, :, :]
|
||||
return x
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
"""Spatial up/downsampling with optional temporal up/downsampling."""
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
if mode == "upsample2d":
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample_weight = mx.zeros((dim, 3, 3, dim)) # Conv2d [O, H, W, I]
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
elif mode == "upsample3d":
|
||||
self.resample_weight = mx.zeros((dim, 3, 3, dim))
|
||||
self.resample_bias = mx.zeros((dim,))
|
||||
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
else:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
|
||||
def _upsample2x(self, x):
|
||||
"""Nearest-neighbor 2x spatial upsample. x: [N, H, W, C]."""
|
||||
N, H, W, C = x.shape
|
||||
# Repeat along H and W axes separately
|
||||
x = mx.repeat(x, repeats=2, axis=1) # [N, 2H, W, C]
|
||||
x = mx.repeat(x, repeats=2, axis=2) # [N, 2H, 2W, C]
|
||||
return x
|
||||
|
||||
def _conv2d(self, x):
|
||||
"""Apply the Conv2d with padding=1. x: [N, H, W, C]."""
|
||||
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
|
||||
return mx.conv_general(x, self.resample_weight) + self.resample_bias
|
||||
|
||||
def __call__(self, x, first_chunk=False):
|
||||
# x: [B, T, H, W, C]
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
if self.mode == "upsample3d":
|
||||
# Temporal upsample via time_conv
|
||||
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
|
||||
# Split into two interleaved temporal streams
|
||||
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
||||
# Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C]
|
||||
stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C]
|
||||
stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C]
|
||||
x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C]
|
||||
x = x.reshape(B, T * 2, H, W, C)
|
||||
|
||||
if first_chunk:
|
||||
# PyTorch skips time_conv for first chunk entirely. In all-at-once
|
||||
# mode, we trim the first frame to match (the first interleaved
|
||||
# frame is from zero-padded causal context and shouldn't be kept).
|
||||
x = x[:, 1:, :, :, :]
|
||||
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
# Spatial upsample in temporal chunks to limit peak memory
|
||||
chunk_size = 8
|
||||
chunks = []
|
||||
for t_start in range(0, T, chunk_size):
|
||||
t_end = min(t_start + chunk_size, T)
|
||||
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
|
||||
x_chunk = self._upsample2x(x_chunk)
|
||||
x_chunk = self._conv2d(x_chunk)
|
||||
mx.eval(x_chunk)
|
||||
chunks.append(x_chunk)
|
||||
|
||||
x = mx.concatenate(chunks, axis=0)
|
||||
H2, W2 = x.shape[1], x.shape[2]
|
||||
x = x.reshape(B, T, H2, W2, C)
|
||||
return x
|
||||
|
||||
|
||||
class Up_ResidualBlock(nn.Module):
|
||||
"""Upsampling residual block with optional DupUp3D shortcut."""
|
||||
|
||||
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
|
||||
super().__init__()
|
||||
self.up_flag = up_flag
|
||||
|
||||
# DupUp3D shortcut (no learnable params)
|
||||
if up_flag:
|
||||
self.avg_shortcut = DupUp3D(
|
||||
in_dim, out_dim,
|
||||
factor_t=2 if temperal_upsample else 1,
|
||||
factor_s=2 if up_flag else 1,
|
||||
)
|
||||
else:
|
||||
self.avg_shortcut = None
|
||||
|
||||
# Main path: ResidualBlocks + optional Resample
|
||||
blocks = []
|
||||
dim_in = in_dim
|
||||
for _ in range(num_res_blocks):
|
||||
blocks.append(ResidualBlock(dim_in, out_dim))
|
||||
dim_in = out_dim
|
||||
|
||||
if up_flag:
|
||||
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
||||
blocks.append(Resample(out_dim, mode=mode))
|
||||
|
||||
self.upsamples = blocks
|
||||
|
||||
def __call__(self, x, first_chunk=False):
|
||||
x_main = x
|
||||
for module in self.upsamples:
|
||||
if isinstance(module, Resample):
|
||||
x_main = module(x_main, first_chunk)
|
||||
else:
|
||||
x_main = module(x_main)
|
||||
mx.eval(x_main) # Limit graph size per sub-block
|
||||
|
||||
if self.avg_shortcut is not None:
|
||||
x_shortcut = self.avg_shortcut(x, first_chunk)
|
||||
mx.eval(x_shortcut)
|
||||
return x_main + x_shortcut
|
||||
return x_main
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
"""Wan2.2 3D VAE Decoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=256,
|
||||
z_dim=48,
|
||||
dim_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
temperal_upsample=(True, True, False),
|
||||
):
|
||||
super().__init__()
|
||||
# Compute layer dimensions
|
||||
dims = [dim * dim_mult[-1]] + [dim * m for m in reversed(dim_mult)]
|
||||
|
||||
# Initial conv
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# Middle blocks
|
||||
self.middle = [
|
||||
ResidualBlock(dims[0], dims[0]),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0]),
|
||||
]
|
||||
|
||||
# Upsample blocks
|
||||
self.upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
|
||||
self.upsamples.append(Up_ResidualBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks + 1,
|
||||
temperal_upsample=t_up,
|
||||
up_flag=(i != len(dim_mult) - 1),
|
||||
))
|
||||
|
||||
# Output head: [RMS_norm, SiLU, CausalConv3d]
|
||||
self.head = Head22(dims[-1])
|
||||
|
||||
def __call__(self, x, first_chunk=False):
|
||||
# x: [B, T, H, W, C=z_dim]
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.middle:
|
||||
x = layer(x)
|
||||
mx.eval(x) # Evaluate to limit graph size
|
||||
|
||||
for i, layer in enumerate(self.upsamples):
|
||||
x = layer(x, first_chunk)
|
||||
mx.eval(x) # Evaluate after each upsample block
|
||||
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class Head22(nn.Module):
|
||||
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
|
||||
|
||||
PyTorch key mapping: head.0 = RMS_norm, head.2 = CausalConv3d
|
||||
(index 1 = SiLU has no params)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, out_channels=12):
|
||||
super().__init__()
|
||||
# Index 0: RMS_norm
|
||||
self.layer_0 = RMS_norm(dim)
|
||||
# Index 2: CausalConv3d
|
||||
self.layer_2 = CausalConv3d(dim, out_channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.layer_0(x)
|
||||
x = nn.silu(x)
|
||||
x = self.layer_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class Wan22VAEDecoder(nn.Module):
|
||||
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
|
||||
|
||||
def __init__(self, z_dim=48, dim=160, dec_dim=256):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
# conv2: 1x1x1 conv before decoder
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(
|
||||
dim=dec_dim,
|
||||
z_dim=z_dim,
|
||||
dim_mult=(1, 2, 4, 4),
|
||||
num_res_blocks=2,
|
||||
temperal_upsample=(True, True, False),
|
||||
)
|
||||
|
||||
def __call__(self, z):
|
||||
"""Decode latents to video.
|
||||
|
||||
Args:
|
||||
z: [B, T, H, W, C=48] latent tensor (already denormalized)
|
||||
|
||||
Returns:
|
||||
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
|
||||
"""
|
||||
x = self.conv2(z)
|
||||
|
||||
# All-at-once decode with first_chunk=True to trim extra temporal
|
||||
# frames from causal padding (matches PyTorch's chunked behavior)
|
||||
out = self.decoder(x, first_chunk=True)
|
||||
|
||||
# Unpatchify: 12 channels → 3 RGB (spatial 2×2)
|
||||
out = _unpatchify(out, patch_size=2)
|
||||
|
||||
return mx.clip(out, -1.0, 1.0)
|
||||
|
||||
|
||||
def denormalize_latents(z, mean=None, std=None):
|
||||
"""Denormalize latents: z = z / (1/std) + mean."""
|
||||
if mean is None:
|
||||
mean = VAE22_MEAN
|
||||
if std is None:
|
||||
std = VAE22_STD
|
||||
inv_scale = std # scale was 1/std, so divide by scale = multiply by std
|
||||
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
|
||||
|
||||
|
||||
def _unpatchify(x, patch_size=2):
|
||||
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
|
||||
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
|
||||
PyTorch: b (c r q) f h w -> b c f (h q) (w r) with q=p, r=p
|
||||
In channels-last: [B, T, H, W, C*r*q] -> [B, T, H*q, W*r, C]
|
||||
"""
|
||||
if patch_size == 1:
|
||||
return x
|
||||
B, T, H, W, Cpacked = x.shape
|
||||
C = Cpacked // (patch_size * patch_size)
|
||||
# Reshape: [B, T, H, W, r, q, C] then rearrange to [B, T, H*q, W*r, C]
|
||||
# PyTorch patchify: "b c f (h q) (w r) -> b (c r q) f h w" — so c is packed as (c, r, q)
|
||||
# Unpatchify reverses: [B, T, H, W, (C, r, q)] -> [B, T, H, q, W, r, C]
|
||||
x = x.reshape(B, T, H, W, C, patch_size, patch_size)
|
||||
# Rearrange: put q next to H, r next to W
|
||||
x = x.transpose(0, 1, 2, 6, 3, 5, 4) # [B, T, H, q, W, r, C]
|
||||
x = x.reshape(B, T, H * patch_size, W * patch_size, C)
|
||||
return x
|
||||
|
||||
|
||||
def sanitize_wan22_vae_weights(weights: dict) -> dict:
|
||||
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
|
||||
|
||||
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
|
||||
Transposes conv weights from channels-first to channels-last.
|
||||
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
|
||||
Maps PyTorch nn.Sequential indices to our named layers.
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
# Skip encoder and conv1 (encoder-only)
|
||||
if key.startswith("encoder.") or key.startswith("conv1."):
|
||||
continue
|
||||
|
||||
new_key = key
|
||||
|
||||
# Map nn.Sequential indexed layers to our named attributes
|
||||
# ResidualBlockLayers: indices 0, 2, 3, 6 → _layer_0, _layer_2, _layer_3, _layer_6
|
||||
# Head22: indices 0, 2 → _layer_0, _layer_2
|
||||
for idx in ["0", "2", "3", "6"]:
|
||||
# Match patterns like "residual.0.gamma" → "residual.layer_0.gamma"
|
||||
# or "head.0.gamma" → "head.layer_0.gamma"
|
||||
old_pattern = f".residual.{idx}."
|
||||
new_pattern = f".residual.layer_{idx}."
|
||||
new_key = new_key.replace(old_pattern, new_pattern)
|
||||
|
||||
# Head layer mapping: head.0.gamma → head.layer_0.gamma, head.2.weight → head.layer_2.weight
|
||||
for idx in ["0", "2"]:
|
||||
old_pattern = f".head.{idx}."
|
||||
new_pattern = f".head.layer_{idx}."
|
||||
new_key = new_key.replace(old_pattern, new_pattern)
|
||||
|
||||
# Map Resample Conv2d: resample.1.weight → resample_weight, resample.1.bias → resample_bias
|
||||
if ".resample.1.weight" in new_key:
|
||||
new_key = new_key.replace(".resample.1.weight", ".resample_weight")
|
||||
elif ".resample.1.bias" in new_key:
|
||||
new_key = new_key.replace(".resample.1.bias", ".resample_bias")
|
||||
|
||||
# Map AttentionBlock Conv2d weights
|
||||
if ".to_qkv.weight" in new_key:
|
||||
new_key = new_key.replace(".to_qkv.weight", ".to_qkv_weight")
|
||||
elif ".to_qkv.bias" in new_key:
|
||||
new_key = new_key.replace(".to_qkv.bias", ".to_qkv_bias")
|
||||
elif ".proj.weight" in new_key and "time_projection" not in new_key:
|
||||
new_key = new_key.replace(".proj.weight", ".proj_weight")
|
||||
elif ".proj.bias" in new_key and "time_projection" not in new_key:
|
||||
new_key = new_key.replace(".proj.bias", ".proj_bias")
|
||||
|
||||
# Transpose conv weights to channels-last
|
||||
is_weight = new_key.endswith(".weight") or new_key.endswith("_weight")
|
||||
if is_weight:
|
||||
if value.ndim == 5:
|
||||
# Conv3d: [O, I, D, H, W] → [O, D, H, W, I]
|
||||
value = np.transpose(np.array(value), (0, 2, 3, 4, 1))
|
||||
value = mx.array(value)
|
||||
elif value.ndim == 4:
|
||||
# Conv2d: [O, I, H, W] → [O, H, W, I]
|
||||
value = np.transpose(np.array(value), (0, 2, 3, 1))
|
||||
value = mx.array(value)
|
||||
|
||||
# Squeeze RMS_norm gamma: (dim, 1, 1, 1) or (dim, 1, 1) → (dim,)
|
||||
if "gamma" in new_key:
|
||||
value = mx.array(np.array(value).squeeze())
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
Reference in New Issue
Block a user