import math import mlx.core as mx import mlx.nn as nn import numpy as np from .attention import WanLayerNorm, _linear_dtype from .config import WanModelConfig from .rope import rope_params, rope_precompute_cos_sin from .transformer import WanAttentionBlock def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array: """Compute sinusoidal positional embeddings. Args: dim: Embedding dimension (must be even). position: Tensor of positions — 1D [L] or 2D [B, L]. Returns: Embeddings of shape [L, dim] or [B, L, dim]. """ assert dim % 2 == 0 half = dim // 2 pos = position.astype(mx.float32) inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half) sinusoid = pos[..., None] * inv_freq # [..., half] return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1) class Head(nn.Module): """Output projection head with learned modulation.""" def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6): super().__init__() self.out_dim = out_dim self.patch_size = patch_size proj_dim = math.prod(patch_size) * out_dim self.norm = WanLayerNorm(dim, eps) self.head = nn.Linear(dim, proj_dim) self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32) def __call__(self, x: mx.array, e: mx.array) -> mx.array: """ Args: x: [B, L, dim] e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token) """ if e.ndim == 2: e = e[:, None, :] # [B, 1, dim] # Compute modulation in float32 (matching reference's autocast(float32)) mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32 e0 = mod[:, :, 0, :] # [B, L_e, dim] shift e1 = mod[:, :, 1, :] # [B, L_e, dim] scale x_norm = self.norm(x) x_mod = x_norm * (1 + e1) + e0 return self.head(x_mod) class WanModel(nn.Module): """Wan2.2 diffusion backbone for text-to-video generation.""" def __init__(self, config: WanModelConfig): super().__init__() self.config = config dim = config.dim self.dim = dim self.num_heads = config.num_heads self.out_dim = config.out_dim self.patch_size = config.patch_size self.text_len = config.text_len self.freq_dim = config.freq_dim # Patch embedding: Conv3d implemented as a reshaped linear # For kernel (1,2,2) and stride (1,2,2): reshape input then linear patch_dim = config.in_dim * math.prod(config.patch_size) self.patch_embedding_proj = nn.Linear(patch_dim, dim) self._patch_size = config.patch_size # Text embedding MLP self.text_embedding_0 = nn.Linear(config.text_dim, dim) self.text_embedding_act = nn.GELU(approx="tanh") self.text_embedding_1 = nn.Linear(dim, dim) # Time embedding MLP self.time_embedding_0 = nn.Linear(config.freq_dim, dim) self.time_embedding_act = nn.SiLU() self.time_embedding_1 = nn.Linear(dim, dim) # Time projection for modulation (6x dim) self.time_projection_act = nn.SiLU() self.time_projection = nn.Linear(dim, dim * 6) # Transformer blocks self.blocks = [ WanAttentionBlock( dim=dim, ffn_dim=config.ffn_dim, num_heads=config.num_heads, window_size=config.window_size, qk_norm=config.qk_norm, cross_attn_norm=config.cross_attn_norm, eps=config.eps, ) for _ in range(config.num_layers) ] # Output head self.head = Head(dim, config.out_dim, config.patch_size, config.eps) # Precompute RoPE frequencies — three separate tables concatenated. # Reference computes three rope_params with different dim normalizations # so each axis (temporal/height/width) gets its own full frequency range. d = dim // config.num_heads self.freqs = mx.concatenate([ rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)), ], axis=1) # Precompute sinusoidal inv_freq for time embedding. half = config.freq_dim // 2 self._inv_freq = mx.array( np.power(10000.0, -np.arange(half, dtype=np.float64) / half ).astype(np.float32) ) def _patchify(self, x: mx.array) -> tuple: """Convert video tensor to patch embeddings. Args: x: Video latent [C, F, H, W] Returns: (patches, grid_size): patches [1, L, dim], grid_size (F', H', W') """ c, f, h, w = x.shape pt, ph, pw = self._patch_size f_out = f // pt h_out = h // ph w_out = w // pw # Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw] # Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw) x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw] x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw] # Project and cast to model dtype to prevent float32 cascade from input latents patches = self.patch_embedding_proj(x) # [L, dim] patches = patches.astype(_linear_dtype(self.patch_embedding_proj)) patches = patches[None, :, :] # [1, L, dim] return patches, (f_out, h_out, w_out) def unpatchify(self, x: mx.array, grid_sizes: list) -> list: """Reconstruct video from patch embeddings. Args: x: [B, L, out_dim * prod(patch_size)] grid_sizes: List of (F', H', W') per batch element Returns: List of tensors [C, F, H, W] """ c = self.out_dim pt, ph, pw = self.patch_size out = [] for i, (f, h, w) in enumerate(grid_sizes): seq_len = f * h * w u = x[i, :seq_len] # [L, out_dim * pt * ph * pw] u = u.reshape(f, h, w, pt, ph, pw, c) # Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw] u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw] u = u.reshape(c, f * pt, h * ph, w * pw) out.append(u) return out def embed_text(self, context: list) -> mx.array: """Precompute text embeddings (call once, reuse across steps). Args: context: List of text embeddings [L_text, text_dim] Returns: Embedded context [B, text_len, dim] in model dtype """ model_dtype = _linear_dtype(self.patch_embedding_proj) context_padded = [] for ctx in context: pad_len = self.text_len - ctx.shape[0] if pad_len > 0: ctx = mx.concatenate( [ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)], axis=0, ) context_padded.append(ctx) context_batch = mx.stack(context_padded) # [B, text_len, text_dim] context_batch = self.text_embedding_1( self.text_embedding_act(self.text_embedding_0(context_batch)) ) return context_batch.astype(model_dtype) def prepare_cross_kv(self, context: mx.array) -> list: """Pre-compute cross-attention K/V for all blocks. Call once before the diffusion loop to cache K/V projections, eliminating redundant computation at each denoising step. Args: context: Pre-embedded text [B, text_len, dim] Returns: List of (k, v) tuples, one per block """ kv_caches = [] for block in self.blocks: kv_caches.append(block.cross_attn.prepare_kv(context)) return kv_caches def prepare_rope(self, grid_sizes: list) -> tuple: """Pre-compute RoPE cos/sin for constant grid sizes. Call once before the diffusion loop when grid sizes don't change across steps. Eliminates per-step broadcast/concat overhead. Args: grid_sizes: List of (F, H, W) tuples per batch element Returns: (cos_f, sin_f) precomputed frequency tensors """ w_dtype = _linear_dtype(self.patch_embedding_proj) return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype) def __call__( self, x_list: list, t: mx.array, context: list | mx.array, seq_len: int, cross_kv_caches: list | None = None, y: list | None = None, rope_cos_sin: tuple | None = None, ) -> list: """Forward pass. Args: x_list: List of video latent tensors [C, F, H, W] t: Timestep tensor [B] context: List of raw text embeddings, OR pre-embedded tensor from embed_text() [B, text_len, dim] seq_len: Maximum sequence length for padding cross_kv_caches: Optional list of (k, v) tuples from prepare_cross_kv(), one per block. y: Optional list of conditioning tensors for I2V [C_y, F, H, W]. Channel-concatenated with x before patchify. rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope(). Returns: List of denoised tensors [C, F, H, W] """ # Detect identical inputs (CFG B=2) to avoid duplicate patchify work. # Check BEFORE I2V concat since concat creates new array objects. batch_size = len(x_list) all_same = batch_size > 1 and all( x_list[i] is x_list[0] for i in range(1, batch_size) ) if all_same and y is not None: all_same = all(y[i] is y[0] for i in range(1, len(y))) # I2V: channel-concatenate conditioning y with noise x if y is not None: x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)] if all_same: # Patchify once and broadcast — saves a Linear projection per step p, gs = self._patchify(x_list[0]) # [1, L, dim] grid_sizes = [gs] * batch_size seq_lens_list = [p.shape[1]] * batch_size # Pad and broadcast if p.shape[1] < seq_len: p = mx.concatenate( [p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)], axis=1, ) x = mx.broadcast_to(p, (batch_size,) + p.shape[1:]) else: patches = [] grid_sizes = [] seq_lens_list = [] for vid in x_list: p, gs = self._patchify(vid) # [1, L, dim] patches.append(p) grid_sizes.append(gs) seq_lens_list.append(p.shape[1]) x = mx.concatenate( [ mx.concatenate( [p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)], axis=1, ) if p.shape[1] < seq_len else p for p in patches ], axis=0, ) # [B, seq_len, dim] # Time embedding: sinusoidal from precomputed inv_freq. # inv_freq was computed in float64 for precision, stored as float32. # With integer timesteps (matching reference), float32 sin/cos is fine. if t.ndim == 0: t = t[None] sinusoid = t[..., None].astype(mx.float32) * self._inv_freq sin_emb = mx.concatenate( [mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1 ) if t.ndim == 1: # Standard T2V: scalar timestep per batch element [B] e = self.time_embedding_1( self.time_embedding_act(self.time_embedding_0(sin_emb)) ) # [B, dim] e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6] e0 = e0.reshape(batch_size, 1, 6, self.dim) else: # I2V: per-token timesteps [B, L] e = self.time_embedding_1( self.time_embedding_act(self.time_embedding_0(sin_emb)) ) # [B, L, dim] e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6] e0 = e0.reshape(batch_size, -1, 6, self.dim) # Text embedding: skip MLP if context is already embedded (mx.array) if isinstance(context, mx.array): # Pre-embedded: expand to batch size if needed context_batch = context if context_batch.shape[0] == 1 and batch_size > 1: context_batch = mx.broadcast_to( context_batch, (batch_size,) + context_batch.shape[1:] ) else: context_batch = self.embed_text(context) # Pre-compute attention mask from seq_lens (constant across all blocks) attn_mask = None w_dtype = _linear_dtype(self.patch_embedding_proj) if any(sl < seq_len for sl in seq_lens_list): attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype) for i, sl in enumerate(seq_lens_list): attn_mask[i, :, :, sl:] = -1e9 kwargs = dict( e=e0, seq_lens=seq_lens_list, grid_sizes=grid_sizes, freqs=self.freqs, context=context_batch, context_lens=None, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask, ) # Run transformer blocks for i, block in enumerate(self.blocks): kv = cross_kv_caches[i] if cross_kv_caches is not None else None x = block(x, cross_kv_cache=kv, **kwargs) # Output head x = self.head(x, e) # Unpatchify outputs = self.unpatchify(x, grid_sizes) return [u.astype(mx.float32) for u in outputs]