feat(wan): Add Wan2.2 I2V support

This commit is contained in:
Daniel
2026-02-27 13:46:23 +01:00
parent 93da550f65
commit 2bb95c61ed
26 changed files with 4401 additions and 2968 deletions

View File

@@ -338,6 +338,10 @@ def convert_wan_checkpoint(
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
# Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072:
return WanModelConfig.wan22_ti2v_5b()
is_22 = model_version == "2.2"
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
@@ -409,7 +413,8 @@ def convert_wan_checkpoint(
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
weights = sanitize_wan22_vae_weights(weights)
include_encoder = config.model_type == "ti2v"
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
else:
weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in

View File

@@ -9,17 +9,7 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
# ANSI color codes
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
from mlx_video.utils import Colors
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel

View File

@@ -13,156 +13,27 @@ import mlx.nn as nn
import numpy as np
from tqdm import tqdm
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan.loading import (
_clean_text,
encode_text,
load_t5_encoder,
load_vae_decoder,
load_vae_encoder,
load_wan_model,
)
from mlx_video.postprocess import save_video
from mlx_video.utils import Colors
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
"""Load and initialize WanModel, with optional quantization support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
"""
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
if quantization:
from mlx_video.convert_wan import _quantize_predicate
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights = mx.load(str(model_path))
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model
def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder.
Weights are upcast to float32 for maximum precision — the T5 encoder
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
dim=config.t5_dim,
dim_attn=config.t5_dim_attn,
dim_ffn=config.t5_dim_ffn,
num_heads=config.t5_num_heads,
num_layers=config.t5_num_layers,
num_buckets=config.t5_num_buckets,
shared_pos=False,
)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters())
return encoder
def load_vae_decoder(model_path: Path, config=None):
"""Load VAE decoder (skips encoder weights with strict=False).
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
For Wan2.1 (vae_z_dim=16), uses WanVAE.
"""
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def _clean_text(text: str) -> str:
"""Clean text matching official Wan2.2 tokenizer preprocessing.
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
double HTML unescape, and whitespace normalization. Critical for
correct tokenization of the Chinese negative prompt.
"""
import html
import re
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text).strip()
return text
def encode_text(
encoder,
tokenizer,
prompt: str,
text_len: int = 512,
) -> mx.array:
"""Encode text prompt using T5 encoder.
Args:
encoder: T5Encoder model
tokenizer: HuggingFace tokenizer
prompt: Text prompt
text_len: Maximum text length
Returns:
Text embeddings [L, dim]
"""
prompt = _clean_text(prompt)
tokens = tokenizer(
prompt,
max_length=text_len,
padding="max_length",
truncation=True,
return_tensors="np",
)
ids = mx.array(tokens["input_ids"])
mask = mx.array(tokens["attention_mask"])
embeddings = encoder(ids, mask=mask)
# Return only non-padding tokens
seq_len = int(mask.sum().item())
return embeddings[0, :seq_len]
# Backward-compat alias (tests and external code may use the old name)
_build_i2v_mask = build_i2v_mask
def generate_video(
model_dir: str,
prompt: str,
negative_prompt: str | None = None,
image: str | None = None,
width: int = 1280,
height: int = 720,
num_frames: int = 81,
@@ -173,12 +44,13 @@ def generate_video(
output_path: str = "output.mp4",
scheduler: str = "unipc",
):
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
"""Generate video using Wan pipeline (supports T2V and I2V).
Args:
model_dir: Path to converted MLX model directory
prompt: Text prompt
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
image: Path to input image for I2V (None = T2V mode)
width: Video width
height: Video height
num_frames: Number of frames (must be 4n+1)
@@ -240,6 +112,7 @@ def generate_video(
config = WanModelConfig.wan21_t2v_14b()
is_dual = config.dual_model
is_i2v = image is not None
# Validate config against actual weights (handles mismatched config.json)
if not is_dual:
@@ -288,6 +161,7 @@ def generate_video(
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
# Resolve negative prompt: explicit user value > config default
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
# that prevents oversaturation, artifacts, and comic look. We use it by default.
@@ -297,9 +171,11 @@ def generate_video(
else:
neg_prompt_resolved = negative_prompt
print(f"{Colors.CYAN}{'='*60}")
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
print(f"{'='*60}{Colors.RESET}")
print(f"{Colors.DIM} Prompt: {prompt}")
if is_i2v:
print(f" Image: {image}")
if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
print(f" Neg prompt: {neg_display}")
@@ -314,8 +190,22 @@ def generate_video(
np.random.seed(seed)
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
# Compute target latent shape
# Align dimensions to patch_size * vae_stride (required for patchify)
vae_stride = config.vae_stride
patch_size = config.patch_size
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
align_w = patch_size[2] * vae_stride[2]
if height % align_h != 0 or width % align_w != 0:
old_h, old_w = height, width
height = (height // align_h) * align_h
width = (width // align_w) * align_w
if height == 0:
height = align_h
if width == 0:
width = align_w
print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
# Compute target latent shape
z_dim = config.vae_z_dim
t_latent = (num_frames - 1) // vae_stride[0] + 1
h_latent = height // vae_stride[1]
@@ -323,7 +213,6 @@ def generate_video(
target_shape = (z_dim, t_latent, h_latent, w_latent)
# Sequence length for transformer
patch_size = config.patch_size
seq_len = math.ceil(
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
)
@@ -352,6 +241,31 @@ def generate_video(
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
# I2V: encode image to latent space
z_img = None
i2v_mask = None
i2v_mask_tokens = None
if is_i2v:
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
t_img = time.time()
img_tensor = preprocess_image(image, width, height)
mx.eval(img_tensor)
vae_path = model_dir / "vae.safetensors"
vae_enc = load_vae_encoder(vae_path, config)
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
mx.eval(z_img)
# Convert to channels-first: [z_dim, 1, H_lat, W_lat]
z_img = z_img[0].transpose(3, 0, 1, 2)
# Build I2V mask
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
del vae_enc, img_tensor
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
# Load transformer models
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
if quantization:
@@ -398,12 +312,18 @@ def generate_video(
# Generate initial noise
noise = mx.random.normal(target_shape)
# I2V: blend first-frame latent into noise
if is_i2v:
# Broadcast z_img [z_dim, 1, H, W] across T for first-frame conditioning
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
else:
latents = noise
# Boundary for model switching (dual model only)
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
# Diffusion loop
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
latents = noise
t3 = time.time()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
@@ -424,10 +344,24 @@ def generate_video(
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
kv = cross_kv
# Build per-token timesteps for I2V (first-frame patches get t=0)
if is_i2v:
t_tokens = i2v_mask_tokens * timestep_val # [1, L]
# Pad to seq_len if needed
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
# Batch for CFG: both cond and uncond get same timesteps
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) # [2, L]
else:
t_batch = mx.array([timestep_val, timestep_val])
# CFG: batch cond + uncond into single B=2 forward pass
preds = model(
[latents, latents],
t=mx.array([timestep_val, timestep_val]),
t=t_batch,
context=context_cfg,
seq_len=seq_len,
cross_kv_caches=kv,
@@ -438,6 +372,10 @@ def generate_video(
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# I2V: re-apply mask to keep first frame frozen
if is_i2v:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
# Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
mx.eval(latents)
@@ -488,43 +426,12 @@ def generate_video(
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""Save video frames to MP4.
Args:
frames: Video frames [T, H, W, 3] uint8
output_path: Output file path
fps: Frames per second
"""
try:
import imageio
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
for frame in frames:
writer.append_data(frame)
writer.close()
except ImportError:
try:
import cv2
h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for frame in frames:
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
except (ImportError, Exception):
# Last resort: save as individual PNGs
from PIL import Image
out_dir = Path(output_path).parent / Path(output_path).stem
out_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames):
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--image", type=str, default=None,
help="Path to input image for I2V (omit for T2V mode)")
parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
@@ -559,6 +466,7 @@ def main():
model_dir=args.model_dir,
prompt=args.prompt,
negative_prompt=neg_prompt,
image=args.image,
width=args.width,
height=args.height,
num_frames=args.num_frames,

