feat(wan): Add Wan2.1/2.2 T2V with quantization support
This commit is contained in:
89
mlx_video/models/wan/transformer.py
Normal file
89
mlx_video/models/wan/transformer.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import WanCrossAttention, WanLayerNorm, WanSelfAttention
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
"""Wan transformer block with learned modulation, self-attn, cross-attn, and FFN."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
window_size: tuple = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Self-attention
|
||||
self.norm1 = WanLayerNorm(dim, eps)
|
||||
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
||||
|
||||
# Cross-attention (with optional norm on context)
|
||||
self.norm3 = (
|
||||
WanLayerNorm(dim, eps, elementwise_affine=True)
|
||||
if cross_attn_norm
|
||||
else None
|
||||
)
|
||||
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
|
||||
|
||||
# Feed-forward
|
||||
self.norm2 = WanLayerNorm(dim, eps)
|
||||
self.ffn = WanFFN(dim, ffn_dim)
|
||||
|
||||
# Learned modulation: 6 vectors for scale/shift/gate
|
||||
self.modulation = mx.random.normal((1, 6, dim)) * (dim**-0.5)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
e: mx.array,
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
cross_kv_cache: tuple | None = None,
|
||||
) -> mx.array:
|
||||
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
|
||||
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
|
||||
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
|
||||
e0 = mod[:, :, 0, :] # shift for self-attn
|
||||
e1 = mod[:, :, 1, :] # scale for self-attn
|
||||
e2 = mod[:, :, 2, :] # gate for self-attn
|
||||
e3 = mod[:, :, 3, :] # shift for ffn
|
||||
e4 = mod[:, :, 4, :] # scale for ffn
|
||||
e5 = mod[:, :, 5, :] # gate for ffn
|
||||
|
||||
# Self-attention with modulation
|
||||
x_mod = self.norm1(x) * (1 + e1) + e0
|
||||
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
|
||||
x = x + y * e2
|
||||
|
||||
# Cross-attention (no modulation, just norm)
|
||||
x_cross = self.norm3(x) if self.norm3 is not None else x
|
||||
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
|
||||
|
||||
# FFN with modulation
|
||||
x_mod = self.norm2(x) * (1 + e4) + e3
|
||||
y = self.ffn(x_mod)
|
||||
x = x + y * e5
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanFFN(nn.Module):
|
||||
"""Gated feed-forward network with GELU(tanh) activation."""
|
||||
|
||||
def __init__(self, dim: int, ffn_dim: int):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, ffn_dim)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
self.fc2 = nn.Linear(ffn_dim, dim)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
Reference in New Issue
Block a user