feat(wan): Add Wan2.2 I2V support
This commit is contained in:
@@ -15,17 +15,17 @@ def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (must be even).
|
||||
position: 1D tensor of positions.
|
||||
position: Tensor of positions — 1D [L] or 2D [B, L].
|
||||
|
||||
Returns:
|
||||
Embeddings of shape [len(position), dim].
|
||||
Embeddings of shape [L, dim] or [B, L, dim].
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
pos = position.astype(mx.float32)
|
||||
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
|
||||
sinusoid = pos[:, None] * inv_freq[None, :]
|
||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
|
||||
sinusoid = pos[..., None] * inv_freq # [..., half]
|
||||
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
|
||||
class Head(nn.Module):
|
||||
@@ -44,16 +44,17 @@ class Head(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
x: [B, L, dim]
|
||||
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
|
||||
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
e_f32 = e.astype(mx.float32)
|
||||
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
|
||||
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
|
||||
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
|
||||
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze
|
||||
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x).astype(mx.float32)
|
||||
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
|
||||
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1
|
||||
return self.head(x_mod.astype(x.dtype))
|
||||
|
||||
|
||||
@@ -261,18 +262,30 @@ class WanModel(nn.Module):
|
||||
axis=0,
|
||||
) # [B, seq_len, dim]
|
||||
|
||||
# Time embedding: compute once per sample, then broadcast to all tokens
|
||||
# Time embedding
|
||||
if t.ndim == 0:
|
||||
t = t[None]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||
|
||||
model_dtype = self.patch_embedding_proj.weight.dtype
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
|
||||
e = e.astype(model_dtype)
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
|
||||
# Keep e and e0 in float32 — official asserts float32 for modulation
|
||||
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
|
||||
e = e.astype(mx.float32)
|
||||
else:
|
||||
# I2V: per-token timesteps [B, L]
|
||||
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
|
||||
e = self.time_embedding_1(
|
||||
self.time_embedding_act(self.time_embedding_0(sin_emb))
|
||||
) # [B, L, dim]
|
||||
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
|
||||
# Keep e and e0 in float32 — official asserts float32 for modulation
|
||||
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
|
||||
e = e.astype(mx.float32)
|
||||
|
||||
# Text embedding: skip MLP if context is already embedded (mx.array)
|
||||
if isinstance(context, mx.array):
|
||||
|
||||
Reference in New Issue
Block a user