feat(wan): Add Wan2.2 I2V support
This commit is contained in:
@@ -49,9 +49,9 @@ class WanAttentionBlock(nn.Module):
|
||||
context_lens: list | None = None,
|
||||
cross_kv_cache: tuple | None = None,
|
||||
) -> mx.array:
|
||||
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
|
||||
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
|
||||
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
|
||||
# Modulation in float32 (matching official torch.amp.autocast float32)
|
||||
e_f32 = e.astype(mx.float32)
|
||||
mod = self.modulation.astype(mx.float32) + e_f32
|
||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||
@@ -59,19 +59,19 @@ class WanAttentionBlock(nn.Module):
|
||||
e4 = mod[:, :, 4, :] # scale for ffn
|
||||
e5 = mod[:, :, 5, :] # gate for ffn
|
||||
|
||||
# Self-attention with modulation
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
# Self-attention with modulation (norm output in float32)
|
||||
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
||||
x = x + y * e2
|
||||
x = x.astype(mx.float32) + y.astype(mx.float32) * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||
|
||||
# FFN with modulation
|
||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||
# FFN with modulation (norm output in float32)
|
||||
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3
|
||||
y = self.ffn(x_mod)
|
||||
x = x + y * e5
|
||||
x = x + y.astype(mx.float32) * e5
|
||||
|
||||
return x
|
||||
|
||||
@@ -86,4 +86,6 @@ class WanFFN(nn.Module):
|
||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
x_w = x.astype(self.fc1.weight.dtype)
|
||||
return self.fc2(self.act(self.fc1(x_w)))
|
||||
|
||||
Reference in New Issue
Block a user