View File

@@ -71,8 +71,12 @@ class WanSelfAttention(nn.Module):
b, s, _ = x.shape
n, d = self.num_heads, self.head_dim
q = self.q(x)
k = self.k(x)
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
x_w = x.astype(w_dtype)
q = self.q(x_w)
k = self.k(x_w)
if self.norm_q is not None:
q = self.norm_q(q)
if self.norm_k is not None:
@@ -80,15 +84,15 @@ class WanSelfAttention(nn.Module):
q = q.reshape(b, s, n, d)
k = k.reshape(b, s, n, d)
v = self.v(x).reshape(b, s, n, d)
v = self.v(x_w).reshape(b, s, n, d)
# Apply RoPE
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs)
# Scaled dot-product attention: [B, L, N, D] -> [B, N, L, D]
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
k = k.astype(w_dtype).transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Build attention mask from seq_lens
@@ -149,11 +153,14 @@ class WanCrossAttention(nn.Module):
"""
b = context.shape[0]
n, d = self.num_heads, self.head_dim
k = self.k(context)
# Cast to weight dtype for efficient matmul
w_dtype = self.k.weight.dtype
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
return k, v
def __call__(
@@ -166,7 +173,9 @@ class WanCrossAttention(nn.Module):
b = x.shape[0]
n, d = self.num_heads, self.head_dim
q = self.q(x)
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
w_dtype = self.q.weight.dtype
q = self.q(x.astype(w_dtype))
if self.norm_q is not None:
q = self.norm_q(q)
q = q.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
@@ -174,11 +183,12 @@ class WanCrossAttention(nn.Module):
if kv_cache is not None:
k, v = kv_cache
else:
k = self.k(context)
ctx = context.astype(w_dtype)
k = self.k(ctx)
if self.norm_k is not None:
k = self.norm_k(k)
k = k.reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(context).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
v = self.v(ctx).reshape(b, -1, n, d).transpose(0, 2, 1, 3)
# Optional context masking
mask = None

View File

@@ -90,3 +90,24 @@ class WanModelConfig(BaseModelConfig):
def wan22_t2v_14b(cls) -> "WanModelConfig":
"""Wan2.2 T2V 14B: dual model, 40 layers, dim=5120 (default)."""
return cls()
@classmethod
def wan22_ti2v_5b(cls) -> "WanModelConfig":
"""Wan2.2 TI2V 5B: text+image to video, 30 layers, dim=3072."""
return cls(
model_type="ti2v",
dim=3072,
ffn_dim=14336,
in_dim=48,
out_dim=48,
num_heads=24,
num_layers=30,
vae_z_dim=48,
vae_stride=(4, 16, 16),
dual_model=False,
boundary=0.0,
sample_shift=5.0,
sample_steps=50,
sample_guide_scale=5.0,
sample_fps=24,
)

View File

@@ -0,0 +1,58 @@
"""Image-to-Video utility functions for Wan2.2."""
import mlx.core as mx
import numpy as np
def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
"""Load, resize, center-crop, and normalize an image for I2V.
Args:
image_path: Path to input image
width: Target width
height: Target height
Returns:
Image tensor [1, 1, H, W, 3] in [-1, 1] (channels-last, batch + temporal dims)
"""
from PIL import Image
img = Image.open(image_path).convert("RGB")
# Resize so that the image covers the target size (LANCZOS)
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
# Center crop
x1 = (img.width - width) // 2
y1 = (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
# To tensor: [H, W, 3] float32 in [-1, 1]
arr = np.array(img, dtype=np.float32) / 255.0
arr = arr * 2.0 - 1.0 # [0,1] → [-1,1]
return mx.array(arr[None, None]) # [1, 1, H, W, 3]
def build_i2v_mask(z_shape, patch_size):
"""Build temporal mask for I2V: first frame = 0, rest = 1.
Args:
z_shape: Latent shape (C, T, H, W) in channels-first
patch_size: (pt, ph, pw) patch size
Returns:
mask: (C, T, H, W) float32 — 0 for first frame, 1 for rest
mask_tokens: (1, L) float32 — 0 for first-frame tokens, 1 for rest
"""
C, T, H, W = z_shape
mask = mx.ones(z_shape)
# Zero out the first temporal position
mask = mx.concatenate([mx.zeros((C, 1, H, W)), mask[:, 1:]], axis=1)
# Token-level mask for per-token timesteps: subsample to patch grid
# mask shape [C, T, H, W] → take first channel, subsample by patch_size
pt, ph, pw = patch_size
mask_tokens = mask[0, ::pt, ::ph, ::pw] # [T', H', W']
mask_tokens = mask_tokens.reshape(1, -1) # [1, L]
return mask, mask_tokens

View File

@@ -0,0 +1,154 @@
"""Wan model loading utilities."""
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
"""Load and initialize WanModel, with optional quantization support.
Args:
model_path: Path to model safetensors file
config: WanModelConfig
quantization: Optional dict with 'bits' and 'group_size' keys.
If provided, creates QuantizedLinear stubs before loading.
"""
from mlx_video.models.wan.model import WanModel
model = WanModel(config)
if quantization:
from mlx_video.convert_wan import _quantize_predicate
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=lambda path, m: _quantize_predicate(path, m),
)
weights = mx.load(str(model_path))
model.load_weights(list(weights.items()), strict=False)
mx.eval(model.parameters())
return model
def load_t5_encoder(model_path: Path, config):
"""Load T5 text encoder.
Weights are upcast to float32 for maximum precision — the T5 encoder
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
dim=config.t5_dim,
dim_attn=config.t5_dim_attn,
dim_ffn=config.t5_dim_ffn,
num_heads=config.t5_num_heads,
num_layers=config.t5_num_layers,
num_buckets=config.t5_num_buckets,
shared_pos=False,
)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()))
mx.eval(encoder.parameters())
return encoder
def load_vae_decoder(model_path: Path, config=None):
"""Load VAE decoder (skips encoder weights with strict=False).
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
For Wan2.1 (vae_z_dim=16), uses WanVAE.
"""
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
vae.load_weights(list(weights.items()), strict=False)
mx.eval(vae.parameters())
return vae
def load_vae_encoder(model_path: Path, config=None):
"""Load VAE encoder for I2V image encoding.
Only supports Wan2.2 (vae_z_dim=48).
"""
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
encoder = Wan22VAEEncoder(z_dim=config.vae_z_dim)
weights = mx.load(str(model_path))
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
encoder.load_weights(list(weights.items()), strict=False)
mx.eval(encoder.parameters())
return encoder
def _clean_text(text: str) -> str:
"""Clean text matching official Wan2.2 tokenizer preprocessing.
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
double HTML unescape, and whitespace normalization. Critical for
correct tokenization of the Chinese negative prompt.
"""
import html
import re
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass
text = html.unescape(html.unescape(text))
text = re.sub(r"\s+", " ", text).strip()
return text
def encode_text(
encoder,
tokenizer,
prompt: str,
text_len: int = 512,
) -> mx.array:
"""Encode text prompt using T5 encoder.
Args:
encoder: T5Encoder model
tokenizer: HuggingFace tokenizer
prompt: Text prompt
text_len: Maximum text length
Returns:
Text embeddings [L, dim]
"""
prompt = _clean_text(prompt)
tokens = tokenizer(
prompt,
max_length=text_len,
padding="max_length",
truncation=True,
return_tensors="np",
)
ids = mx.array(tokens["input_ids"])
mask = mx.array(tokens["attention_mask"])
embeddings = encoder(ids, mask=mask)
# Return only non-padding tokens
seq_len = int(mask.sum().item())
return embeddings[0, :seq_len]

