202 lines
5.9 KiB
Python
202 lines
5.9 KiB
Python
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .rope import rope_apply
|
|
|
|
|
|
class WanRMSNorm(nn.Module):
|
|
"""RMS normalization with learnable scale."""
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-5):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = mx.ones((dim,))
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
return mx.fast.rms_norm(x, self.weight, self.eps)
|
|
|
|
|
|
class WanLayerNorm(nn.Module):
|
|
"""LayerNorm computed in float32, with optional affine."""
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.elementwise_affine = elementwise_affine
|
|
if elementwise_affine:
|
|
self.weight = mx.ones((dim,))
|
|
self.bias = mx.zeros((dim,))
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
if self.elementwise_affine:
|
|
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
|
|
else:
|
|
return mx.fast.layer_norm(x, None, None, self.eps)
|
|
|
|
|
|
class WanSelfAttention(nn.Module):
|
|
"""Self-attention with QK normalization and 3-way factorized RoPE."""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
window_size: tuple = (-1, -1),
|
|
qk_norm: bool = True,
|
|
eps: float = 1e-6,
|
|
):
|
|
super().__init__()
|
|
assert dim % num_heads == 0
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.window_size = window_size
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
self.q = nn.Linear(dim, dim)
|
|
self.k = nn.Linear(dim, dim)
|
|
self.v = nn.Linear(dim, dim)
|
|
self.o = nn.Linear(dim, dim)
|
|
|
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
seq_lens: list,
|
|
grid_sizes: list,
|
|
freqs: mx.array,
|
|
) -> mx.array:
|
|
b, s, _ = x.shape
|
|
n, d = self.num_heads, self.head_dim
|
|
|
|
q = self.q(x)
|
|
k = self.k(x)
|
|
if self.norm_q is not None:
|
|
q = self.norm_q(q)
|
|
if self.norm_k is not None:
|
|
k = self.norm_k(k)
|
|
|
|
q = q.reshape(b, s, n, d)
|
|
k = k.reshape(b, s, n, d)
|
|
v = self.v(x).reshape(b, s, n, d)
|
|
|
|
# Apply RoPE
|
|
q = rope_apply(q, grid_sizes, freqs)
|
|
k = rope_apply(k, grid_sizes, freqs)
|
|
|
|
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
|
|
q = q.transpose(0, 2, 1, 3)
|
|
k = k.transpose(0, 2, 1, 3)
|
|
v = v.transpose(0, 2, 1, 3)
|
|
|
|
# Build attention mask from seq_lens
|
|
max_len = s
|
|
mask = None
|
|
if any(sl < max_len for sl in seq_lens):
|
|
mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype)
|
|
for i, sl in enumerate(seq_lens):
|
|
mask[i, :, :, sl:] = -1e9
|
|
|
|
# Use memory-efficient scaled dot-product attention
|
|
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
|
|
if mask is not None:
|
|
out = mx.fast.scaled_dot_product_attention(
|
|
q, k, v, scale=self.scale, mask=mask
|
|
)
|
|
else:
|
|
out = mx.fast.scaled_dot_product_attention(
|
|
q, k, v, scale=self.scale
|
|
)
|
|
|
|
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
|
return self.o(out)
|
|
|
|
|
|
class WanCrossAttention(nn.Module):
|
|
"""Cross-attention: Q from hidden states, K/V from text context."""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
qk_norm: bool = True,
|
|
eps: float = 1e-6,
|
|
):
|
|
super().__init__()
|
|
assert dim % num_heads == 0
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.scale = self.head_dim**-0.5
|
|
|
|
self.q = nn.Linear(dim, dim)
|
|
self.k = nn.Linear(dim, dim)
|
|
self.v = nn.Linear(dim, dim)
|
|
self.o = nn.Linear(dim, dim)
|
|
|
|
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
|
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
|
|
|
def prepare_kv(self, context: mx.array) -> tuple:
|
|
"""Pre-compute K and V projections for caching.
|
|
|
|
Args:
|
|
context: [B, L_ctx, dim]
|
|
|
|
Returns:
|
|
(k, v) each [B, N, L_ctx, D] ready for attention
|
|
"""
|
|
b = context.shape[0]
|
|
n, d = self.num_heads, self.head_dim
|
|
k = self.k(context)
|
|
if self.norm_k is not None:
|
|
k = self.norm_k(k)
|
|
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
|
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
|
return k, v
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
context: mx.array,
|
|
context_lens: list | None = None,
|
|
kv_cache: tuple | None = None,
|
|
) -> mx.array:
|
|
b = x.shape[0]
|
|
n, d = self.num_heads, self.head_dim
|
|
|
|
q = self.q(x)
|
|
if self.norm_q is not None:
|
|
q = self.norm_q(q)
|
|
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
|
|
|
if kv_cache is not None:
|
|
k, v = kv_cache
|
|
else:
|
|
k = self.k(context)
|
|
if self.norm_k is not None:
|
|
k = self.norm_k(k)
|
|
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
|
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
|
|
|
# Optional context masking
|
|
mask = None
|
|
if context_lens is not None:
|
|
ctx_len = k.shape[2]
|
|
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
|
|
for i, cl in enumerate(context_lens):
|
|
mask[i, :, :, cl:] = -1e9
|
|
|
|
if mask is not None:
|
|
out = mx.fast.scaled_dot_product_attention(
|
|
q, k, v, scale=self.scale, mask=mask
|
|
)
|
|
else:
|
|
out = mx.fast.scaled_dot_product_attention(
|
|
q, k, v, scale=self.scale
|
|
)
|
|
|
|
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
|
return self.o(out)
|