Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.

This commit is contained in:
Prince Canuma
2026-03-18 17:59:43 +01:00
parent b029668cd2
commit 996a542011
37 changed files with 354 additions and 354 deletions

View File

@@ -0,0 +1,388 @@
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .attention import WanLayerNorm, _linear_dtype
from .config import WanModelConfig
from .rope import rope_params, rope_precompute_cos_sin
from .transformer import WanAttentionBlock
def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
"""Compute sinusoidal positional embeddings.
Args:
dim: Embedding dimension (must be even).
position: Tensor of positions — 1D [L] or 2D [B, L].
Returns:
Embeddings of shape [L, dim] or [B, L, dim].
"""
assert dim % 2 == 0
half = dim // 2
pos = position.astype(mx.float32)
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
sinusoid = pos[..., None] * inv_freq # [..., half]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
class Head(nn.Module):
"""Output projection head with learned modulation."""
def __init__(self, dim: int, out_dim: int, patch_size: tuple, eps: float = 1e-6):
super().__init__()
self.out_dim = out_dim
self.patch_size = patch_size
proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
"""
Args:
x: [B, L, dim]
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
# Compute modulation in float32 (matching reference's autocast(float32))
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
x_mod = x_norm * (1 + e1) + e0
return self.head(x_mod)
class WanModel(nn.Module):
"""Wan2.2 diffusion backbone for text-to-video generation."""
def __init__(self, config: WanModelConfig):
super().__init__()
self.config = config
dim = config.dim
self.dim = dim
self.num_heads = config.num_heads
self.out_dim = config.out_dim
self.patch_size = config.patch_size
self.text_len = config.text_len
self.freq_dim = config.freq_dim
# Patch embedding: Conv3d implemented as a reshaped linear
# For kernel (1,2,2) and stride (1,2,2): reshape input then linear
patch_dim = config.in_dim * math.prod(config.patch_size)
self.patch_embedding_proj = nn.Linear(patch_dim, dim)
self._patch_size = config.patch_size
# Text embedding MLP
self.text_embedding_0 = nn.Linear(config.text_dim, dim)
self.text_embedding_act = nn.GELU(approx="tanh")
self.text_embedding_1 = nn.Linear(dim, dim)
# Time embedding MLP
self.time_embedding_0 = nn.Linear(config.freq_dim, dim)
self.time_embedding_act = nn.SiLU()
self.time_embedding_1 = nn.Linear(dim, dim)
# Time projection for modulation (6x dim)
self.time_projection_act = nn.SiLU()
self.time_projection = nn.Linear(dim, dim * 6)
# Transformer blocks
self.blocks = [
WanAttentionBlock(
dim=dim,
ffn_dim=config.ffn_dim,
num_heads=config.num_heads,
window_size=config.window_size,
qk_norm=config.qk_norm,
cross_attn_norm=config.cross_attn_norm,
eps=config.eps,
)
for _ in range(config.num_layers)
]
# Output head
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
# Precompute RoPE frequencies — three separate tables concatenated.
# Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads
self.freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Precompute sinusoidal inv_freq for time embedding.
half = config.freq_dim // 2
self._inv_freq = mx.array(
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
np.float32
)
)
def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings.
Args:
x: Video latent [C, F, H, W]
Returns:
(patches, grid_size): patches [1, L, dim], grid_size (F', H', W')
"""
c, f, h, w = x.shape
pt, ph, pw = self._patch_size
f_out = f // pt
h_out = h // ph
w_out = w // pw
# Reshape: [C, F, H, W] -> [F', H', W', C, pt, ph, pw] -> [F'*H'*W', C*pt*ph*pw]
# Order must be [C, pt, ph, pw] (C slowest) to match Conv3d weight layout
x = x.reshape(c, f_out, pt, h_out, ph, w_out, pw)
x = x.transpose(1, 3, 5, 0, 2, 4, 6) # [F', H', W', C, pt, ph, pw]
x = x.reshape(f_out * h_out * w_out, -1) # [L, C*pt*ph*pw]
# Project and cast to model dtype to prevent float32 cascade from input latents
patches = self.patch_embedding_proj(x) # [L, dim]
patches = patches.astype(_linear_dtype(self.patch_embedding_proj))
patches = patches[None, :, :] # [1, L, dim]
return patches, (f_out, h_out, w_out)
def unpatchify(self, x: mx.array, grid_sizes: list) -> list:
"""Reconstruct video from patch embeddings.
Args:
x: [B, L, out_dim * prod(patch_size)]
grid_sizes: List of (F', H', W') per batch element
Returns:
List of tensors [C, F, H, W]
"""
c = self.out_dim
pt, ph, pw = self.patch_size
out = []
for i, (f, h, w) in enumerate(grid_sizes):
seq_len = f * h * w
u = x[i, :seq_len] # [L, out_dim * pt * ph * pw]
u = u.reshape(f, h, w, pt, ph, pw, c)
# Rearrange: [F', H', W', pt, ph, pw, C] -> [C, F'*pt, H'*ph, W'*pw]
u = u.transpose(6, 0, 3, 1, 4, 2, 5) # [C, F', pt, H', ph, W', pw]
u = u.reshape(c, f * pt, h * ph, w * pw)
out.append(u)
return out
def embed_text(self, context: list) -> mx.array:
"""Precompute text embeddings (call once, reuse across steps).
Args:
context: List of text embeddings [L_text, text_dim]
Returns:
Embedded context [B, text_len, dim] in model dtype
"""
model_dtype = _linear_dtype(self.patch_embedding_proj)
context_padded = []
for ctx in context:
pad_len = self.text_len - ctx.shape[0]
if pad_len > 0:
ctx = mx.concatenate(
[ctx, mx.zeros((pad_len, ctx.shape[1]), dtype=ctx.dtype)],
axis=0,
)
context_padded.append(ctx)
context_batch = mx.stack(context_padded) # [B, text_len, text_dim]
context_batch = self.text_embedding_1(
self.text_embedding_act(self.text_embedding_0(context_batch))
)
return context_batch.astype(model_dtype)
def prepare_cross_kv(self, context: mx.array) -> list:
"""Pre-compute cross-attention K/V for all blocks.
Call once before the diffusion loop to cache K/V projections,
eliminating redundant computation at each denoising step.
Args:
context: Pre-embedded text [B, text_len, dim]
Returns:
List of (k, v) tuples, one per block
"""
kv_caches = []
for block in self.blocks:
kv_caches.append(block.cross_attn.prepare_kv(context))
return kv_caches
def prepare_rope(self, grid_sizes: list) -> tuple:
"""Pre-compute RoPE cos/sin for constant grid sizes.
Call once before the diffusion loop when grid sizes don't change
across steps. Eliminates per-step broadcast/concat overhead.
Args:
grid_sizes: List of (F, H, W) tuples per batch element
Returns:
(cos_f, sin_f) precomputed frequency tensors
"""
w_dtype = _linear_dtype(self.patch_embedding_proj)
return rope_precompute_cos_sin(grid_sizes, self.freqs, dtype=w_dtype)
def __call__(
self,
x_list: list,
t: mx.array,
context: list | mx.array,
seq_len: int,
cross_kv_caches: list | None = None,
y: list | None = None,
rope_cos_sin: tuple | None = None,
) -> list:
"""Forward pass.
Args:
x_list: List of video latent tensors [C, F, H, W]
t: Timestep tensor [B]
context: List of raw text embeddings, OR pre-embedded tensor
from embed_text() [B, text_len, dim]
seq_len: Maximum sequence length for padding
cross_kv_caches: Optional list of (k, v) tuples from
prepare_cross_kv(), one per block.
y: Optional list of conditioning tensors for I2V [C_y, F, H, W].
Channel-concatenated with x before patchify.
rope_cos_sin: Optional precomputed (cos, sin) from prepare_rope().
Returns:
List of denoised tensors [C, F, H, W]
"""
# Detect identical inputs (CFG B=2) to avoid duplicate patchify work.
# Check BEFORE I2V concat since concat creates new array objects.
batch_size = len(x_list)
all_same = batch_size > 1 and all(
x_list[i] is x_list[0] for i in range(1, batch_size)
)
if all_same and y is not None:
all_same = all(y[i] is y[0] for i in range(1, len(y)))
# I2V: channel-concatenate conditioning y with noise x
if y is not None:
x_list = [mx.concatenate([u, v], axis=0) for u, v in zip(x_list, y)]
if all_same:
# Patchify once and broadcast — saves a Linear projection per step
p, gs = self._patchify(x_list[0]) # [1, L, dim]
grid_sizes = [gs] * batch_size
seq_lens_list = [p.shape[1]] * batch_size
# Pad and broadcast
if p.shape[1] < seq_len:
p = mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
)
x = mx.broadcast_to(p, (batch_size,) + p.shape[1:])
else:
patches = []
grid_sizes = []
seq_lens_list = []
for vid in x_list:
p, gs = self._patchify(vid) # [1, L, dim]
patches.append(p)
grid_sizes.append(gs)
seq_lens_list.append(p.shape[1])
x = mx.concatenate(
[
(
mx.concatenate(
[
p,
mx.zeros(
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
),
],
axis=1,
)
if p.shape[1] < seq_len
else p
)
for p in patches
],
axis=0,
) # [B, seq_len, dim]
# Time embedding: sinusoidal from precomputed inv_freq.
# inv_freq was computed in float64 for precision, stored as float32.
# With integer timesteps (matching reference), float32 sin/cos is fine.
if t.ndim == 0:
t = t[None]
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
e0 = e0.reshape(batch_size, 1, 6, self.dim)
else:
# I2V: per-token timesteps [B, L]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, L, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
e0 = e0.reshape(batch_size, -1, 6, self.dim)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):
# Pre-embedded: expand to batch size if needed
context_batch = context
if context_batch.shape[0] == 1 and batch_size > 1:
context_batch = mx.broadcast_to(
context_batch, (batch_size,) + context_batch.shape[1:]
)
else:
context_batch = self.embed_text(context)
# Pre-compute attention mask from seq_lens (constant across all blocks)
attn_mask = None
w_dtype = _linear_dtype(self.patch_embedding_proj)
if any(sl < seq_len for sl in seq_lens_list):
attn_mask = mx.zeros((batch_size, 1, 1, seq_len), dtype=w_dtype)
for i, sl in enumerate(seq_lens_list):
attn_mask[i, :, :, sl:] = -1e9
kwargs = dict(
e=e0,
seq_lens=seq_lens_list,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context_batch,
context_lens=None,
rope_cos_sin=rope_cos_sin,
attn_mask=attn_mask,
)
# Run transformer blocks
for i, block in enumerate(self.blocks):
kv = cross_kv_caches[i] if cross_kv_caches is not None else None
x = block(x, cross_kv_cache=kv, **kwargs)
# Output head
x = self.head(x, e)
# Unpatchify
outputs = self.unpatchify(x, grid_sizes)
return [u.astype(mx.float32) for u in outputs]