feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -67,6 +67,8 @@ class WanSelfAttention(nn.Module):
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
@@ -87,19 +89,18 @@ class WanSelfAttention(nn.Module):
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
|
||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
|
||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
||||
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Build attention mask from seq_lens
|
||||
max_len = s
|
||||
mask = None
|
||||
if any(sl < max_len for sl in seq_lens):
|
||||
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
|
||||
# Use precomputed mask or build from seq_lens
|
||||
mask = attn_mask
|
||||
if mask is None and any(sl < s for sl in seq_lens):
|
||||
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
|
||||
for i, sl in enumerate(seq_lens):
|
||||
mask[i, :, :, sl:] = -1e9
|
||||
|
||||
|
||||
Reference in New Issue
Block a user