feat(wan): Add I2V-14B dual-model support

This commit is contained in:
Daniel
2026-02-27 23:43:42 +01:00
parent 2bb95c61ed
commit f4195f0118
14 changed files with 1332 additions and 152 deletions

View File

@@ -67,6 +67,8 @@ class WanSelfAttention(nn.Module):
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array:
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
@@ -87,19 +89,18 @@ class WanSelfAttention(nn.Module):
v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Build attention mask from seq_lens
max_len = s
mask = None
if any(sl < max_len for sl in seq_lens):
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
# Use precomputed mask or build from seq_lens
mask = attn_mask
if mask is None and any(sl < s for sl in seq_lens):
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
for i, sl in enumerate(seq_lens):
mask[i, :, :, sl:] = -1e9