View File

@@ -15,17 +15,17 @@ def sinusoidal_embedding_1d(dim: int, position: mx.array) -> mx.array:
Args:
dim: Embedding dimension (must be even).
position: 1D tensor of positions.
position: Tensor of positions — 1D [L] or 2D [B, L].
Returns:
Embeddings of shape [len(position), dim].
Embeddings of shape [L, dim] or [B, L, dim].
"""
assert dim % 2 == 0
half = dim // 2
pos = position.astype(mx.float32)
inv_freq = mx.power(10000.0, -mx.arange(half).astype(mx.float32) / half)
sinusoid = pos[:, None] * inv_freq[None, :]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=1)
sinusoid = pos[..., None] * inv_freq # [..., half]
return mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
class Head(nn.Module):
@@ -44,16 +44,17 @@ class Head(nn.Module):
"""
Args:
x: [B, L, dim]
e: [B, dim] or [B, 1, dim] (time embedding, broadcast to all tokens)
e: [B, dim] or [B, 1, dim] (broadcast) or [B, L, dim] (per-token)
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
e_f32 = e.astype(mx.float32)
mod = (self.modulation + e_f32) # broadcasts [1, 2, dim] + [B, 1, dim] -> [B, 2, dim]
e0 = mod[:, 0:1, :] # [B, 1, dim] shift
e1 = mod[:, 1:2, :] # [B, 1, dim] scale
# modulation [1, 2, dim] broadcasts with e [B, 1/L, dim] via unsqueeze
mod = self.modulation.astype(mx.float32)[:, None, :, :] + e_f32[:, :, None, :] # [B, L_e, 2, dim]
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x).astype(mx.float32)
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L
x_mod = x_norm * (1 + e1) + e0 # broadcasts over L if L_e==1
return self.head(x_mod.astype(x.dtype))
@@ -261,18 +262,30 @@ class WanModel(nn.Module):
axis=0,
) # [B, seq_len, dim]
# Time embedding: compute once per sample, then broadcast to all tokens
# Time embedding
if t.ndim == 0:
t = t[None]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
model_dtype = self.patch_embedding_proj.weight.dtype
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(model_dtype)
e = e.astype(model_dtype)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, freq_dim]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, dim*6]
# Keep e and e0 in float32 — official asserts float32 for modulation
e0 = e0.reshape(batch_size, 1, 6, self.dim).astype(mx.float32)
e = e.astype(mx.float32)
else:
# I2V: per-token timesteps [B, L]
sin_emb = sinusoidal_embedding_1d(self.freq_dim, t) # [B, L, freq_dim]
e = self.time_embedding_1(
self.time_embedding_act(self.time_embedding_0(sin_emb))
) # [B, L, dim]
e0 = self.time_projection(self.time_projection_act(e)) # [B, L, dim*6]
# Keep e and e0 in float32 — official asserts float32 for modulation
e0 = e0.reshape(batch_size, -1, 6, self.dim).astype(mx.float32)
e = e.astype(mx.float32)
# Text embedding: skip MLP if context is already embedded (mx.array)
if isinstance(context, mx.array):

