Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
221
mlx_video/models/wan_2/attention.py
Normal file
221
mlx_video/models/wan_2/attention.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .rope import rope_apply
|
||||
|
||||
|
||||
def _linear_dtype(layer) -> mx.Dtype:
|
||||
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
|
||||
# Unwrap LoRA wrapper to get the underlying linear layer
|
||||
inner = getattr(layer, "linear", layer)
|
||||
if isinstance(inner, nn.QuantizedLinear):
|
||||
return inner.scales.dtype
|
||||
return inner.weight.dtype
|
||||
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
"""RMS normalization with learnable scale."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.fast.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
class WanLayerNorm(nn.Module):
|
||||
"""LayerNorm computed in float32, with optional affine."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if elementwise_affine:
|
||||
self.weight = mx.ones((dim,))
|
||||
self.bias = mx.zeros((dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.elementwise_affine:
|
||||
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
|
||||
else:
|
||||
return mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class WanSelfAttention(nn.Module):
|
||||
"""Self-attention with QK normalization and 3-way factorized RoPE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
window_size: tuple = (-1, -1),
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.window_size = window_size
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
seq_lens: list,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
b, s, _ = x.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
x_w = x.astype(w_dtype)
|
||||
|
||||
q = self.q(x_w)
|
||||
k = self.k(x_w)
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
|
||||
q = q.reshape(b, s, n, d)
|
||||
k = k.reshape(b, s, n, d)
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(
|
||||
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
k = rope_apply(
|
||||
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Use precomputed mask or build from seq_lens
|
||||
mask = attn_mask
|
||||
if mask is None and any(sl < s for sl in seq_lens):
|
||||
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
|
||||
for i, sl in enumerate(seq_lens):
|
||||
mask[i, :, :, sl:] = -1e9
|
||||
|
||||
# Use memory-efficient scaled dot-product attention
|
||||
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
|
||||
if mask is not None:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class WanCrossAttention(nn.Module):
|
||||
"""Cross-attention: Q from hidden states, K/V from text context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim)
|
||||
self.k = nn.Linear(dim, dim)
|
||||
self.v = nn.Linear(dim, dim)
|
||||
self.o = nn.Linear(dim, dim)
|
||||
|
||||
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
|
||||
|
||||
def prepare_kv(self, context: mx.array) -> tuple:
|
||||
"""Pre-compute K and V projections for caching.
|
||||
|
||||
Args:
|
||||
context: [B, L_ctx, dim]
|
||||
|
||||
Returns:
|
||||
(k, v) each [B, N, L_ctx, D] ready for attention
|
||||
"""
|
||||
b = context.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
# Cast to compute dtype for efficient matmul
|
||||
w_dtype = _linear_dtype(self.k)
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
return k, v
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: mx.array,
|
||||
context_lens: list | None = None,
|
||||
kv_cache: tuple | None = None,
|
||||
) -> mx.array:
|
||||
b = x.shape[0]
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
|
||||
w_dtype = _linear_dtype(self.q)
|
||||
q = self.q(x.astype(w_dtype))
|
||||
if self.norm_q is not None:
|
||||
q = self.norm_q(q)
|
||||
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache
|
||||
else:
|
||||
ctx = context.astype(w_dtype)
|
||||
k = self.k(ctx)
|
||||
if self.norm_k is not None:
|
||||
k = self.norm_k(k)
|
||||
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
|
||||
|
||||
# Optional context masking
|
||||
mask = None
|
||||
if context_lens is not None:
|
||||
ctx_len = k.shape[2]
|
||||
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
|
||||
for i, cl in enumerate(context_lens):
|
||||
mask[i, :, :, cl:] = -1e9
|
||||
|
||||
if mask is not None:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
||||
return self.o(out)
|
||||
Reference in New Issue
Block a user