feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -6,7 +6,7 @@ import numpy as np
|
||||
|
||||
from .attention import WanLayerNorm
|
||||
from .config import WanModelConfig
|
||||
from .rope import rope_params
|
||||
from .rope import rope_params, rope_precompute_cos_sin
|
||||
from .transformer import WanAttentionBlock
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class Head(nn.Module):
|
||||
proj_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, proj_dim)
|
||||
self.modulation = mx.random.normal((1, 2, dim)) * (dim**-0.5)
|
||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -48,14 +48,13 @@ class Head(nn.Module):
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
e_f32 = e.astype(mx.float32)
|
||||
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze
|
||||
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
|
||||
# modulation already float32; e already float32 from model forward
|
||||
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # [B, L_e, 2, dim]
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x).astype(mx.float32)
|
||||
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1
|
||||
return self.head(x_mod.astype(x.dtype))
|
||||
x_norm = self.norm(x)
|
||||
x_mod = x_norm * (1 + e1) + e0 # type promotion handles bf16→f32
|
||||
return self.head(x_mod.astype(self.head.weight.dtype))
|
||||
|
||||
|
||||
class WanModel(nn.Module):
|
||||
@@ -109,17 +108,16 @@ class WanModel(nn.Module):
|
||||
# Output head
|
||||
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
|
||||
|
||||
# Precompute RoPE frequencies
|
||||
d = dim // config.num_heads
|
||||
d_t = d - 4 * (d // 6)
|
||||
d_h = 2 * (d // 6)
|
||||
d_w = 2 * (d // 6)
|
||||
# Each rope_params returns [1024, d_x//2, 2]
|
||||
freqs_t = rope_params(1024, d_t)
|
||||
freqs_h = rope_params(1024, d_h)
|
||||
freqs_w = rope_params(1024, d_w)
|
||||
# Concatenate along the frequency dimension: [1024, d//2, 2]
|
||||
self.freqs = mx.concatenate([freqs_t, freqs_h, freqs_w], axis=1)
|
||||
# Precompute RoPE frequencies — single table, split by rope_apply
|
||||
# Reference computes one rope_params(head_dim) and splits into t/h/w.
|
||||
self.freqs = rope_params(1024, dim // config.num_heads)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding
|
||||
half = config.freq_dim // 2
|
||||
self._inv_freq = mx.power(
|
||||
10000.0, -mx.arange(half).astype(mx.float32) / half
|
||||
)
|
||||
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
"""Convert video tensor to patch embeddings.
|
||||
@@ -215,6 +213,21 @@ class WanModel(nn.Module):
|
||||
kv_caches.append(block.cross_attn.prepare_kv(context))
|
||||
return kv_caches
|
||||
|
||||
def prepare_rope(self, grid_sizes: list) -> tuple:
|
||||
"""Pre-compute RoPE cos/sin for constant grid sizes.
|
||||
|
||||
Call once before the diffusion loop when grid sizes don't change
|
||||
across steps. Eliminates per-step broadcast/concat overhead.
|
||||
|
||||
Args:
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
|
||||
Returns:
|
||||
(cos_f, sin_f) precomputed frequency tensors
|
||||
"""
|
||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
||||
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x_list: list,
|
||||
@@ -222,6 +235,8 @@ class WanModel(nn.Module):
|
||||
context: list | mx.array,
|
||||
seq_len: int,
|
||||
cross_kv_caches: list | None = None,
|
||||
y: list | None = None,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
) -> list:
|
||||
"""Forward pass.
|
||||
|
||||
@@ -233,42 +248,70 @@ class WanModel(nn.Module):
|
||||
seq_len: Maximum sequence length for padding
|
||||
cross_kv_caches: Optional list of (k, v) tuples from
|
||||
prepare_cross_kv(), one per block.
|
||||
y: Optional list of conditioning tensors for I2V [C_y, F, H, W].
|
||||
Channel-concatenated with x before patchify.
|
||||
rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope().
|
||||
|
||||
Returns:
|
||||
List of denoised tensors [C, F, H, W]
|
||||
"""
|
||||
# Patchify each video
|
||||
patches = []
|
||||
grid_sizes = []
|
||||
seq_lens_list = []
|
||||
for vid in x_list:
|
||||
p, gs = self._patchify(vid) # [1, L, dim]
|
||||
patches.append(p)
|
||||
grid_sizes.append(gs)
|
||||
seq_lens_list.append(p.shape[1])
|
||||
# Detect identical inputs (CFG B=2) to avoid duplicate patchify work.
|
||||
# Check BEFORE I2V concat since concat creates new array objects.
|
||||
batch_size = len(x_list)
|
||||
all_same = batch_size > 1 and all(
|
||||
x_list[i] is x_list[0] for i in range(1, batch_size)
|
||||
)
|
||||
if all_same and y is not None:
|
||||
all_same = all(y[i] is y[0] for i in range(1, len(y)))
|
||||
|
||||
# Pad and batch
|
||||
batch_size = len(patches)
|
||||
x = mx.concatenate(
|
||||
[
|
||||
mx.concatenate(
|
||||
# I2V: channel-concatenate conditioning y with noise x
|
||||
if y is not None:
|
||||
x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)]
|
||||
|
||||
if all_same:
|
||||
# Patchify once and broadcast — saves a Linear projection per step
|
||||
p, gs = self._patchify(x_list[0]) # [1, L, dim]
|
||||
grid_sizes = [gs] * batch_size
|
||||
seq_lens_list = [p.shape[1]] * batch_size
|
||||
# Pad and broadcast
|
||||
if p.shape[1] < seq_len:
|
||||
p = mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
x = mx.broadcast_to(p, (batch_size,) + p.shape[1:])
|
||||
else:
|
||||
patches = []
|
||||
grid_sizes = []
|
||||
seq_lens_list = []
|
||||
for vid in x_list:
|
||||
p, gs = self._patchify(vid) # [1, L, dim]
|
||||
patches.append(p)
|
||||
grid_sizes.append(gs)
|
||||
seq_lens_list.append(p.shape[1])
|
||||
x = mx.concatenate(
|
||||
[
|
||||
mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding
|
||||
# Time embedding (use cached inv_freq to avoid recomputing each step)
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
|
||||
pos = t.astype(mx.float32)
|
||||
sinusoid = pos[..., None] * self._inv_freq
|
||||
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
@@ -278,7 +321,6 @@ class WanModel(nn.Module):
|
||||
e = e.astype(mx.float32)
|
||||
else:
|
||||
# I2V: per-token timesteps [B, L]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, L, dim]
|
||||
@@ -298,7 +340,15 @@ class WanModel(nn.Module):
|
||||
else:
|
||||
context_batch = self.embed_text(context)
|
||||
|
||||
# Run transformer blocks
|
||||
# Pre-compute attention mask from seq_lens (constant across all blocks)
|
||||
attn_mask = None
|
||||
w_dtype = self.patch_embedding_proj.weight.dtype
|
||||
if any(sl < seq_len for sl in seq_lens_list):
|
||||
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
|
||||
for i, sl in enumerate(seq_lens_list):
|
||||
attn_mask[i, :, :, sl:] = -1e9
|
||||
|
||||
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens_list,
|
||||
@@ -306,8 +356,11 @@ class WanModel(nn.Module):
|
||||
freqs=self.freqs,
|
||||
context=context_batch,
|
||||
context_lens=None,
|
||||
rope_cos_sin=rope_cos_sin,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
# Run transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
|
||||
x = block(x, cross_kv_cache=kv, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user