feat(wan): Add Wan2.2 I2V support
This commit is contained in:
@@ -13,156 +13,27 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan.loading import (
|
||||
_clean_text,
|
||||
encode_text,
|
||||
load_t5_encoder,
|
||||
load_vae_decoder,
|
||||
load_vae_encoder,
|
||||
load_wan_model,
|
||||
)
|
||||
from mlx_video.postprocess import save_video
|
||||
from mlx_video.utils import Colors
|
||||
|
||||
class Colors:
|
||||
CYAN = "\033[96m"
|
||||
BLUE = "\033[94m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
MAGENTA = "\033[95m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None):
|
||||
"""Load and initialize WanModel, with optional quantization support.
|
||||
|
||||
Args:
|
||||
model_path: Path to model safetensors file
|
||||
config: WanModelConfig
|
||||
quantization: Optional dict with 'bits' and 'group_size' keys.
|
||||
If provided, creates QuantizedLinear stubs before loading.
|
||||
"""
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
model = WanModel(config)
|
||||
|
||||
if quantization:
|
||||
from mlx_video.convert_wan import _quantize_predicate
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
return model
|
||||
|
||||
|
||||
def load_t5_encoder(model_path: Path, config):
|
||||
"""Load T5 text encoder.
|
||||
|
||||
Weights are upcast to float32 for maximum precision — the T5 encoder
|
||||
only runs once per generation, so performance impact is negligible.
|
||||
This matches the official which computes softmax in float32 explicitly.
|
||||
"""
|
||||
from mlx_video.models.wan.text_encoder import T5Encoder
|
||||
|
||||
encoder = T5Encoder(
|
||||
vocab_size=config.t5_vocab_size,
|
||||
dim=config.t5_dim,
|
||||
dim_attn=config.t5_dim_attn,
|
||||
dim_ffn=config.t5_dim_ffn,
|
||||
num_heads=config.t5_num_heads,
|
||||
num_layers=config.t5_num_layers,
|
||||
num_buckets=config.t5_num_buckets,
|
||||
shared_pos=False,
|
||||
)
|
||||
weights = mx.load(str(model_path))
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
encoder.load_weights(list(weights.items()))
|
||||
mx.eval(encoder.parameters())
|
||||
return encoder
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: Path, config=None):
|
||||
"""Load VAE decoder (skips encoder weights with strict=False).
|
||||
|
||||
For Wan2.2 (vae_z_dim=48), uses Wan22VAEDecoder.
|
||||
For Wan2.1 (vae_z_dim=16), uses WanVAE.
|
||||
"""
|
||||
is_wan22 = config is not None and config.vae_z_dim == 48
|
||||
|
||||
if is_wan22:
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
vae = Wan22VAEDecoder(z_dim=48)
|
||||
else:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
# Upcast VAE weights to float32 for quality — official Wan2.2 runs VAE in float32
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
vae.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(vae.parameters())
|
||||
return vae
|
||||
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
"""Clean text matching official Wan2.2 tokenizer preprocessing.
|
||||
|
||||
Applies ftfy.fix_text (fixes mojibake, normalizes fullwidth chars),
|
||||
double HTML unescape, and whitespace normalization. Critical for
|
||||
correct tokenization of the Chinese negative prompt.
|
||||
"""
|
||||
import html
|
||||
import re
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
text = ftfy.fix_text(text)
|
||||
except ImportError:
|
||||
pass
|
||||
text = html.unescape(html.unescape(text))
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
def encode_text(
|
||||
encoder,
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
text_len: int = 512,
|
||||
) -> mx.array:
|
||||
"""Encode text prompt using T5 encoder.
|
||||
|
||||
Args:
|
||||
encoder: T5Encoder model
|
||||
tokenizer: HuggingFace tokenizer
|
||||
prompt: Text prompt
|
||||
text_len: Maximum text length
|
||||
|
||||
Returns:
|
||||
Text embeddings [L, dim]
|
||||
"""
|
||||
prompt = _clean_text(prompt)
|
||||
tokens = tokenizer(
|
||||
prompt,
|
||||
max_length=text_len,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
ids = mx.array(tokens["input_ids"])
|
||||
mask = mx.array(tokens["attention_mask"])
|
||||
|
||||
embeddings = encoder(ids, mask=mask)
|
||||
|
||||
# Return only non-padding tokens
|
||||
seq_len = int(mask.sum().item())
|
||||
return embeddings[0, :seq_len]
|
||||
# Backward-compat alias (tests and external code may use the old name)
|
||||
_build_i2v_mask = build_i2v_mask
|
||||
|
||||
|
||||
def generate_video(
|
||||
model_dir: str,
|
||||
prompt: str,
|
||||
negative_prompt: str | None = None,
|
||||
image: str | None = None,
|
||||
width: int = 1280,
|
||||
height: int = 720,
|
||||
num_frames: int = 81,
|
||||
@@ -173,12 +44,13 @@ def generate_video(
|
||||
output_path: str = "output.mp4",
|
||||
scheduler: str = "unipc",
|
||||
):
|
||||
"""Generate video using Wan T2V pipeline (supports 2.1 and 2.2).
|
||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||
|
||||
Args:
|
||||
model_dir: Path to converted MLX model directory
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
|
||||
image: Path to input image for I2V (None = T2V mode)
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames (must be 4n+1)
|
||||
@@ -240,6 +112,7 @@ def generate_video(
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
|
||||
is_dual = config.dual_model
|
||||
is_i2v = image is not None
|
||||
|
||||
# Validate config against actual weights (handles mismatched config.json)
|
||||
if not is_dual:
|
||||
@@ -288,6 +161,7 @@ def generate_video(
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
|
||||
# Resolve negative prompt: explicit user value > config default
|
||||
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
|
||||
# that prevents oversaturation, artifacts, and comic look. We use it by default.
|
||||
@@ -297,9 +171,11 @@ def generate_video(
|
||||
else:
|
||||
neg_prompt_resolved = negative_prompt
|
||||
print(f"{Colors.CYAN}{'='*60}")
|
||||
print(f" {version_str} Text-to-Video Generation (MLX, {mode_str})")
|
||||
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
|
||||
print(f"{'='*60}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Prompt: {prompt}")
|
||||
if is_i2v:
|
||||
print(f" Image: {image}")
|
||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
@@ -314,8 +190,22 @@ def generate_video(
|
||||
np.random.seed(seed)
|
||||
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
|
||||
|
||||
# Compute target latent shape
|
||||
# Align dimensions to patch_size * vae_stride (required for patchify)
|
||||
vae_stride = config.vae_stride
|
||||
patch_size = config.patch_size
|
||||
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
|
||||
align_w = patch_size[2] * vae_stride[2]
|
||||
if height % align_h != 0 or width % align_w != 0:
|
||||
old_h, old_w = height, width
|
||||
height = (height // align_h) * align_h
|
||||
width = (width // align_w) * align_w
|
||||
if height == 0:
|
||||
height = align_h
|
||||
if width == 0:
|
||||
width = align_w
|
||||
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
||||
|
||||
# Compute target latent shape
|
||||
z_dim = config.vae_z_dim
|
||||
t_latent = (num_frames - 1) // vae_stride[0] + 1
|
||||
h_latent = height // vae_stride[1]
|
||||
@@ -323,7 +213,6 @@ def generate_video(
|
||||
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
||||
|
||||
# Sequence length for transformer
|
||||
patch_size = config.patch_size
|
||||
seq_len = math.ceil(
|
||||
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
|
||||
)
|
||||
@@ -352,6 +241,31 @@ def generate_video(
|
||||
gc.collect(); mx.clear_cache()
|
||||
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
# I2V: encode image to latent space
|
||||
z_img = None
|
||||
i2v_mask = None
|
||||
i2v_mask_tokens = None
|
||||
if is_i2v:
|
||||
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
|
||||
t_img = time.time()
|
||||
img_tensor = preprocess_image(image, width, height)
|
||||
mx.eval(img_tensor)
|
||||
|
||||
vae_path = model_dir / "vae.safetensors"
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||
mx.eval(z_img)
|
||||
|
||||
# Convert to channels-first: [z_dim, 1, H_lat, W_lat]
|
||||
z_img = z_img[0].transpose(3, 0, 1, 2)
|
||||
|
||||
# Build I2V mask
|
||||
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||
|
||||
del vae_enc, img_tensor
|
||||
gc.collect(); mx.clear_cache()
|
||||
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||
|
||||
# Load transformer models
|
||||
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||
if quantization:
|
||||
@@ -398,12 +312,18 @@ def generate_video(
|
||||
# Generate initial noise
|
||||
noise = mx.random.normal(target_shape)
|
||||
|
||||
# I2V: blend first-frame latent into noise
|
||||
if is_i2v:
|
||||
# Broadcast z_img [z_dim, 1, H, W] across T for first-frame conditioning
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
# Boundary for model switching (dual model only)
|
||||
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
|
||||
|
||||
# Diffusion loop
|
||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||
latents = noise
|
||||
t3 = time.time()
|
||||
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
@@ -424,10 +344,24 @@ def generate_video(
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
kv = cross_kv
|
||||
|
||||
# Build per-token timesteps for I2V (first-frame patches get t=0)
|
||||
if is_i2v:
|
||||
t_tokens = i2v_mask_tokens * timestep_val # [1, L]
|
||||
# Pad to seq_len if needed
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
# Batch for CFG: both cond and uncond get same timesteps
|
||||
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) # [2, L]
|
||||
else:
|
||||
t_batch = mx.array([timestep_val, timestep_val])
|
||||
|
||||
# CFG: batch cond + uncond into single B=2 forward pass
|
||||
preds = model(
|
||||
[latents, latents],
|
||||
t=mx.array([timestep_val, timestep_val]),
|
||||
t=t_batch,
|
||||
context=context_cfg,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
@@ -438,6 +372,10 @@ def generate_video(
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
|
||||
# I2V: re-apply mask to keep first frame frozen
|
||||
if is_i2v:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
||||
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
||||
mx.eval(latents)
|
||||
@@ -488,43 +426,12 @@ def generate_video(
|
||||
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
|
||||
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||
"""Save video frames to MP4.
|
||||
|
||||
Args:
|
||||
frames: Video frames [T, H, W, 3] uint8
|
||||
output_path: Output file path
|
||||
fps: Frames per second
|
||||
"""
|
||||
try:
|
||||
import imageio
|
||||
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
except ImportError:
|
||||
try:
|
||||
import cv2
|
||||
h, w = frames.shape[1], frames.shape[2]
|
||||
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
||||
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||||
for frame in frames:
|
||||
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
writer.release()
|
||||
except (ImportError, Exception):
|
||||
# Last resort: save as individual PNGs
|
||||
from PIL import Image
|
||||
out_dir = Path(output_path).parent / Path(output_path).stem
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, frame in enumerate(frames):
|
||||
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
||||
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument("--image", type=str, default=None,
|
||||
help="Path to input image for I2V (omit for T2V mode)")
|
||||
parser.add_argument("--negative-prompt", type=str, default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||
@@ -559,6 +466,7 @@ def main():
|
||||
model_dir=args.model_dir,
|
||||
prompt=args.prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
image=args.image,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
num_frames=args.num_frames,
|
||||
|
||||
Reference in New Issue
Block a user