View File

@@ -187,7 +187,7 @@ class FlowUniPCScheduler:
solver_order: int = 2,
lower_order_final: bool = True,
disable_corrector: list | None = None,
use_corrector: bool = False,
use_corrector: bool = True,
):
self.num_train_timesteps = num_train_timesteps
self.solver_order = solver_order

View File

@@ -49,9 +49,9 @@ class WanAttentionBlock(nn.Module):
context_lens: list | None = None,
cross_kv_cache: tuple | None = None,
) -> mx.array:
# Compute modulation: e is [B, 1, 6, dim] (broadcasts over tokens)
mod = (self.modulation + e) # [1, 6, dim] + [B, 1, 6, dim] -> [B, 1, 6, dim]
# Split into 6 modulation vectors (each [B, 1, dim], broadcast over L)
# Modulation in float32 (matching official torch.amp.autocast float32)
e_f32 = e.astype(mx.float32)
mod = self.modulation.astype(mx.float32) + e_f32
e0 = mod[:, :, 0, :] # shift for self-attn
e1 = mod[:, :, 1, :] # scale for self-attn
e2 = mod[:, :, 2, :] # gate for self-attn
@@ -59,19 +59,19 @@ class WanAttentionBlock(nn.Module):
e4 = mod[:, :, 4, :] # scale for ffn
e5 = mod[:, :, 5, :] # gate for ffn
# Self-attention with modulation
x_mod = self.norm1(x) * (1 + e1) + e0
# Self-attention with modulation (norm output in float32)
x_mod = self.norm1(x).astype(mx.float32) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs)
x = x + y * e2
x = x.astype(mx.float32) + y.astype(mx.float32) * e2
# Cross-attention (no modulation, just norm)
x_cross = self.norm3(x) if self.norm3 is not None else x
x = x + self.cross_attn(x_cross, context, context_lens, kv_cache=cross_kv_cache)
# FFN with modulation
x_mod = self.norm2(x) * (1 + e4) + e3
# FFN with modulation (norm output in float32)
x_mod = self.norm2(x).astype(mx.float32) * (1 + e4) + e3
y = self.ffn(x_mod)
x = x + y * e5
x = x + y.astype(mx.float32) * e5
return x
@@ -86,4 +86,6 @@ class WanFFN(nn.Module):
self.fc2 = nn.Linear(ffn_dim, dim)
def __call__(self, x: mx.array) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
# Cast to weight dtype for efficient matmul (bfloat16 matching official autocast)
x_w = x.astype(self.fc1.weight.dtype)
return self.fc2(self.act(self.fc1(x_w)))

