feat(wan): Add Wan2.1/2.2 T2V with quantization support

This commit is contained in:
Daniel
2026-02-26 16:16:07 +01:00
parent 7a74946c57
commit e64483a66a
21 changed files with 5309 additions and 35 deletions

View File

@@ -0,0 +1,234 @@
"""T5 Text Encoder (UMT5-XXL) for Wan2.2 text conditioning."""
import math
import mlx.core as mx
import mlx.nn as nn
class T5LayerNorm(nn.Module):
"""RMS-based layer normalization (T5 style)."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class T5RelativeEmbedding(nn.Module):
"""T5-style relative position bias with bucketing."""
def __init__(
self,
num_buckets: int,
num_heads: int,
bidirectional: bool = True,
max_dist: int = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.num_heads = num_heads
self.bidirectional = bidirectional
self.max_dist = max_dist
self.embedding = nn.Embedding(num_buckets, num_heads)
def _relative_position_bucket(self, rel_pos: mx.array) -> mx.array:
if self.bidirectional:
num_buckets = self.num_buckets // 2
rel_buckets = (rel_pos > 0).astype(mx.int32) * num_buckets
rel_pos = mx.abs(rel_pos)
else:
num_buckets = self.num_buckets
rel_buckets = mx.zeros_like(rel_pos, dtype=mx.int32)
rel_pos = mx.maximum(-rel_pos, mx.zeros_like(rel_pos))
max_exact = num_buckets // 2
is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = (
max_exact
+ (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = mx.minimum(
rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
)
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array:
positions_k = mx.arange(lk)[None, :] # [1, lk]
positions_q = mx.arange(lq)[:, None] # [lq, 1]
rel_pos = positions_k - positions_q # [lq, lk]
buckets = self._relative_position_bucket(rel_pos)
embeds = self.embedding(buckets) # [lq, lk, num_heads]
embeds = embeds.transpose(2, 0, 1)[None, :, :, :] # [1, N, lq, lk]
return embeds
class T5Attention(nn.Module):
"""T5-style multi-head attention (no scaling)."""
def __init__(self, dim: int, dim_attn: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert dim_attn % num_heads == 0
self.dim = dim
self.dim_attn = dim_attn
self.num_heads = num_heads
self.head_dim = dim_attn // num_heads
self.q = nn.Linear(dim, dim_attn, bias=False)
self.k = nn.Linear(dim, dim_attn, bias=False)
self.v = nn.Linear(dim, dim_attn, bias=False)
self.o = nn.Linear(dim_attn, dim, bias=False)
def __call__(
self,
x: mx.array,
context: mx.array | None = None,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
context = x if context is None else context
b, n, c = x.shape[0], self.num_heads, self.head_dim
q = self.q(x).reshape(b, -1, n, c) # [B, Lq, N, C]
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
v = self.v(context).reshape(b, -1, n, c)
# T5 does not use scaling
# attn = einsum('binc,bjnc->bnij', q, k)
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Combine position bias and attention mask for SDPA
attn_mask = None
if pos_bias is not None:
attn_mask = pos_bias.astype(q.dtype)
if mask is not None:
if mask.ndim == 2:
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
elif mask.ndim == 3:
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype)
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask
# T5 uses no scaling (scale=1.0)
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=1.0, mask=attn_mask
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c)
return self.o(out)
class T5FeedForward(nn.Module):
"""Gated feed-forward: gate(x) * fc1(x) -> fc2."""
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.dim = dim
self.dim_ffn = dim_ffn
self.gate_proj = nn.Linear(dim, dim_ffn, bias=False)
self.gate_act = nn.GELU(approx="precise")
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.fc1(x) * self.gate_act(self.gate_proj(x)))
class T5SelfAttentionBlock(nn.Module):
"""T5 encoder block: self-attention + FFN."""
def __init__(
self,
dim: int,
dim_attn: int,
dim_ffn: int,
num_heads: int,
num_buckets: int,
shared_pos: bool = True,
):
super().__init__()
self.shared_pos = shared_pos
self.norm1 = T5LayerNorm(dim)
self.attn = T5Attention(dim, dim_attn, num_heads)
self.norm2 = T5LayerNorm(dim)
self.ffn = T5FeedForward(dim, dim_ffn)
self.pos_embedding = (
None
if shared_pos
else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
pos_bias: mx.array | None = None,
) -> mx.array:
e = pos_bias if self.shared_pos else self.pos_embedding(x.shape[1], x.shape[1])
x = x + self.attn(self.norm1(x), mask=mask, pos_bias=e)
x = x + self.ffn(self.norm2(x))
return x
class T5Encoder(nn.Module):
"""T5 Encoder (UMT5-XXL configuration)."""
def __init__(
self,
vocab_size: int = 256384,
dim: int = 4096,
dim_attn: int = 4096,
dim_ffn: int = 10240,
num_heads: int = 64,
num_layers: int = 24,
num_buckets: int = 32,
shared_pos: bool = False,
):
super().__init__()
self.dim = dim
self.token_embedding = nn.Embedding(vocab_size, dim)
self.pos_embedding = (
T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
if shared_pos
else None
)
self.blocks = [
T5SelfAttentionBlock(
dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos
)
for _ in range(num_layers)
]
self.norm = T5LayerNorm(dim)
def __call__(self, ids: mx.array, mask: mx.array | None = None) -> mx.array:
"""
Args:
ids: Token IDs [B, L]
mask: Attention mask [B, L]
Returns:
Hidden states [B, L, dim]
"""
x = self.token_embedding(ids)
e = self.pos_embedding(x.shape[1], x.shape[1]) if self.pos_embedding else None
for block in self.blocks:
x = block(x, mask=mask, pos_bias=e)
x = self.norm(x)
return x