feat(wan): Add Wan2.2 I2V support

This commit is contained in:
Daniel
2026-02-27 13:46:23 +01:00
parent 93da550f65
commit 2bb95c61ed
26 changed files with 4401 additions and 2968 deletions

View File

@@ -49,9 +49,9 @@ class WanAttentionBlock(nn.Module):
context_lens: list | None = None,
cross_kv_cache: tuple | None = None,
) -> mx.array:
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
# Modulation in float32 (matching official torch.amp.autocast float32)
e_f32 = e.astype(mx.float32)
mod = self.modulation.astype(mx.float32) + e_f32
e0 = mod[:, :, 0, :] # shift for self-attn
e1 = mod[:, :, 1, :] # scale for self-attn
e2 = mod[:, :, 2, :] # gate for self-attn
@@ -59,19 +59,19 @@ class WanAttentionBlock(nn.Module):
e4 = mod[:, :, 4, :] # scale for ffn
e5 = mod[:, :, 5, :] # gate for ffn
# Self-attention with modulation
x_mod = self.norm1(x) * (1 + e1) + e0
# Self-attention with modulation (norm output in float32)
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
x = x + y * e2
x = x.astype(mx.float32) + y.astype(mx.float32) * e2
# Cross-attention (no modulation, just norm)
x_cross = self.norm3(x) if self.norm3 is not None else x
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
# FFN with modulation
x_mod = self.norm2(x) * (1 + e4) + e3
# FFN with modulation (norm output in float32)
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3
y = self.ffn(x_mod)
x = x + y * e5
x = x + y.astype(mx.float32) * e5
return x
@@ -86,4 +86,6 @@ class WanFFN(nn.Module):
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(self.fc1.weight.dtype)
return self.fc2(self.act(self.fc1(x_w)))