View File

@@ -53,7 +53,9 @@ class CausalConv3d(nn.Module):
self.kernel_size = kernel_size
self.stride = stride
self._causal_pad_t = 2 * padding[0]
# Causal temporal padding: always kernel_size-1 on the left.
# This matches the official CausalConv3d which pads (kernel[0]-1, 0, ...).
self._causal_pad_t = kernel_size[0] - 1
self._pad_h = padding[1]
self._pad_w = padding[2]
@@ -250,6 +252,46 @@ class DupUp3D(nn.Module):
return x
class AvgDown3D(nn.Module):
"""Downsample by grouping channels across spatial/temporal factors and averaging.
Inverse of DupUp3D. No learnable parameters.
Input: [B, T, H, W, C_in] → Output: [B, T//ft, H//fs, W//fs, C_out]
"""
def __init__(self, in_channels, out_channels, factor_t, factor_s=1):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor_t = factor_t
self.factor_s = factor_s
self.factor = factor_t * factor_s * factor_s
assert in_channels * self.factor % out_channels == 0
self.group_size = in_channels * self.factor // out_channels
def __call__(self, x):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
# Pad temporal if not divisible by factor_t
pad_t = (self.factor_t - T % self.factor_t) % self.factor_t
if pad_t > 0:
x = mx.pad(x, [(0, 0), (pad_t, 0), (0, 0), (0, 0), (0, 0)])
T = T + pad_t
ft, fs = self.factor_t, self.factor_s
# Reshape to split spatial/temporal dims
x = x.reshape(B, T // ft, ft, H // fs, fs, W // fs, fs, C)
# Move factors next to channels
x = x.transpose(0, 1, 3, 5, 7, 2, 4, 6) # [B, T', H', W', C, ft, fs, fs]
# Expand channels
x = x.reshape(B, T // ft, H // fs, W // fs, C * self.factor)
# Group and average
x = x.reshape(B, T // ft, H // fs, W // fs, self.out_channels, self.group_size)
x = x.mean(axis=-1)
return x
class Resample(nn.Module):
"""Spatial up/downsampling with optional temporal up/downsampling."""
@@ -267,6 +309,15 @@ class Resample(nn.Module):
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim*2, (3,1,1), padding=(1,0,0))
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
elif mode == "downsample3d":
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
raise ValueError(f"Unsupported mode: {mode}")
@@ -283,6 +334,12 @@ class Resample(nn.Module):
x = mx.pad(x, [(0, 0), (1, 1), (1, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight) + self.resample_bias
def _downsample_conv2d(self, x):
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
def __call__(self, x, first_chunk=False):
# x: [B, T, H, W, C]
B, T, H, W, C = x.shape
@@ -320,20 +377,37 @@ class Resample(nn.Module):
mx.eval(x)
T = x.shape[1]
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8
chunks = []
for t_start in range(0, T, chunk_size):
t_end = min(t_start + chunk_size, T)
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
x_chunk = self._upsample2x(x_chunk)
x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk)
chunks.append(x_chunk)
if self.mode == "downsample3d" and T > 1:
# Temporal downsample via strided CausalConv3d
# Skip for T=1 (single frame) — matches official chunked encoding
# where first chunk stores cache but doesn't apply time_conv
x = self.time_conv(x)
mx.eval(x)
T = x.shape[1]
if self.mode in ("upsample2d", "upsample3d"):
# Spatial upsample in temporal chunks to limit peak memory
chunk_size = 8
chunks = []
for t_start in range(0, T, chunk_size):
t_end = min(t_start + chunk_size, T)
x_chunk = x[:, t_start:t_end].reshape(-1, H, W, C)
x_chunk = self._upsample2x(x_chunk)
x_chunk = self._conv2d(x_chunk)
mx.eval(x_chunk)
chunks.append(x_chunk)
x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C)
elif self.mode in ("downsample2d", "downsample3d"):
# Spatial downsample: per-frame strided Conv2d
x_flat = x.reshape(B * T, H, W, C)
x_flat = self._downsample_conv2d(x_flat)
mx.eval(x_flat)
H2, W2 = x_flat.shape[1], x_flat.shape[2]
x = x_flat.reshape(B, T, H2, W2, C)
x = mx.concatenate(chunks, axis=0)
H2, W2 = x.shape[1], x.shape[2]
x = x.reshape(B, T, H2, W2, C)
return x
@@ -383,6 +457,44 @@ class Up_ResidualBlock(nn.Module):
return x_main
class Down_ResidualBlock(nn.Module):
"""Downsampling residual block with AvgDown3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
super().__init__()
self.down_flag = down_flag
# AvgDown3D shortcut (no learnable params, always present)
self.avg_shortcut = AvgDown3D(
in_dim, out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
# Main path: ResidualBlocks + optional Resample
blocks = []
dim_in = in_dim
for _ in range(num_res_blocks):
blocks.append(ResidualBlock(dim_in, out_dim))
dim_in = out_dim
if down_flag:
mode = "downsample3d" if temperal_downsample else "downsample2d"
blocks.append(Resample(out_dim, mode=mode))
self.downsamples = blocks
def __call__(self, x):
x_shortcut = self.avg_shortcut(x)
mx.eval(x_shortcut)
for module in self.downsamples:
x = module(x)
mx.eval(x)
return x + x_shortcut
class Decoder3d(nn.Module):
"""Wan2.2 3D VAE Decoder."""
@@ -439,6 +551,63 @@ class Decoder3d(nn.Module):
return x
class Encoder3d(nn.Module):
"""Wan2.2 3D VAE Encoder. Mirror of Decoder3d with downsampling."""
def __init__(
self,
dim=160,
z_dim=96,
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_downsample=(False, True, True),
):
super().__init__()
# Channel dimensions: [160, 160, 320, 640, 640]
dims = [dim * m for m in [1] + list(dim_mult)]
# Initial conv: patchified input (12 ch) → first dim
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
# Downsample blocks
self.downsamples = []
for i in range(len(dim_mult)):
in_d, out_d = dims[i], dims[i + 1]
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
self.downsamples.append(Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
))
# Middle blocks (same as decoder)
out_dim = dims[-1]
self.middle = [
ResidualBlock(out_dim, out_dim),
AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim),
]
# Output head: RMS_norm → SiLU → CausalConv3d → z_dim channels
self.head = Head22(out_dim, out_channels=z_dim)
def __call__(self, x):
# x: [B, T, H, W, 12] (patchified)
x = self.conv1(x)
for layer in self.downsamples:
x = layer(x)
for layer in self.middle:
x = layer(x)
mx.eval(x)
x = self.head(x)
return x
class Head22(nn.Module):
"""Decoder output head: RMS_norm → SiLU → CausalConv3d(dim, 12, 3).
@@ -460,6 +629,46 @@ class Head22(nn.Module):
return x
class Wan22VAEEncoder(nn.Module):
"""Full Wan2.2 VAE encoder with patchify and normalization."""
def __init__(self, z_dim=48, dim=160):
super().__init__()
self.z_dim = z_dim
# conv1: top-level 1x1x1 conv after encoder (z_dim*2 → z_dim*2)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.encoder = Encoder3d(
dim=dim,
z_dim=z_dim * 2, # Encoder outputs z_dim*2, split into mu + log_var
dim_mult=(1, 2, 4, 4),
num_res_blocks=2,
temperal_downsample=(False, True, True),
)
def __call__(self, img):
"""Encode image/video to latent space.
Args:
img: [B, T, H, W, 3] image/video in [-1, 1]
Returns:
mu: [B, T_lat, H_lat, W_lat, z_dim] normalized latent
"""
# Patchify: [B, T, H, W, 3] → [B, T, H/2, W/2, 12]
x = _patchify(img, patch_size=2)
# Encoder: [B, T, H/2, W/2, 12] → [B, T', H', W', z_dim*2]
out = self.encoder(x)
# conv1 (pointwise) + split into mu, log_var
out = self.conv1(out)
mu = out[:, :, :, :, :self.z_dim]
# Normalize
mu = normalize_latents(mu)
return mu
class Wan22VAEDecoder(nn.Module):
"""Full Wan2.2 VAE decoder with normalization and unpatchify."""
@@ -507,6 +716,15 @@ def denormalize_latents(z, mean=None, std=None):
return z * inv_scale.reshape(1, 1, 1, 1, -1) + mean.reshape(1, 1, 1, 1, -1)
def normalize_latents(z, mean=None, std=None):
"""Normalize latents: z_norm = (z - mean) / std. Inverse of denormalize_latents."""
if mean is None:
mean = VAE22_MEAN
if std is None:
std = VAE22_STD
return (z - mean.reshape(1, 1, 1, 1, -1)) / std.reshape(1, 1, 1, 1, -1)
def _unpatchify(x, patch_size=2):
"""Convert from packed channels to spatial: [B, T, H, W, C*p*p] → [B, T, H*p, W*p, C//(p*p)]
Actually: [B, T, H, W, 12] → [B, T, H*2, W*2, 3]
@@ -527,10 +745,30 @@ def _unpatchify(x, patch_size=2):
return x
def sanitize_wan22_vae_weights(weights: dict) -> dict:
def _patchify(x, patch_size=2):
"""Convert spatial to packed channels: [B, T, H*p, W*p, C] → [B, T, H, W, C*p*p]
Inverse of _unpatchify.
PyTorch: b c f (h q) (w r) -> b (c r q) f h w
In channels-last: [B, T, H*q, W*r, C] → [B, T, H, W, C*r*q]
"""
if patch_size == 1:
return x
B, T, Hfull, Wfull, C = x.shape
H = Hfull // patch_size
W = Wfull // patch_size
# [B, T, H, q, W, r, C]
x = x.reshape(B, T, H, patch_size, W, patch_size, C)
# Rearrange to pack q,r into channels: [B, T, H, W, C, r, q]
x = x.transpose(0, 1, 2, 4, 6, 5, 3) # [B, T, H, W, C, r, q]
x = x.reshape(B, T, H, W, C * patch_size * patch_size)
return x
def sanitize_wan22_vae_weights(weights: dict, include_encoder: bool = False) -> dict:
"""Convert PyTorch Wan2.2 VAE weights to MLX format.
Only keeps decoder + conv2 weights (encoder/conv1 not needed for generation).
By default keeps decoder + conv2 weights only. Set include_encoder=True
to also keep encoder + conv1 weights (needed for I2V encoding).
Transposes conv weights from channels-first to channels-last.
Squeezes RMS_norm gamma from (dim, 1, 1, 1) or (dim, 1, 1) to (dim,).
Maps PyTorch nn.Sequential indices to our named layers.
@@ -538,9 +776,10 @@ def sanitize_wan22_vae_weights(weights: dict) -> dict:
sanitized = {}
for key, value in weights.items():
# Skip encoder and conv1 (encoder-only)
if key.startswith("encoder.") or key.startswith("conv1."):
continue
# Skip encoder and conv1 unless requested
if not include_encoder:
if key.startswith("encoder.") or key.startswith("conv1."):
continue
new_key = key

View File

@@ -1,8 +1,42 @@
import numpy as np
from pathlib import Path
from typing import Optional
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""Save video frames to MP4.
Args:
frames: Video frames [T, H, W, 3] uint8
output_path: Output file path
fps: Frames per second
"""
try:
import imageio
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
for frame in frames:
writer.append_data(frame)
writer.close()
except ImportError:
try:
import cv2
h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
for frame in frames:
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
writer.release()
except (ImportError, Exception):
# Last resort: save as individual PNGs
from PIL import Image
out_dir = Path(output_path).parent / Path(output_path).stem
out_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames):
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
"""Apply bilateral filter to reduce grid artifacts while preserving edges.

View File

@@ -9,6 +9,20 @@ from pathlib import Path
from huggingface_hub import snapshot_download
from PIL import Image
class Colors:
"""ANSI color codes for terminal output."""
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try: