feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -35,8 +35,8 @@ class WanAttentionBlock(nn.Module):
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = WanFFN(dim, ffn_dim)
|
||||
|
||||
# Learned modulation: 6 vectors for scale/shift/gate
|
||||
self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5)
|
||||
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
|
||||
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -48,10 +48,11 @@ class WanAttentionBlock(nn.Module):
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
cross_kv_cache: tuple | None = None,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Modulation in float32 (matching official torch.amp.autocast float32)
|
||||
e_f32 = e.astype(mx.float32)
|
||||
mod = self.modulation.astype(mx.float32) + e_f32
|
||||
# Modulation in float32 (e is already float32 from model forward)
|
||||
mod = self.modulation + e
|
||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||
@@ -59,19 +60,20 @@ class WanAttentionBlock(nn.Module):
|
||||
e4 = mod[:, :, 4, :] # scale for ffn
|
||||
e5 = mod[:, :, 5, :] # gate for ffn
|
||||
|
||||
# Self-attention with modulation (norm output in float32)
|
||||
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
||||
x = x.astype(mx.float32) + y.astype(mx.float32) * e2
|
||||
# Self-attention with modulation
|
||||
# Type promotion handles bf16→f32 automatically when multiplied with f32 modulation
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
|
||||
x = x + y * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||
|
||||
# FFN with modulation (norm output in float32)
|
||||
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3
|
||||
# FFN with modulation
|
||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||
y = self.ffn(x_mod)
|
||||
x = x + y.astype(mx.float32) * e5
|
||||
x = x + y * e5
|
||||
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user