feat(wan): Add Wan2.2 I2V support
This commit is contained in:
@@ -71,8 +71,12 @@ class WanSelfAttention(nn.Module):
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = self.q.weight.dtype
|
||||
x_w = x.astype(w_dtype)
|
||||
|
||||
q = self.q(x_w)
|
||||
k = self.k(x_w)
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
if self.norm_k is not None:
|
||||
@@ -80,15 +84,15 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
q = q.reshape(b, s, n, d)
|
||||
k = k.reshape(b, s, n, d)
|
||||
v = self.v(x).reshape(b, s, n, d)
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# Apply RoPE
|
||||
q = rope_apply(q, grid_sizes, freqs)
|
||||
k = rope_apply(k, grid_sizes, freqs)
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
|
||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
|
||||
|
||||
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
|
||||
q = q.transpose(0, 2, 1, 3)
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Build attention mask from seq_lens
|
||||
@@ -149,11 +153,14 @@ class WanCrossAttention(nn.Module):
|
||||
"""
|
||||
b = context.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
k = self.k(context)
|
||||
# Cast to weight dtype for efficient matmul
|
||||
w_dtype = self.k.weight.dtype
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
return k, v
|
||||
|
||||
def __call__(
|
||||
@@ -166,7 +173,9 @@ class WanCrossAttention(nn.Module):
|
||||
b = x.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(x)
|
||||
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = self.q.weight.dtype
|
||||
q = self.q(x.astype(w_dtype))
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
@@ -174,11 +183,12 @@ class WanCrossAttention(nn.Module):
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache
|
||||
else:
|
||||
k = self.k(context)
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
# Optional context masking
|
||||
mask = None
|
||||
|
||||
Reference in New Issue
Block a user