feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -67,6 +67,8 @@ class WanSelfAttention(nn.Module):
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
@@ -87,19 +89,18 @@ class WanSelfAttention(nn.Module):
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
|
||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
|
||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
||||
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Build attention mask from seq_lens
|
||||
max_len = s
|
||||
mask = None
|
||||
if any(sl < max_len for sl in seq_lens):
|
||||
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
|
||||
# Use precomputed mask or build from seq_lens
|
||||
mask = attn_mask
|
||||
if mask is None and any(sl < s for sl in seq_lens):
|
||||
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
|
||||
for i, sl in enumerate(seq_lens):
|
||||
mask[i, :, :, sl:] = -1e9
|
||||
|
||||
|
||||
@@ -91,6 +91,19 @@ class WanModelConfig(BaseModelConfig):
|
||||
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def wan22_i2v_14b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 I2V 14B: dual model, image-to-video, 40 layers, dim=5120."""
|
||||
return cls(
|
||||
model_type="i2v",
|
||||
in_dim=36,
|
||||
out_dim=16,
|
||||
dual_model=True,
|
||||
boundary=0.900,
|
||||
sample_shift=5.0,
|
||||
sample_guide_scale=(3.5, 3.5),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wan22_ti2v_5b(cls) -> "WanModelConfig":
|
||||
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
|
||||
|
||||
@@ -87,16 +87,23 @@ def load_vae_decoder(model_path: Path, config=None):
|
||||
def load_vae_encoder(model_path: Path, config=None):
|
||||
"""Load VAE encoder for I2V image encoding.
|
||||
|
||||
Only supports Wan2.2 (vae_z_dim=48).
|
||||
For Wan2.2 TI2V (vae_z_dim=48), uses Wan22VAEEncoder.
|
||||
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
|
||||
"""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
if config is not None and config.vae_z_dim == 16:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16, encoder=True)
|
||||
else:
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
|
||||
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
|
||||
|
||||
encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from .attention import WanLayerNorm
|
||||
from .config import WanModelConfig
|
||||
from .rope import rope_params
|
||||
from .rope import rope_params, rope_precompute_cos_sin
|
||||
from .transformer import WanAttentionBlock
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class Head(nn.Module):
|
||||
proj_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, proj_dim)
|
||||
self.modulation = mx.random.normal((1, 2, dim)) * (dim**-0.5)
|
||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -48,14 +48,13 @@ class Head(nn.Module):
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
e_f32 = e.astype(mx.float32)
|
||||
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze
|
||||
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
|
||||
# modulation already float32; e already float32 from model forward
|
||||
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # [B, L_e, 2, dim]
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x).astype(mx.float32)
|
||||
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1
|
||||
return self.head(x_mod.astype(x.dtype))
|
||||
x_norm = self.norm(x)
|
||||
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
|
||||
return self.head(x_mod.astype(self.head.weight.dtype))
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
@@ -109,17 +108,16 @@ class WanModel(nn.Module):
|
||||
# Output head
|
||||
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
|
||||
|
||||
# Precompute RoPE frequencies
|
||||
d = dim // config.num_heads
|
||||
d_t = d - 4 * (d // 6)
|
||||
d_h = 2 * (d // 6)
|
||||
d_w = 2 * (d // 6)
|
||||
# Each rope_params returns [1024, d_x//2, 2]
|
||||
freqs_t = rope_params(1024, d_t)
|
||||
freqs_h = rope_params(1024, d_h)
|
||||
freqs_w = rope_params(1024, d_w)
|
||||
# Concatenate along the frequency dimension: [1024, d//2, 2]
|
||||
self.freqs = mx.concatenate([freqs_t, freqs_h, freqs_w], axis=1)
|
||||
# Precompute RoPE frequencies — single table, split by rope_apply
|
||||
# Reference computes one rope_params(head_dim) and splits into t/h/w.
|
||||
self.freqs = rope_params(1024, dim // config.num_heads)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding
|
||||
half = config.freq_dim // 2
|
||||
self._inv_freq = mx.power(
|
||||
10000.0, -mx.arange(half).astype(mx.float32) / half
|
||||
)
|
||||
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
"""Convert video tensor to patch embeddings.
|
||||
@@ -215,6 +213,21 @@ class WanModel(nn.Module):
|
||||
kv_caches.append(block.cross_attn.prepare_kv(context))
|
||||
return kv_caches
|
||||
|
||||
def prepare_rope(self, grid_sizes: list) -> tuple:
|
||||
"""Pre-compute RoPE cos/sin for constant grid sizes.
|
||||
|
||||
Call once before the diffusion loop when grid sizes don't change
|
||||
across steps. Eliminates per-step broadcast/concat overhead.
|
||||
|
||||
Args:
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
|
||||
Returns:
|
||||
(cos_f, sin_f) precomputed frequency tensors
|
||||
"""
|
||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
||||
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x_list: list,
|
||||
@@ -222,6 +235,8 @@ class WanModel(nn.Module):
|
||||
context: list | mx.array,
|
||||
seq_len: int,
|
||||
cross_kv_caches: list | None = None,
|
||||
y: list | None = None,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
) -> list:
|
||||
"""Forward pass.
|
||||
|
||||
@@ -233,42 +248,70 @@ class WanModel(nn.Module):
|
||||
seq_len: Maximum sequence length for padding
|
||||
cross_kv_caches: Optional list of (k, v) tuples from
|
||||
prepare_cross_kv(), one per block.
|
||||
y: Optional list of conditioning tensors for I2V [C_y, F, H, W].
|
||||
Channel-concatenated with x before patchify.
|
||||
rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope().
|
||||
|
||||
Returns:
|
||||
List of denoised tensors [C, F, H, W]
|
||||
"""
|
||||
# Patchify each video
|
||||
patches = []
|
||||
grid_sizes = []
|
||||
seq_lens_list = []
|
||||
for vid in x_list:
|
||||
p, gs = self._patchify(vid) # [1, L, dim]
|
||||
patches.append(p)
|
||||
grid_sizes.append(gs)
|
||||
seq_lens_list.append(p.shape[1])
|
||||
# Detect identical inputs (CFG B=2) to avoid duplicate patchify work.
|
||||
# Check BEFORE I2V concat since concat creates new array objects.
|
||||
batch_size = len(x_list)
|
||||
all_same = batch_size > 1 and all(
|
||||
x_list[i] is x_list[0] for i in range(1, batch_size)
|
||||
)
|
||||
if all_same and y is not None:
|
||||
all_same = all(y[i] is y[0] for i in range(1, len(y)))
|
||||
|
||||
# Pad and batch
|
||||
batch_size = len(patches)
|
||||
x = mx.concatenate(
|
||||
[
|
||||
mx.concatenate(
|
||||
# I2V: channel-concatenate conditioning y with noise x
|
||||
if y is not None:
|
||||
x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)]
|
||||
|
||||
if all_same:
|
||||
# Patchify once and broadcast — saves a Linear projection per step
|
||||
p, gs = self._patchify(x_list[0]) # [1, L, dim]
|
||||
grid_sizes = [gs] * batch_size
|
||||
seq_lens_list = [p.shape[1]] * batch_size
|
||||
# Pad and broadcast
|
||||
if p.shape[1] < seq_len:
|
||||
p = mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
x = mx.broadcast_to(p, (batch_size,) + p.shape[1:])
|
||||
else:
|
||||
patches = []
|
||||
grid_sizes = []
|
||||
seq_lens_list = []
|
||||
for vid in x_list:
|
||||
p, gs = self._patchify(vid) # [1, L, dim]
|
||||
patches.append(p)
|
||||
grid_sizes.append(gs)
|
||||
seq_lens_list.append(p.shape[1])
|
||||
x = mx.concatenate(
|
||||
[
|
||||
mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding
|
||||
# Time embedding (use cached inv_freq to avoid recomputing each step)
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
|
||||
pos = t.astype(mx.float32)
|
||||
sinusoid = pos[..., None] * self._inv_freq
|
||||
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
@@ -278,7 +321,6 @@ class WanModel(nn.Module):
|
||||
e = e.astype(mx.float32)
|
||||
else:
|
||||
# I2V: per-token timesteps [B, L]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, L, dim]
|
||||
@@ -298,7 +340,15 @@ class WanModel(nn.Module):
|
||||
else:
|
||||
context_batch = self.embed_text(context)
|
||||
|
||||
# Run transformer blocks
|
||||
# Pre-compute attention mask from seq_lens (constant across all blocks)
|
||||
attn_mask = None
|
||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
||||
if any(sl < seq_len for sl in seq_lens_list):
|
||||
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
|
||||
for i, sl in enumerate(seq_lens_list):
|
||||
attn_mask[i, :, :, sl:] = -1e9
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens_list,
|
||||
@@ -306,8 +356,11 @@ class WanModel(nn.Module):
|
||||
freqs=self.freqs,
|
||||
context=context_batch,
|
||||
context_lens=None,
|
||||
rope_cos_sin=rope_cos_sin,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
# Run transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
|
||||
x = block(x, cross_kv_cache=kv, **kwargs)
|
||||
|
||||
@@ -28,6 +28,7 @@ def rope_apply(
|
||||
x: mx.array,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
precomputed_cos_sin: tuple | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply 3-way factorized RoPE to Q or K tensor.
|
||||
|
||||
@@ -35,10 +36,48 @@ def rope_apply(
|
||||
x: Shape [B, L, num_heads, head_dim]
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
|
||||
precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin()
|
||||
"""
|
||||
b, s, n, d = x.shape
|
||||
half_d = d // 2
|
||||
|
||||
if precomputed_cos_sin is not None:
|
||||
cos_f, sin_f = precomputed_cos_sin
|
||||
# Check if all batch elements have the same grid (common for CFG B=2)
|
||||
f0, h0, w0 = grid_sizes[0]
|
||||
seq_len = f0 * h0 * w0
|
||||
all_same_grid = all(
|
||||
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
|
||||
) if b > 1 else True
|
||||
|
||||
if all_same_grid:
|
||||
# Vectorized path: apply RoPE to all batch elements at once
|
||||
x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2)
|
||||
x_real = x_seq[..., 0]
|
||||
x_imag = x_seq[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
|
||||
return x_rotated
|
||||
else:
|
||||
# Per-element path for mixed grid sizes
|
||||
outputs = []
|
||||
for i in range(b):
|
||||
f, h, w = grid_sizes[i]
|
||||
sl = f * h * w
|
||||
x_i = x[i, :sl].reshape(sl, n, half_d, 2)
|
||||
x_real = x_i[..., 0]
|
||||
x_imag = x_i[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d)
|
||||
if sl < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0)
|
||||
outputs.append(x_rotated)
|
||||
return mx.stack(outputs)
|
||||
|
||||
# Cast freqs to input dtype to prevent float32 promotion cascade
|
||||
if freqs.dtype != x.dtype:
|
||||
freqs = freqs.astype(x.dtype)
|
||||
@@ -98,3 +137,42 @@ def rope_apply(
|
||||
outputs.append(x_rotated)
|
||||
|
||||
return mx.stack(outputs)
|
||||
|
||||
|
||||
def rope_precompute_cos_sin(
|
||||
grid_sizes: list, freqs: mx.array, dtype: type = mx.float32
|
||||
) -> tuple:
|
||||
"""Precompute cos/sin frequency tensors for constant grid sizes.
|
||||
|
||||
Call once before the diffusion loop. Pass result as precomputed_cos_sin
|
||||
to rope_apply to skip per-step broadcast/concat.
|
||||
|
||||
Args:
|
||||
grid_sizes: List of (F, H, W) tuples (must be same for all batch elements)
|
||||
freqs: Precomputed frequencies [1024, d//2, 2]
|
||||
dtype: Target dtype for the output tensors
|
||||
|
||||
Returns:
|
||||
(cos_f, sin_f) each [seq_len, 1, half_d]
|
||||
"""
|
||||
if freqs.dtype != dtype:
|
||||
freqs = freqs.astype(dtype)
|
||||
|
||||
f, h, w = grid_sizes[0]
|
||||
seq_len = f * h * w
|
||||
half_d = freqs.shape[1]
|
||||
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
|
||||
freqs_t = freqs[:, :d_t]
|
||||
freqs_h = freqs[:, d_t : d_t + d_h]
|
||||
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w]
|
||||
|
||||
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
return freqs_i[..., 0], freqs_i[..., 1]
|
||||
|
||||
@@ -34,6 +34,8 @@ class FlowMatchEulerScheduler:
|
||||
sigmas = _compute_sigmas(num_steps, shift)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||||
# Store as Python floats to avoid .item() sync in step()
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
|
||||
def step(
|
||||
@@ -43,9 +45,7 @@ class FlowMatchEulerScheduler:
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||||
dt = float(self.sigmas[self._step_index + 1].item()) - float(
|
||||
self.sigmas[self._step_index].item()
|
||||
)
|
||||
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
|
||||
x_next = sample + dt * model_output
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
@@ -35,8 +35,8 @@ class WanAttentionBlock(nn.Module):
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = WanFFN(dim, ffn_dim)
|
||||
|
||||
# Learned modulation: 6 vectors for scale/shift/gate
|
||||
self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5)
|
||||
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
|
||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -48,10 +48,11 @@ class WanAttentionBlock(nn.Module):
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
cross_kv_cache: tuple | None = None,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Modulation in float32 (matching official torch.amp.autocast float32)
|
||||
e_f32 = e.astype(mx.float32)
|
||||
mod = self.modulation.astype(mx.float32) + e_f32
|
||||
# Modulation in float32 (e is already float32 from model forward)
|
||||
mod = self.modulation + e
|
||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||
@@ -59,19 +60,20 @@ class WanAttentionBlock(nn.Module):
|
||||
e4 = mod[:, :, 4, :] # scale for ffn
|
||||
e5 = mod[:, :, 5, :] # gate for ffn
|
||||
|
||||
# Self-attention with modulation (norm output in float32)
|
||||
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
||||
x = x.astype(mx.float32) + y.astype(mx.float32) * e2
|
||||
# Self-attention with modulation
|
||||
# Type promotion handles bf16→f32 automatically when multiplied with f32 modulation
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
||||
x = x + y * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||
|
||||
# FFN with modulation (norm output in float32)
|
||||
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3
|
||||
# FFN with modulation
|
||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||
y = self.ffn(x_mod)
|
||||
x = x + y.astype(mx.float32) * e5
|
||||
x = x + y * e5
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -43,7 +43,9 @@ class CausalConv3d(nn.Module):
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self._causal_pad_t = 2 * padding[0]
|
||||
# Causal padding: match reference formula dilation*(k-1) + (1-stride)
|
||||
# With dilation=1: k-stride (pads left only, no future context)
|
||||
self._causal_pad_t = kernel_size[0] - stride[0]
|
||||
self._pad_h = padding[1]
|
||||
self._pad_w = padding[2]
|
||||
|
||||
@@ -51,12 +53,17 @@ class CausalConv3d(nn.Module):
|
||||
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
|
||||
"""x: [B, C, T, H, W] (channel-first)"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
if self._causal_pad_t > 0:
|
||||
pad_t = mx.zeros((b, c, self._causal_pad_t, h, w), dtype=x.dtype)
|
||||
causal_pad = self._causal_pad_t
|
||||
if cache_x is not None and causal_pad > 0:
|
||||
x = mx.concatenate([cache_x, x], axis=2)
|
||||
causal_pad = max(0, causal_pad - cache_x.shape[2])
|
||||
|
||||
if causal_pad > 0:
|
||||
pad_t = mx.zeros((b, c, causal_pad, h, w), dtype=x.dtype)
|
||||
x = mx.concatenate([pad_t, x], axis=2)
|
||||
|
||||
if self._pad_h > 0 or self._pad_w > 0:
|
||||
@@ -136,12 +143,35 @@ class ResidualBlock(nn.Module):
|
||||
]
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
h = x if self.shortcut is None else self.shortcut(x)
|
||||
x = nn.silu(self.residual[0](x))
|
||||
x = self.residual[2](x)
|
||||
x = nn.silu(self.residual[3](x))
|
||||
x = self.residual[6](x)
|
||||
|
||||
if feat_cache is not None:
|
||||
# First conv: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.residual[0](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.residual[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
# Second conv: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.residual[3](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.residual[6](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = nn.silu(self.residual[0](x))
|
||||
x = self.residual[2](x)
|
||||
x = nn.silu(self.residual[3](x))
|
||||
x = self.residual[6](x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
@@ -180,23 +210,31 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
"""Upsample block matching original Wan VAE structure.
|
||||
"""Resample block matching original Wan VAE structure.
|
||||
|
||||
Uses `resample` list with [None, Conv2d] to match original
|
||||
nn.Sequential(Upsample, Conv2d) where index 1 has the conv params.
|
||||
Supports both upsampling (decoder) and downsampling (encoder).
|
||||
Uses list-based param storage to match original nn.Sequential key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str):
|
||||
super().__init__()
|
||||
assert mode in ("upsample2d", "upsample3d")
|
||||
assert mode in ("upsample2d", "upsample3d", "downsample2d", "downsample3d")
|
||||
self.mode = mode
|
||||
self.dim = dim
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if mode.startswith("upsample"):
|
||||
# resample.0 = Upsample (no params), resample.1 = Conv2d
|
||||
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
|
||||
if mode == "upsample3d":
|
||||
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
else:
|
||||
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
|
||||
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
|
||||
if mode == "downsample3d":
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, C, T, H, W]"""
|
||||
b, c, t, h, w = x.shape
|
||||
|
||||
@@ -204,17 +242,43 @@ class Resample(nn.Module):
|
||||
# Temporal upsample via learned conv
|
||||
x_t = self.time_conv(x) # [B, 2C, T, H, W]
|
||||
x_t = x_t.reshape(b, 2, c, t, h, w)
|
||||
# Interleave along time: [B, C, 2T, H, W]
|
||||
x = mx.stack([x_t[:, 0], x_t[:, 1]], axis=3).reshape(b, c, t * 2, h, w)
|
||||
t = t * 2
|
||||
|
||||
# Per-frame spatial upsample: nearest 2x + Conv2d
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.repeat(x, 2, axis=1)
|
||||
x = mx.repeat(x, 2, axis=2)
|
||||
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
|
||||
c_out = x.shape[-1]
|
||||
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
|
||||
if self.mode.startswith("upsample"):
|
||||
# Per-frame spatial upsample: nearest 2x + Conv2d
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.repeat(x, 2, axis=1)
|
||||
x = mx.repeat(x, 2, axis=2)
|
||||
x = self.resample[1](x) # Conv2d [BT, 2H, 2W, C//2]
|
||||
c_out = x.shape[-1]
|
||||
return x.reshape(b, t, h * 2, w * 2, c_out).transpose(0, 4, 1, 2, 3)
|
||||
else:
|
||||
# Per-frame spatial downsample: ZeroPad(0,1,0,1) + Conv2d(stride=2)
|
||||
x = x.transpose(0, 2, 3, 4, 1).reshape(b * t, h, w, c) # [BT, H, W, C]
|
||||
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) # ZeroPad2d(0,1,0,1)
|
||||
x = self.resample[1](x) # Conv2d stride=2
|
||||
c_out = x.shape[-1]
|
||||
h_out, w_out = x.shape[1], x.shape[2]
|
||||
x = x.reshape(b, t, h_out, w_out, c_out).transpose(0, 4, 1, 2, 3)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
# First chunk: save x, skip time_conv
|
||||
feat_cache[idx] = x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
# Subsequent chunks: use cached frame as temporal context
|
||||
cache_x = x[:, :, -1:]
|
||||
x = self.time_conv(
|
||||
x, cache_x=feat_cache[idx][:, :, -1:])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.time_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
@@ -284,10 +348,108 @@ class Decoder3d(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
"""Wan2.1 VAE wrapper with per-channel normalization."""
|
||||
class Encoder3d(nn.Module):
|
||||
"""3D VAE Encoder matching Wan2.1 architecture.
|
||||
|
||||
def __init__(self, z_dim: int = 16):
|
||||
Mirror of Decoder3d with downsampling instead of upsampling.
|
||||
Uses flat lists to match original PyTorch nn.Sequential weight key hierarchy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: list = None,
|
||||
num_res_blocks: int = 2,
|
||||
temporal_downsample: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
if dim_mult is None:
|
||||
dim_mult = [1, 2, 4, 4]
|
||||
if temporal_downsample is None:
|
||||
temporal_downsample = [False, True, True]
|
||||
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# Flat downsample list matching original nn.Sequential indexing
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim))
|
||||
in_dim = out_dim
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temporal_downsample[i] else "downsample2d"
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
self.downsamples = downsamples
|
||||
|
||||
# Middle: [ResBlock, AttentionBlock, ResBlock]
|
||||
self.middle = [
|
||||
ResidualBlock(dims[-1], dims[-1]),
|
||||
AttentionBlock(dims[-1]),
|
||||
ResidualBlock(dims[-1], dims[-1]),
|
||||
]
|
||||
|
||||
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
|
||||
self.head = [
|
||||
RMS_norm(dims[-1], images=False),
|
||||
None, # SiLU
|
||||
CausalConv3d(dims[-1], z_dim, 3, padding=1),
|
||||
]
|
||||
|
||||
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
|
||||
"""x: [B, 3, T, H, W] -> [B, z_dim, T_lat, H_lat, W_lat]"""
|
||||
if feat_cache is not None:
|
||||
# conv1 with caching
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.conv1(x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None and isinstance(layer, (ResidualBlock, Resample)):
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
for layer in self.middle:
|
||||
if feat_cache is not None and isinstance(layer, ResidualBlock):
|
||||
x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
# Head: norm -> silu -> [cache] -> conv
|
||||
x = nn.silu(self.head[0](x))
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:]
|
||||
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
|
||||
cache_x = mx.concatenate(
|
||||
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
|
||||
x = self.head[2](x, cache_x=feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = nn.silu(self.head[0](x))
|
||||
x = self.head[2](x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanVAE(nn.Module):
|
||||
"""Wan2.1 VAE wrapper with per-channel normalization.
|
||||
|
||||
Supports both encode (for I2V) and decode (for all models).
|
||||
"""
|
||||
|
||||
def __init__(self, z_dim: int = 16, encoder: bool = False):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.mean = mx.array(VAE_MEAN)
|
||||
@@ -297,6 +459,65 @@ class WanVAE(nn.Module):
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim=96, z_dim=z_dim)
|
||||
|
||||
if encoder:
|
||||
self.encoder = Encoder3d(dim=96, z_dim=z_dim * 2)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
|
||||
def encode(self, x: mx.array) -> mx.array:
|
||||
"""Encode video to normalized latent using chunked encoding.
|
||||
|
||||
Uses chunked encoding with temporal caching to match reference behavior.
|
||||
First frame encoded alone, then 4-frame chunks with cached context.
|
||||
|
||||
Args:
|
||||
x: Video [B, 3, T, H, W] in [-1, 1]
|
||||
|
||||
Returns:
|
||||
Normalized latent [B, z_dim, T_lat, H_lat, W_lat]
|
||||
"""
|
||||
# Count cacheable CausalConv3d slots in encoder
|
||||
num_slots = self._count_encoder_cache_slots()
|
||||
feat_cache = [None] * num_slots
|
||||
|
||||
t = x.shape[2]
|
||||
num_chunks = 1 + (t - 1) // 4
|
||||
|
||||
out = None
|
||||
for i in range(num_chunks):
|
||||
feat_idx = [0]
|
||||
if i == 0:
|
||||
chunk = x[:, :, :1]
|
||||
else:
|
||||
chunk = x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i]
|
||||
|
||||
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
|
||||
|
||||
if out is None:
|
||||
out = chunk_out
|
||||
else:
|
||||
out = mx.concatenate([out, chunk_out], axis=2)
|
||||
|
||||
mu, _ = mx.split(self.conv1(out), 2, axis=1)
|
||||
|
||||
# Normalize: (mu - mean) * inv_std
|
||||
mean = self.mean.reshape(1, -1, 1, 1, 1)
|
||||
inv_std = self.inv_std.reshape(1, -1, 1, 1, 1)
|
||||
return (mu - mean) * inv_std
|
||||
|
||||
def _count_encoder_cache_slots(self) -> int:
|
||||
"""Count CausalConv3d that participate in chunked encoding cache."""
|
||||
count = 1 # encoder.conv1
|
||||
for layer in self.encoder.downsamples:
|
||||
if isinstance(layer, ResidualBlock):
|
||||
count += 2 # two convs in residual path
|
||||
elif isinstance(layer, Resample) and layer.mode == "downsample3d":
|
||||
count += 1 # time_conv
|
||||
for layer in self.encoder.middle:
|
||||
if isinstance(layer, ResidualBlock):
|
||||
count += 2
|
||||
count += 1 # encoder.head CausalConv3d
|
||||
return count
|
||||
|
||||
def decode(self, z: mx.array) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user