feat(wan): Add Wan2.2 I2V support

This commit is contained in:
Daniel
2026-02-27 13:46:23 +01:00
parent 93da550f65
commit 2bb95c61ed
26 changed files with 4401 additions and 2968 deletions

View File

@@ -15,17 +15,17 @@ def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
Args:
dim: Embedding dimension (must be even).
position: 1D tensor of positions.
position: Tensor of positions — 1D [L] or 2D [B, L].
Returns:
Embeddings of shape [len(position), dim].
Embeddings of shape [L, dim] or [B, L, dim].
"""
assert dim % 2 == 0
half = dim // 2
pos = position.astype(mx.float32)
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
sinusoid = pos[:, None] * inv_freq[None, :]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
sinusoid = pos[..., None] * inv_freq # [..., half]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
class Head(nn.Module):
@@ -44,16 +44,17 @@ class Head(nn.Module):
"""
Args:
x: [B, L, dim]
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
e_f32 = e.astype(mx.float32)
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x).astype(mx.float32)
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1
return self.head(x_mod.astype(x.dtype))
@@ -261,18 +262,30 @@ class WanModel(nn.Module):
axis=0,
) # [B, seq_len, dim]
# Time embedding: compute once per sample, then broadcast to all tokens
# Time embedding
if t.ndim == 0:
t = t[None]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
model_dtype = self.patch_embedding_proj.weight.dtype
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
e = e.astype(model_dtype)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
# Keep e and e0 in float32 — official asserts float32 for modulation
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
e = e.astype(mx.float32)
else:
# I2V: per-token timesteps [B, L]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, L, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
# Keep e and e0 in float32 — official asserts float32 for modulation
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
e = e.astype(mx.float32)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):