import mlx.core as mx import mlx.nn as nn from .rope import rope_apply class WanRMSNorm(nn.Module): """RMS normalization with learnable scale.""" def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = mx.ones((dim,)) def __call__(self, x: mx.array) -> mx.array: return mx.fast.rms_norm(x, self.weight, self.eps) class WanLayerNorm(nn.Module): """LayerNorm computed in float32, with optional affine.""" def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False): super().__init__() self.eps = eps self.elementwise_affine = elementwise_affine if elementwise_affine: self.weight = mx.ones((dim,)) self.bias = mx.zeros((dim,)) def __call__(self, x: mx.array) -> mx.array: if self.elementwise_affine: return mx.fast.layer_norm(x, self.weight, self.bias, self.eps) else: return mx.fast.layer_norm(x, None, None, self.eps) class WanSelfAttention(nn.Module): """Self-attention with QK normalization and 3-way factorized RoPE.""" def __init__( self, dim: int, num_heads: int, window_size: tuple = (-1, -1), qk_norm: bool = True, eps: float = 1e-6, ): super().__init__() assert dim % num_heads == 0 self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads self.window_size = window_size self.scale = self.head_dim**-0.5 self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None def __call__( self, x: mx.array, seq_lens: list, grid_sizes: list, freqs: mx.array, ) -> mx.array: b, s, _ = x.shape n, d = self.num_heads, self.head_dim q = self.q(x) k = self.k(x) if self.norm_q is not None: q = self.norm_q(q) if self.norm_k is not None: k = self.norm_k(k) q = q.reshape(b, s, n, d) k = k.reshape(b, s, n, d) v = self.v(x).reshape(b, s, n, d) # Apply RoPE q = rope_apply(q, grid_sizes, freqs) k = rope_apply(k, grid_sizes, freqs) # Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D] q = q.transpose(0, 2, 1, 3) k = k.transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3) # Build attention mask from seq_lens max_len = s mask = None if any(sl < max_len for sl in seq_lens): mask = mx.zeros((b, 1, 1, max_len), dtype=q.dtype) for i, sl in enumerate(seq_lens): mask[i, :, :, sl:] = -1e9 # Use memory-efficient scaled dot-product attention # mx.fast.scaled_dot_product_attention expects [B, N, L, D] if mask is not None: out = mx.fast.scaled_dot_product_attention( q, k, v, scale=self.scale, mask=mask ) else: out = mx.fast.scaled_dot_product_attention( q, k, v, scale=self.scale ) out = out.transpose(0, 2, 1, 3).reshape(b, s, -1) return self.o(out) class WanCrossAttention(nn.Module): """Cross-attention: Q from hidden states, K/V from text context.""" def __init__( self, dim: int, num_heads: int, qk_norm: bool = True, eps: float = 1e-6, ): super().__init__() assert dim % num_heads == 0 self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) self.o = nn.Linear(dim, dim) self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else None self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else None def prepare_kv(self, context: mx.array) -> tuple: """Pre-compute K and V projections for caching. Args: context: [B, L_ctx, dim] Returns: (k, v) each [B, N, L_ctx, D] ready for attention """ b = context.shape[0] n, d = self.num_heads, self.head_dim k = self.k(context) if self.norm_k is not None: k = self.norm_k(k) k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3) v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3) return k, v def __call__( self, x: mx.array, context: mx.array, context_lens: list | None = None, kv_cache: tuple | None = None, ) -> mx.array: b = x.shape[0] n, d = self.num_heads, self.head_dim q = self.q(x) if self.norm_q is not None: q = self.norm_q(q) q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3) if kv_cache is not None: k, v = kv_cache else: k = self.k(context) if self.norm_k is not None: k = self.norm_k(k) k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3) v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3) # Optional context masking mask = None if context_lens is not None: ctx_len = k.shape[2] mask = mx.zeros((b, 1, 1, ctx_len), dtype=q.dtype) for i, cl in enumerate(context_lens): mask[i, :, :, cl:] = -1e9 if mask is not None: out = mx.fast.scaled_dot_product_attention( q, k, v, scale=self.scale, mask=mask ) else: out = mx.fast.scaled_dot_product_attention( q, k, v, scale=self.scale ) out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d) return self.o(out)