Files
mlx-video/mlx_video/models/wan/generate.py

829 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Wan2.2 Text-to-Video generation pipeline for MLX."""
import argparse
import gc
import math
import random
import sys
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan.loading import (
_clean_text,
encode_text,
load_t5_encoder,
load_vae_decoder,
load_vae_encoder,
load_wan_model,
)
from mlx_video.models.wan.postprocess import save_video
class Colors:
"""ANSI color codes for terminal output."""
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
# Backward-compat alias (tests and external code may use the old name)
_build_i2v_mask = build_i2v_mask
def _best_output_size(w, h, dw, dh, max_area):
"""Compute the best output resolution that fits within max_area while
preserving the input aspect ratio and satisfying alignment constraints.
Matches the reference implementation's best_output_size().
"""
ratio = w / h
ow = (max_area * ratio) ** 0.5
oh = max_area / ow
# Option 1: process width first
ow1 = int(ow // dw * dw)
oh1 = int(max_area / ow1 // dh * dh)
ratio1 = ow1 / oh1
# Option 2: process height first
oh2 = int(oh // dh * dh)
ow2 = int(max_area / oh2 // dw * dw)
ratio2 = ow2 / oh2
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
return ow1, oh1
return ow2, oh2
def generate_video(
model_dir: str,
prompt: str,
negative_prompt: str | None = None,
image: str | None = None,
width: int = 1280,
height: int = 704,
num_frames: int = 81,
steps: int = None,
guide_scale: str | float | tuple = None,
shift: float = None,
seed: int = -1,
output_path: str = "output.mp4",
scheduler: str = "unipc",
loras: list | None = None,
loras_high: list | None = None,
loras_low: list | None = None,
tiling: str = "auto",
no_compile: bool = False,
trim_first_frames: int = 0,
debug_latents: bool = False,
):
"""Generate video using Wan pipeline (supports T2V and I2V).
Args:
model_dir: Path to converted MLX model directory
prompt: Text prompt
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
image: Path to input image for I2V (None = T2V mode)
width: Video width
height: Video height
num_frames: Number of frames (must be 4n+1)
steps: Number of diffusion steps (None = use config default)
guide_scale: Guidance scale: float for single, (low,high) for dual (None = config default)
shift: Noise schedule shift (None = use config default)
seed: Random seed (-1 for random)
output_path: Output video path
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
loras: Optional list of (path, strength) tuples applied to all models
loras_high: Optional list of (path, strength) tuples for high-noise model only
loras_low: Optional list of (path, strength) tuples for low-noise model only
tiling: Tiling mode for VAE decoding. Options:
- "auto": Automatically determine tiling based on video size (default)
- "none": Disable tiling
- "default", "aggressive", "conservative": Preset tiling configs
- "spatial": Spatial tiling only
- "temporal": Temporal tiling only
no_compile: If True, skip mx.compile on models (useful for debugging)
trim_first_frames: Number of temporal latent positions to generate extra
and discard from the start. Each position = 4 pixel frames. Use 1
to fix first-frame artifacts on 14B models (generates 4 extra frames,
discards first 4). Use 2 for more aggressive trimming. Default: 0.
debug_latents: If True, print per-temporal-position latent statistics
after denoising for diagnosing first-frame artifacts.
"""
import json
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
model_dir = Path(model_dir)
# Load config from model dir if available, otherwise auto-detect
config_path = model_dir / "config.json"
quantization = None
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Extract quantization config (not a model config field)
quantization = config_dict.pop("quantization", None)
# Handle tuple fields stored as lists in JSON
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**{
k: v for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
})
else:
# Auto-detect: dual model files → 2.2, single model → 2.1
if (model_dir / "low_noise_model.safetensors").exists():
config = WanModelConfig.wan22_t2v_14b()
else:
# Detect 1.3B vs 14B from weight shapes
model_path = model_dir / "model.safetensors"
if model_path.exists():
probe = mx.load(str(model_path), return_metadata=False)
for k, v in probe.items():
if "patch_embedding_proj.weight" in k:
dim = v.shape[0]
if dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
config = WanModelConfig.wan21_t2v_14b()
break
else:
config = WanModelConfig.wan21_t2v_14b()
del probe
else:
config = WanModelConfig.wan21_t2v_14b()
is_dual = config.dual_model
is_i2v = image is not None
# Validate config against actual weights (handles mismatched config.json)
if not is_dual:
model_path = model_dir / "model.safetensors"
if model_path.exists():
probe = mx.load(str(model_path), return_metadata=False)
for k, v in probe.items():
if "patch_embedding_proj.weight" in k:
actual_dim = v.shape[0]
if actual_dim != config.dim:
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
if actual_dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
config = WanModelConfig.wan21_t2v_14b()
break
del probe
# Auto-correct Wan2.2 VAE params from stale configs
if config.in_dim == 48 and config.vae_z_dim != 48:
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
config = WanModelConfig(**{
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
})
# Apply defaults from config if not overridden
if steps is None:
steps = config.sample_steps
if shift is None:
shift = config.sample_shift
if guide_scale is None:
guide_scale = config.sample_guide_scale
# Normalize guide_scale
if isinstance(guide_scale, (int, float)):
guide_scale = float(guide_scale)
elif isinstance(guide_scale, str):
parts = [float(x) for x in guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
if isinstance(guide_scale, tuple):
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
else:
cfg_disabled = guide_scale <= 1.0
# Validate frame count
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
gen_frames = num_frames
if trim_first_frames > 0:
gen_frames = num_frames + trim_first_frames * 4
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
# Resolve negative prompt: explicit user value > config default
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
# that prevents oversaturation, artifacts, and comic look. We use it by default.
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
if negative_prompt is None:
neg_prompt_resolved = config.sample_neg_prompt
else:
neg_prompt_resolved = negative_prompt
print(f"{Colors.CYAN}{'='*60}")
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
print(f"{'='*60}{Colors.RESET}")
print(f"{Colors.DIM} Prompt: {prompt}")
if is_i2v:
print(f" Image: {image}")
if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
if cfg_disabled:
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
print(f"{Colors.RESET}")
# Seed
if seed < 0:
seed = random.randint(0, 2**32 - 1)
mx.random.seed(seed)
np.random.seed(seed)
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
# Align dimensions to patch_size * vae_stride (required for patchify)
vae_stride = config.vae_stride
patch_size = config.patch_size
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
align_w = patch_size[2] * vae_stride[2]
if height % align_h != 0 or width % align_w != 0:
old_h, old_w = height, width
height = (height // align_h) * align_h
width = (width // align_w) * align_w
if height == 0:
height = align_h
if width == 0:
width = align_w
print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
# Enforce max_area constraint (model-specific resolution limit)
if config.max_area > 0 and height * width > config.max_area:
old_h, old_w = height, width
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
print(
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
)
# Compute target latent shape
z_dim = config.vae_z_dim
t_latent = (gen_frames - 1) // vae_stride[0] + 1
h_latent = height // vae_stride[1]
w_latent = width // vae_stride[2]
target_shape = (z_dim, t_latent, h_latent, w_latent)
# Sequence length for transformer
seq_len = math.ceil(
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
)
print(f"{Colors.DIM} Latent shape: {target_shape}")
print(f" Sequence length: {seq_len}{Colors.RESET}")
# Load T5 encoder
t1 = time.time()
print(f"\n{Colors.BLUE}Loading T5 encoder...{Colors.RESET}")
t5_path = model_dir / "t5_encoder.safetensors"
t5_encoder = load_t5_encoder(t5_path, config)
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
# Encode prompts
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
if cfg_disabled:
context_null = None
mx.eval(context)
else:
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
mx.eval(context, context_null)
# Free T5 from memory
del t5_encoder
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
# I2V: encode image to latent space
z_img = None
i2v_mask = None
i2v_mask_tokens = None
y_i2v = None
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
if is_i2v:
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
t_img = time.time()
vae_path = model_dir / "vae.safetensors"
if is_i2v_channel_concat:
# I2V-14B: encode full video (first frame = image, rest = zeros)
# and construct y tensor with mask + encoded latents
from PIL import Image
img = Image.open(image).convert("RGB")
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
video = mx.concatenate([
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
], axis=1)
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
vae_enc = load_vae_encoder(vae_path, config)
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
mx.eval(z_video)
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
msk = mx.concatenate([
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
], axis=1)
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
y_i2v = mx.concatenate([msk, z_video], axis=0)
mx.eval(y_i2v)
del vae_enc, img_arr, img_chw, video, z_video, msk
else:
# TI2V-5B: encode single image, blend with noise via mask
img_tensor = preprocess_image(image, width, height)
mx.eval(img_tensor)
vae_enc = load_vae_encoder(vae_path, config)
z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
mx.eval(z_img)
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
del vae_enc, img_tensor
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
# Load transformer models
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
if quantization:
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
t2 = time.time()
# Merge per-model LoRAs with shared LoRAs
_loras_low = (loras or []) + (loras_low or []) or None
_loras_high = (loras or []) + (loras_high or []) or None
_loras_single = loras
if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step)
# Each model has its own text_embedding weights, so dual models need separate embeddings
if cfg_disabled:
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
if is_dual:
context_emb_low = low_noise_model.embed_text([context])
context_emb_high = high_noise_model.embed_text([context])
mx.eval(context_emb_low, context_emb_high)
context_cond_low = context_emb_low[0:1]
context_cond_high = context_emb_high[0:1]
else:
context_emb = single_model.embed_text([context])
mx.eval(context_emb)
context_cond = context_emb[0:1]
else:
if is_dual:
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb)
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
# Precompute cross-attention K/V caches (constant across all steps)
if cfg_disabled:
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cond)
mx.eval(cross_kv)
else:
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv)
# Precompute RoPE frequencies (grid sizes are constant across all steps)
f_grid = t_latent // patch_size[0]
h_grid = h_latent // patch_size[1]
w_grid = w_latent // patch_size[2]
if cfg_disabled:
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
else:
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
if is_dual:
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
else:
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
mx.eval(rope_cos_sin)
# Setup scheduler
_schedulers = {
"euler": FlowMatchEulerScheduler,
"dpm++": FlowDPMPP2MScheduler,
"unipc": FlowUniPCScheduler,
}
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
sched.set_timesteps(steps, shift=shift)
# Generate initial noise
noise = mx.random.normal(target_shape)
# I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
else:
latents = noise
# Boundary for model switching (dual model only)
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
# Diffusion loop
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
t3 = time.time()
# Compile model forward for faster denoising
if not no_compile:
models_to_compile = (
[high_noise_model, low_noise_model] if is_dual else [single_model]
)
for m in models_to_compile:
m._compiled = mx.compile(m)
# Pre-convert timesteps to Python list to avoid .item() sync each step
timestep_list = sched.timesteps.tolist()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = timestep_list[i]
# Select model, cached K/V, and precomputed RoPE
if is_dual:
if timestep_val >= boundary:
model = high_noise_model
kv = cross_kv_high
rcs = rope_cos_sin_high
else:
model = low_noise_model
kv = cross_kv_low
rcs = rope_cos_sin_low
else:
model = single_model
kv = cross_kv
rcs = rope_cos_sin
# Use compiled forward when available (faster after first trace)
_call = getattr(model, '_compiled', model)
if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = t_tokens # [1, L]
else:
t_batch = mx.array([timestep_val])
y_arg = [y_i2v] if is_i2v_channel_concat else None
if is_dual:
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
else:
ctx = context_cond
preds = _call(
[latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred = preds[0]
del preds
else:
# CFG: batch cond + uncond into single B=2 forward pass
if is_dual:
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
else:
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
pad_len = seq_len - t_tokens.shape[1]
if pad_len > 0:
t_tokens = mx.concatenate(
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
)
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
else:
t_batch = mx.array([timestep_val, timestep_val])
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
)
preds = _call(
[latents, latents],
t=t_batch,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
del noise_pred_cond, noise_pred_uncond, preds
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# TI2V-5B: re-apply mask to keep first frame frozen
if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
# Release temporaries before eval to free memory for graph execution
del noise_pred
mx.eval(latents)
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
# Diagnostic: per-temporal-position latent statistics
if debug_latents:
lat_np = np.array(latents) # [C, T, H, W]
n_t = lat_np.shape[1]
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
for t_pos in range(min(n_t, 8)):
frame = lat_np[:, t_pos, :, :]
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
if n_t > 8:
interior = lat_np[:, 4:, :, :]
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
print()
# Free transformer models and text embeddings
if is_dual:
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
if cfg_disabled:
del context_cond_low, context_cond_high
else:
del context_cfg_low, context_cfg_high
else:
del single_model, cross_kv
if cfg_disabled:
del context_cond
else:
del context_cfg
del model, kv, context
if context_null is not None:
del context_null
gc.collect(); mx.clear_cache()
# Load VAE and decode
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
t4 = time.time()
vae_path = model_dir / "vae.safetensors"
vae = load_vae_decoder(vae_path, config)
is_wan22_vae = config.vae_z_dim == 48
# Temporal extend: prepend reflected latent frames to the VAE input so that
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
# This gives the first real frame a full temporal receptive field of real data.
# Select tiling configuration
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None]
z = denormalize_latents(z)
if tiling_config is not None:
video = vae.decode_tiled(z, tiling_config)
else:
video = vae(z)
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [T', H', W', 3]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
else:
if tiling_config is not None:
video = vae.decode_tiled(latents[None], tiling_config)
else:
video = vae.decode(latents[None])
mx.eval(video)
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
video = np.array(video[0]) # [3, T', H, W]
video = (video + 1.0) / 2.0
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
# Trim first N temporal chunks if requested (avoids first-frame artifacts)
if trim_first_frames > 0:
trim_pixels = trim_first_frames * 4
video = video[trim_pixels:]
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
save_video(video, output_path, fps=config.sample_fps)
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--image", type=str, default=None,
help="Path to input image for I2V (omit for T2V mode)")
parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
parser.add_argument(
"--scheduler", type=str, default="unipc",
choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
)
parser.add_argument(
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
)
parser.add_argument(
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="VAE tiling mode to reduce memory during decoding (default: auto)",
)
parser.add_argument(
"--no-compile", action="store_true",
help="Disable mx.compile on models (for debugging)",
)
parser.add_argument(
"--trim-first-frames", type=int, default=0, metavar="N",
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
"Default: 0 (disabled)",
)
parser.add_argument(
"--debug-latents", action="store_true",
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
)
args = parser.parse_args()
# Parse guide scale
guide_scale = None
if args.guide_scale is not None:
parts = [float(x) for x in args.guide_scale.split(",")]
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
neg_prompt = args.negative_prompt
if args.no_negative_prompt:
neg_prompt = ""
# Parse LoRA configs: convert [path, strength_str] → (path, float)
def _parse_lora_args(lora_list):
if not lora_list:
return None
return [(path, float(strength)) for path, strength in lora_list]
generate_video(
model_dir=args.model_dir,
prompt=args.prompt,
negative_prompt=neg_prompt,
image=args.image,
width=args.width,
height=args.height,
num_frames=args.num_frames,
steps=args.steps,
guide_scale=guide_scale,
shift=args.shift,
seed=args.seed,
output_path=args.output_path,
scheduler=args.scheduler,
loras=_parse_lora_args(args.lora),
loras_high=_parse_lora_args(args.lora_high),
loras_low=_parse_lora_args(args.lora_low),
tiling=args.tiling,
no_compile=args.no_compile,
trim_first_frames=args.trim_first_frames,
debug_latents=args.debug_latents,
)
if __name__ == "__main__":
main()