feat(wan): Add Wan2.2 I2V support

This commit is contained in:
Daniel
2026-02-27 13:46:23 +01:00
parent 93da550f65
commit 2bb95c61ed
26 changed files with 4401 additions and 2968 deletions

View File

@@ -71,8 +71,12 @@ class WanSelfAttention(nn.Module):
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
q = self.q(x)
k = self.k(x)
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
x_w = x.astype(w_dtype)
q = self.q(x_w)
k = self.k(x_w)
if self.norm_q is not None:
q = self.norm_q(q)
if self.norm_k is not None:
@@ -80,15 +84,15 @@ class WanSelfAttention(nn.Module):
q = q.reshape(b, s, n, d)
k = k.reshape(b, s, n, d)
v = self.v(x).reshape(b, s, n, d)
v = self.v(x_w).reshape(b, s, n, d)
# Apply RoPE
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Build attention mask from seq_lens
@@ -149,11 +153,14 @@ class WanCrossAttention(nn.Module):
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
k = self.k(context)
# Cast to weight dtype for efficient matmul
w_dtype = self.k.weight.dtype
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
return k, v
def __call__(
@@ -166,7 +173,9 @@ class WanCrossAttention(nn.Module):
b = x.shape[0]
n, d = self.num_heads, self.head_dim
q = self.q(x)
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
q = self.q(x.astype(w_dtype))
if self.norm_q is not None:
q = self.norm_q(q)
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
@@ -174,11 +183,12 @@ class WanCrossAttention(nn.Module):
if kv_cache is not None:
k, v = kv_cache
else:
k = self.k(context)
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
# Optional context masking
mask = None