Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -0,0 +1,221 @@
import mlx.core as mx
import mlx.nn as nn
from .rope import rope_apply
def _linear_dtype(layer) -> mx.Dtype:
"""Get the compute dtype of a linear layer, handling QuantizedLinear and LoRA wrappers."""
# Unwrap LoRA wrapper to get the underlying linear layer
inner = getattr(layer, "linear", layer)
if isinstance(inner, nn.QuantizedLinear):
return inner.scales.dtype
return inner.weight.dtype
class WanRMSNorm(nn.Module):
"""RMS normalization with learnable scale."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = mx.ones((dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.fast.rms_norm(x, self.weight, self.eps)
class WanLayerNorm(nn.Module):
"""LayerNorm computed in float32, with optional affine."""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = mx.ones((dim,))
self.bias = mx.zeros((dim,))
def __call__(self, x: mx.array) -> mx.array:
if self.elementwise_affine:
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
else:
return mx.fast.layer_norm(x, None, None, self.eps)
class WanSelfAttention(nn.Module):
"""Self-attention with QK normalization and 3-way factorized RoPE."""
def __init__(
self,
dim: int,
num_heads: int,
window_size: tuple = (-1, -1),
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.window_size = window_size
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def __call__(
self,
x: mx.array,
seq_lens: list,
grid_sizes: list,
freqs: mx.array,
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array:
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
x_w = x.astype(w_dtype)
q = self.q(x_w)
k = self.k(x_w)
if self.norm_q is not None:
q = self.norm_q(q)
if self.norm_k is not None:
k = self.norm_k(k)
q = q.reshape(b, s, n, d)
k = k.reshape(b, s, n, d)
v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
k = rope_apply(
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Use precomputed mask or build from seq_lens
mask = attn_mask
if mask is None and any(sl < s for sl in seq_lens):
mask = mx.zeros((b, 1, 1, s), dtype=q.dtype)
for i, sl in enumerate(seq_lens):
mask[i, :, :, sl:] = -1e9
# Use memory-efficient scaled dot-product attention
# mx.fast.scaled_dot_product_attention expects [B, N, L, D]
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out)
class WanCrossAttention(nn.Module):
"""Cross-attention: Q from hidden states, K/V from text context."""
def __init__(
self,
dim: int,
num_heads: int,
qk_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None
def prepare_kv(self, context: mx.array) -> tuple:
"""Pre-compute K and V projections for caching.
Args:
context: [B, L_ctx, dim]
Returns:
(k, v) each [B, N, L_ctx, D] ready for attention
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul
w_dtype = _linear_dtype(self.k)
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
return k, v
def __call__(
self,
x: mx.array,
context: mx.array,
context_lens: list | None = None,
kv_cache: tuple | None = None,
) -> mx.array:
b = x.shape[0]
n, d = self.num_heads, self.head_dim
# Cast to compute dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = _linear_dtype(self.q)
q = self.q(x.astype(w_dtype))
if self.norm_q is not None:
q = self.norm_q(q)
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
if kv_cache is not None:
k, v = kv_cache
else:
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
# Optional context masking
mask = None
if context_lens is not None:
ctx_len = k.shape[2]
mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype)
for i, cl in enumerate(context_lens):
mask[i, :, :, cl:] = -1e9
if mask is not None:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out)