Remove deprecated stubs for video conversion and generation; introduce new weight conversion and generation scripts for Wan2.2 models in MLX.
This commit is contained in:
773
mlx_video/models/wan/convert.py
Normal file
773
mlx_video/models/wan/convert.py
Normal file
@@ -0,0 +1,773 @@
|
||||
"""Weight conversion for Wan2.2 models (PyTorch -> MLX)."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.utils
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_torch_weights(path: str) -> Dict[str, mx.array]:
|
||||
"""Load PyTorch .pth weights and convert to MLX arrays.
|
||||
|
||||
Args:
|
||||
path: Path to .pth file
|
||||
|
||||
Returns:
|
||||
Dictionary of MLX arrays
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("PyTorch is required to load .pth weights: pip install torch")
|
||||
|
||||
logging.info(f"Loading weights from {path}")
|
||||
state_dict = torch.load(path, map_location="cpu", weights_only=True)
|
||||
|
||||
weights = {}
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
np_val = value.detach().float().numpy()
|
||||
weights[key] = mx.array(np_val)
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
|
||||
"""Load safetensors weights as MLX arrays.
|
||||
|
||||
Args:
|
||||
path: Path to directory with safetensors files or single file
|
||||
|
||||
Returns:
|
||||
Dictionary of MLX arrays
|
||||
"""
|
||||
path = Path(path)
|
||||
weights = {}
|
||||
if path.is_file():
|
||||
weights = mx.load(str(path))
|
||||
elif path.is_dir():
|
||||
for sf in sorted(path.glob("*.safetensors")):
|
||||
weights.update(mx.load(str(sf)))
|
||||
return weights
|
||||
|
||||
|
||||
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 transformer weight keys to MLX model structure.
|
||||
|
||||
Wan2.2 keys follow the pattern:
|
||||
patch_embedding.weight/bias
|
||||
text_embedding.{0,2}.weight/bias
|
||||
time_embedding.{0,2}.weight/bias
|
||||
time_projection.1.weight/bias
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
||||
blocks.{i}.self_attn.norm_q.weight
|
||||
blocks.{i}.self_attn.norm_k.weight
|
||||
blocks.{i}.norm3.weight/bias (if cross_attn_norm)
|
||||
blocks.{i}.cross_attn.{q,k,v,o}.weight/bias
|
||||
blocks.{i}.cross_attn.norm_q.weight
|
||||
blocks.{i}.cross_attn.norm_k.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.{0,2}.weight/bias
|
||||
blocks.{i}.modulation
|
||||
head.norm.weight
|
||||
head.head.weight/bias
|
||||
head.modulation
|
||||
freqs (buffer)
|
||||
|
||||
MLX model uses:
|
||||
patch_embedding_proj.weight/bias (after patchify reshape)
|
||||
text_embedding_0.weight/bias, text_embedding_1.weight/bias
|
||||
time_embedding_0.weight/bias, time_embedding_1.weight/bias
|
||||
time_projection.weight/bias
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.self_attn.{q,k,v,o}.weight/bias
|
||||
etc.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Patch embedding: Conv3d(16, 5120, (1,2,2)) weight is [O, I, D, H, W]
|
||||
# MLX Linear expects [O, I*D*H*W] after we flatten in patchify
|
||||
if key == "patch_embedding.weight":
|
||||
# Original: [dim, in_dim, 1, 2, 2] -> reshape to [dim, in_dim*1*2*2]
|
||||
value = value.reshape(value.shape[0], -1)
|
||||
new_key = "patch_embedding_proj.weight"
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key == "patch_embedding.bias":
|
||||
new_key = "patch_embedding_proj.bias"
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Text embedding Sequential: 0=Linear, 1=GELU(no params), 2=Linear
|
||||
if key.startswith("text_embedding.0."):
|
||||
new_key = key.replace("text_embedding.0.", "text_embedding_0.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key.startswith("text_embedding.2."):
|
||||
new_key = key.replace("text_embedding.2.", "text_embedding_1.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Time embedding Sequential: 0=Linear, 1=SiLU(no params), 2=Linear
|
||||
if key.startswith("time_embedding.0."):
|
||||
new_key = key.replace("time_embedding.0.", "time_embedding_0.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
if key.startswith("time_embedding.2."):
|
||||
new_key = key.replace("time_embedding.2.", "time_embedding_1.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# Time projection Sequential: 0=SiLU(no params), 1=Linear
|
||||
if key.startswith("time_projection.1."):
|
||||
new_key = key.replace("time_projection.1.", "time_projection.")
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
# FFN: Sequential(Linear, GELU, Linear) -> ffn.{0,2} -> ffn.fc1, ffn.fc2
|
||||
new_key = new_key.replace(".ffn.0.", ".ffn.fc1.")
|
||||
new_key = new_key.replace(".ffn.2.", ".ffn.fc2.")
|
||||
|
||||
# Skip the freqs buffer (we compute it in the model)
|
||||
if key == "freqs":
|
||||
consumed.add(key)
|
||||
continue
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed transformer weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_wan_t5_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 T5 encoder weight keys to MLX T5Encoder structure.
|
||||
|
||||
Wan2.2 T5 keys:
|
||||
token_embedding.weight
|
||||
pos_embedding.embedding.weight (if shared_pos)
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.attn.{q,k,v,o}.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.gate.0.weight (gate linear)
|
||||
blocks.{i}.ffn.fc1.weight
|
||||
blocks.{i}.ffn.fc2.weight
|
||||
blocks.{i}.pos_embedding.embedding.weight (if not shared_pos)
|
||||
norm.weight
|
||||
|
||||
MLX T5Encoder structure:
|
||||
token_embedding.weight
|
||||
blocks.{i}.norm1.weight
|
||||
blocks.{i}.attn.{q,k,v,o}.weight
|
||||
blocks.{i}.norm2.weight
|
||||
blocks.{i}.ffn.gate_proj.weight (mapped from gate.0)
|
||||
blocks.{i}.ffn.fc1.weight
|
||||
blocks.{i}.ffn.fc2.weight
|
||||
blocks.{i}.pos_embedding.embedding.weight
|
||||
norm.weight
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Map gate.0 -> gate_proj (the GELU is a separate module, not a parameter)
|
||||
new_key = new_key.replace(".ffn.gate.0.", ".ffn.gate_proj.")
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed T5 weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_wan_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Convert Wan2.2 VAE weight keys to MLX WanVAE structure.
|
||||
|
||||
Handles Conv3d and Conv2d weight transpositions for MLX format.
|
||||
"""
|
||||
sanitized = {}
|
||||
consumed = set()
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle Conv3d: PyTorch [O, I, D, H, W] -> MLX CausalConv3d weight [O, D, H, W, I]
|
||||
if "weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d: PyTorch [O, I, H, W] -> MLX [O, H, W, I]
|
||||
if "weight" in key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
# Map decoder keys to MLX decoder structure
|
||||
# Wan2.2 uses encoder/decoder with downsamples/upsamples
|
||||
# Need to adapt naming for our simplified structure
|
||||
|
||||
sanitized[new_key] = value
|
||||
consumed.add(key)
|
||||
|
||||
unconsumed = set(weights.keys()) - consumed
|
||||
if unconsumed:
|
||||
logger.warning("Unconsumed VAE weight keys: %s", sorted(unconsumed))
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def _load_lora_configs(
|
||||
lora_configs: List[Tuple[str, float]],
|
||||
) -> Dict[str, list]:
|
||||
"""Load LoRA weights from config tuples, returning module_to_loras dict.
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
from mlx_video.generate_wan import Colors
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
|
||||
configs = []
|
||||
for lora_path, strength in lora_configs:
|
||||
try:
|
||||
config = LoRAConfig(path=lora_path, strength=strength)
|
||||
configs.append(config)
|
||||
print(f" - {Path(lora_path).name} (strength: {strength})")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}Error loading LoRA {lora_path}: {e}{Colors.RESET}")
|
||||
raise
|
||||
|
||||
module_to_loras = load_multiple_loras(configs)
|
||||
|
||||
if not module_to_loras:
|
||||
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
|
||||
|
||||
return module_to_loras
|
||||
|
||||
|
||||
def load_and_apply_loras(
|
||||
model_weights: Dict[str, mx.array],
|
||||
lora_configs: Optional[List[Tuple[str, float]]] = None,
|
||||
verbose: bool = False,
|
||||
quantization_bits: int = 0,
|
||||
) -> Dict[str, mx.array]:
|
||||
"""Load and apply LoRA weights to model weights by merging into weight dict.
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
from mlx_video.generate_wan import Colors
|
||||
|
||||
if not lora_configs:
|
||||
return model_weights
|
||||
|
||||
module_to_loras = _load_lora_configs(lora_configs)
|
||||
if not module_to_loras:
|
||||
return model_weights
|
||||
|
||||
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
|
||||
if verbose:
|
||||
print(f" Model has {len(model_weights)} weight keys")
|
||||
|
||||
modified_weights = apply_loras_to_weights(
|
||||
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
|
||||
)
|
||||
|
||||
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
|
||||
|
||||
return modified_weights
|
||||
|
||||
|
||||
def convert_wan_checkpoint(
|
||||
checkpoint_dir: str,
|
||||
output_dir: str,
|
||||
dtype: str = "bfloat16",
|
||||
model_version: str = "auto",
|
||||
quantize: bool = False,
|
||||
bits: int = 4,
|
||||
group_size: int = 64,
|
||||
):
|
||||
"""Convert a Wan2.1 or Wan2.2 checkpoint directory to MLX format.
|
||||
|
||||
Wan2.2 expected structure:
|
||||
checkpoint_dir/
|
||||
models_t5_umt5-xxl-enc-bf16.pth
|
||||
Wan2.1_VAE.pth
|
||||
low_noise_model/ (safetensors)
|
||||
high_noise_model/ (safetensors)
|
||||
|
||||
Wan2.1 expected structure:
|
||||
checkpoint_dir/
|
||||
models_t5_umt5-xxl-enc-bf16.pth
|
||||
Wan2.1_VAE.pth
|
||||
diffusion_pytorch_model*.safetensors (single model)
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to Wan checkpoint directory
|
||||
output_dir: Path to output MLX model directory
|
||||
dtype: Target dtype
|
||||
model_version: "2.1", "2.2", or "auto" (detect from directory)
|
||||
quantize: Whether to quantize the transformer weights
|
||||
bits: Quantization bits (4 or 8)
|
||||
group_size: Quantization group size (32, 64, or 128)
|
||||
"""
|
||||
import json
|
||||
|
||||
checkpoint_dir = Path(checkpoint_dir)
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dtype_map = {
|
||||
"float16": mx.float16,
|
||||
"float32": mx.float32,
|
||||
"bfloat16": mx.bfloat16,
|
||||
}
|
||||
target_dtype = dtype_map.get(dtype, mx.bfloat16)
|
||||
|
||||
# Auto-detect version
|
||||
if model_version == "auto":
|
||||
if (checkpoint_dir / "low_noise_model").exists():
|
||||
model_version = "2.2"
|
||||
elif (checkpoint_dir / "Wan2.2_VAE.pth").exists():
|
||||
model_version = "2.2"
|
||||
else:
|
||||
model_version = "2.1"
|
||||
print(f"Auto-detected Wan{model_version} checkpoint")
|
||||
|
||||
is_dual = (checkpoint_dir / "low_noise_model").exists()
|
||||
|
||||
if is_dual:
|
||||
# Wan2.2: Convert dual transformer models
|
||||
low_noise_path = checkpoint_dir / "low_noise_model"
|
||||
if low_noise_path.exists():
|
||||
print("Converting low-noise transformer...")
|
||||
weights = load_safetensors_weights(str(low_noise_path))
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "low_noise_model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
|
||||
high_noise_path = checkpoint_dir / "high_noise_model"
|
||||
if high_noise_path.exists():
|
||||
print("Converting high-noise transformer...")
|
||||
weights = load_safetensors_weights(str(high_noise_path))
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "high_noise_model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
else:
|
||||
# Wan2.1: Convert single transformer model
|
||||
# Try safetensors in the checkpoint dir itself
|
||||
print("Converting transformer (single model)...")
|
||||
weights = load_safetensors_weights(str(checkpoint_dir))
|
||||
if not weights:
|
||||
# Fallback: look for .pth files
|
||||
for pth in sorted(checkpoint_dir.glob("*.pth")):
|
||||
if "t5" not in pth.name.lower() and "vae" not in pth.name.lower():
|
||||
print(f" Loading from {pth.name}...")
|
||||
weights = load_torch_weights(str(pth))
|
||||
break
|
||||
if weights:
|
||||
weights = sanitize_wan_transformer_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "model.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
else:
|
||||
print(" Warning: No transformer weights found!")
|
||||
|
||||
# Save config — detect model size from source config.json or transformer weights
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
def _detect_config():
|
||||
"""Detect config from source config.json or transformer weight shapes."""
|
||||
if is_dual:
|
||||
# Check source config.json for model_type (I2V vs T2V)
|
||||
src_cfg_path = checkpoint_dir / "high_noise_model" / "config.json"
|
||||
if src_cfg_path.exists():
|
||||
with open(src_cfg_path) as f:
|
||||
src_config = json.load(f)
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
if src_model_type == "i2v" or src_config.get("in_dim") == 36:
|
||||
return WanModelConfig.wan22_i2v_14b()
|
||||
return WanModelConfig.wan22_t2v_14b()
|
||||
|
||||
# Try reading source config.json first (most reliable)
|
||||
src_cfg_path = checkpoint_dir / "config.json"
|
||||
src_config = None
|
||||
if src_cfg_path.exists():
|
||||
with open(src_cfg_path) as f:
|
||||
src_config = json.load(f)
|
||||
|
||||
if src_config and "dim" in src_config:
|
||||
src_dim = src_config.get("dim", 5120)
|
||||
src_in_dim = src_config.get("in_dim", 16)
|
||||
src_out_dim = src_config.get("out_dim", 16)
|
||||
src_ffn_dim = src_config.get("ffn_dim", 13824)
|
||||
src_num_heads = src_config.get("num_heads", 40)
|
||||
src_num_layers = src_config.get("num_layers", 40)
|
||||
src_model_type = src_config.get("model_type", "t2v")
|
||||
src_text_len = src_config.get("text_len", 512)
|
||||
|
||||
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
|
||||
f"heads={src_num_heads}, type={src_model_type}")
|
||||
|
||||
# Use preset for known TI2V 5B configuration
|
||||
if src_model_type == "ti2v" and src_dim == 3072:
|
||||
return WanModelConfig.wan22_ti2v_5b()
|
||||
|
||||
is_22 = model_version == "2.2"
|
||||
|
||||
# Wan2.2 uses different VAE with z_dim=48 and stride (4,16,16)
|
||||
vae_z = 48 if is_22 else 16
|
||||
vae_s = (4, 16, 16) if is_22 else (4, 8, 8)
|
||||
fps = 24 if is_22 else 16
|
||||
|
||||
return WanModelConfig(
|
||||
model_type=src_model_type,
|
||||
model_version=model_version,
|
||||
dim=src_dim,
|
||||
ffn_dim=src_ffn_dim,
|
||||
in_dim=src_in_dim,
|
||||
out_dim=src_out_dim,
|
||||
num_heads=src_num_heads,
|
||||
num_layers=src_num_layers,
|
||||
text_len=src_text_len,
|
||||
vae_z_dim=vae_z,
|
||||
vae_stride=vae_s,
|
||||
dual_model=False,
|
||||
boundary=0.0,
|
||||
sample_shift=5.0,
|
||||
sample_steps=50,
|
||||
sample_guide_scale=5.0,
|
||||
sample_fps=fps,
|
||||
)
|
||||
|
||||
# Fallback: detect from saved transformer weight shapes
|
||||
saved_model = output_dir / "model.safetensors"
|
||||
if saved_model.exists():
|
||||
det_weights = mx.load(str(saved_model))
|
||||
dim = None
|
||||
for k, v in det_weights.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
dim = v.shape[0]
|
||||
break
|
||||
del det_weights
|
||||
if dim is not None and dim <= 2048:
|
||||
print(f" Auto-detected 1.3B model (dim={dim})")
|
||||
return WanModelConfig.wan21_t2v_1_3b()
|
||||
|
||||
return WanModelConfig.wan21_t2v_14b()
|
||||
|
||||
config = _detect_config()
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
print(f" Saved config to {config_path}")
|
||||
|
||||
# Convert T5 encoder
|
||||
t5_path = checkpoint_dir / "models_t5_umt5-xxl-enc-bf16.pth"
|
||||
if t5_path.exists():
|
||||
print("Converting T5 encoder...")
|
||||
weights = load_torch_weights(str(t5_path))
|
||||
weights = sanitize_wan_t5_weights(weights)
|
||||
weights = {k: v.astype(target_dtype) for k, v in weights.items()}
|
||||
out_path = output_dir / "t5_encoder.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path}")
|
||||
|
||||
# Convert VAE (check both naming conventions)
|
||||
vae_path = checkpoint_dir / "Wan2.1_VAE.pth"
|
||||
is_wan22_vae = False
|
||||
if not vae_path.exists():
|
||||
vae_path = checkpoint_dir / "Wan2.2_VAE.pth"
|
||||
is_wan22_vae = True
|
||||
if vae_path.exists():
|
||||
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
|
||||
else:
|
||||
weights = sanitize_wan_vae_weights(weights)
|
||||
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
|
||||
# float32 (dtype=torch.float). Saving in bfloat16 loses precision
|
||||
# that cannot be recovered by upcasting at load time.
|
||||
weights = {k: v.astype(mx.float32) for k, v in weights.items()}
|
||||
out_path = output_dir / "vae.safetensors"
|
||||
mx.save_safetensors(str(out_path), weights)
|
||||
print(f" Saved {len(weights)} weight tensors to {out_path} (float32)")
|
||||
|
||||
# Quantize transformer weights if requested
|
||||
if quantize:
|
||||
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
|
||||
|
||||
print(f"\nConversion complete! Output: {output_dir}")
|
||||
|
||||
|
||||
def _quantize_predicate(path: str, module) -> bool:
|
||||
"""Return True for layers that should be quantized.
|
||||
|
||||
Targets heavyweight Linear layers in attention and FFN blocks.
|
||||
Skips embeddings, norms, head, and modulation (small, precision-sensitive).
|
||||
"""
|
||||
if not hasattr(module, "to_quantized"):
|
||||
return False
|
||||
# Quantize attention Q/K/V/O and FFN fc1/fc2
|
||||
quantize_patterns = (
|
||||
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
|
||||
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
|
||||
".ffn.fc1", ".ffn.fc2",
|
||||
)
|
||||
return any(path.endswith(p) for p in quantize_patterns)
|
||||
|
||||
|
||||
def _quantize_saved_model(
|
||||
output_dir: Path,
|
||||
config,
|
||||
is_dual: bool,
|
||||
bits: int,
|
||||
group_size: int,
|
||||
source_dir: Path = None,
|
||||
):
|
||||
"""Load saved bf16 model, quantize, and re-save.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to write quantized weights to.
|
||||
config: WanModelConfig for creating the model.
|
||||
is_dual: Whether this is a dual-expert model.
|
||||
bits: Quantization bits.
|
||||
group_size: Quantization group size.
|
||||
source_dir: Directory to read bf16 weights from. Defaults to output_dir.
|
||||
"""
|
||||
import json
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
|
||||
if source_dir is None:
|
||||
source_dir = output_dir
|
||||
|
||||
model_names = []
|
||||
if is_dual:
|
||||
for name in ["low_noise_model.safetensors", "high_noise_model.safetensors"]:
|
||||
if (source_dir / name).exists():
|
||||
model_names.append(name)
|
||||
else:
|
||||
if (source_dir / "model.safetensors").exists():
|
||||
model_names.append("model.safetensors")
|
||||
|
||||
for name in model_names:
|
||||
print(f" Quantizing {name}...")
|
||||
model = WanModel(config)
|
||||
weights = mx.load(str(source_dir / name))
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
mx.eval(model.parameters())
|
||||
del weights
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Apply quantization to targeted layers
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
class_predicate=lambda path, m: _quantize_predicate(path, m),
|
||||
)
|
||||
|
||||
# Save quantized weights
|
||||
weights_dict = dict(mlx.utils.tree_flatten(model.parameters()))
|
||||
|
||||
# Validate: check for NaN/Inf in bias tensors (corruption canary)
|
||||
bad_keys = []
|
||||
for k, v in weights_dict.items():
|
||||
if k.endswith(".bias") and not k.endswith(".biases"):
|
||||
mx.eval(v)
|
||||
if mx.any(mx.isnan(v)).item() or mx.any(mx.isinf(v)).item():
|
||||
bad_keys.append(k)
|
||||
if bad_keys:
|
||||
raise RuntimeError(
|
||||
f"Quantization produced corrupted weights in {model_path.name}: "
|
||||
f"{len(bad_keys)} bias tensors contain NaN/Inf "
|
||||
f"(e.g. {bad_keys[0]}). Try re-running with more available memory."
|
||||
)
|
||||
|
||||
mx.save_safetensors(str(output_dir / name), weights_dict)
|
||||
n_quantized = sum(1 for k in weights_dict if ".scales" in k)
|
||||
print(f" {n_quantized} layers quantized, {len(weights_dict)} tensors saved")
|
||||
|
||||
# Free model before processing next file
|
||||
del model, weights_dict
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Update config.json with quantization metadata
|
||||
config_path = output_dir / "config.json"
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
cfg["quantization"] = {
|
||||
"group_size": group_size,
|
||||
"bits": bits,
|
||||
}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
print(f" Updated config.json with quantization metadata")
|
||||
|
||||
|
||||
def quantize_mlx_model(
|
||||
mlx_model_dir: str,
|
||||
output_dir: str,
|
||||
bits: int = 4,
|
||||
group_size: int = 64,
|
||||
):
|
||||
"""Quantize an already-converted MLX model (skips PyTorch conversion).
|
||||
|
||||
Args:
|
||||
mlx_model_dir: Path to existing MLX model directory (bf16/fp16).
|
||||
output_dir: Path to output quantized model directory.
|
||||
bits: Quantization bits (4 or 8).
|
||||
group_size: Quantization group size (32, 64, or 128).
|
||||
"""
|
||||
import json
|
||||
import shutil
|
||||
|
||||
src = Path(mlx_model_dir)
|
||||
dst = Path(output_dir)
|
||||
|
||||
config_path = src / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"No config.json found in {src}")
|
||||
|
||||
with open(config_path) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
if cfg.get("quantization"):
|
||||
raise ValueError(
|
||||
f"Model at {src} is already quantized "
|
||||
f"({cfg['quantization']['bits']}-bit). Use a bf16/fp16 source."
|
||||
)
|
||||
|
||||
# Detect dual vs single expert
|
||||
is_dual = (src / "low_noise_model.safetensors").exists() and (
|
||||
src / "high_noise_model.safetensors"
|
||||
).exists()
|
||||
|
||||
# Build model config
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
|
||||
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**config_dict)
|
||||
|
||||
# Copy non-transformer files to output dir (skip large model weights)
|
||||
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
|
||||
if dst.resolve() != src.resolve():
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for f in src.iterdir():
|
||||
if f.is_file() and f.name not in transformer_files:
|
||||
shutil.copy2(f, dst / f.name)
|
||||
print(f"Copied non-transformer files from {src} to {dst}")
|
||||
|
||||
print(f"Quantizing transformer weights ({bits}-bit, group_size={group_size})...")
|
||||
_quantize_saved_model(dst, config, is_dual, bits, group_size, source_dir=src)
|
||||
|
||||
print(f"\nQuantization complete! Output: {dst}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Wan model to MLX format")
|
||||
parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to Wan checkpoint directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="wan_mlx_model",
|
||||
help="Output path for MLX model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["float16", "float32", "bfloat16"],
|
||||
default="bfloat16",
|
||||
help="Target dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-version",
|
||||
type=str,
|
||||
choices=["2.1", "2.2", "auto"],
|
||||
default="auto",
|
||||
help="Wan model version (auto-detect by default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize",
|
||||
action="store_true",
|
||||
help="Quantize transformer weights for faster inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize-only",
|
||||
action="store_true",
|
||||
help="Quantize an already-converted MLX model (skips PyTorch conversion)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bits",
|
||||
type=int,
|
||||
choices=[4, 8],
|
||||
default=4,
|
||||
help="Quantization bits (default: 4)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
choices=[32, 64, 128],
|
||||
default=64,
|
||||
help="Quantization group size (default: 64)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quantize_only:
|
||||
quantize_mlx_model(
|
||||
args.checkpoint_dir, args.output_dir,
|
||||
bits=args.bits, group_size=args.group_size,
|
||||
)
|
||||
else:
|
||||
convert_wan_checkpoint(
|
||||
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
|
||||
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
|
||||
)
|
||||
828
mlx_video/models/wan/generate.py
Normal file
828
mlx_video/models/wan/generate.py
Normal file
@@ -0,0 +1,828 @@
|
||||
"""Wan2.2 Text-to-Video generation pipeline for MLX."""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan.loading import (
|
||||
_clean_text,
|
||||
encode_text,
|
||||
load_t5_encoder,
|
||||
load_vae_decoder,
|
||||
load_vae_encoder,
|
||||
load_wan_model,
|
||||
)
|
||||
from mlx_video.models.wan.postprocess import save_video
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output."""
|
||||
|
||||
CYAN = "\033[96m"
|
||||
BLUE = "\033[94m"
|
||||
GREEN = "\033[92m"
|
||||
YELLOW = "\033[93m"
|
||||
RED = "\033[91m"
|
||||
MAGENTA = "\033[95m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# Backward-compat alias (tests and external code may use the old name)
|
||||
_build_i2v_mask = build_i2v_mask
|
||||
|
||||
|
||||
def _best_output_size(w, h, dw, dh, max_area):
|
||||
"""Compute the best output resolution that fits within max_area while
|
||||
preserving the input aspect ratio and satisfying alignment constraints.
|
||||
Matches the reference implementation's best_output_size().
|
||||
"""
|
||||
ratio = w / h
|
||||
ow = (max_area * ratio) ** 0.5
|
||||
oh = max_area / ow
|
||||
|
||||
# Option 1: process width first
|
||||
ow1 = int(ow // dw * dw)
|
||||
oh1 = int(max_area / ow1 // dh * dh)
|
||||
ratio1 = ow1 / oh1
|
||||
|
||||
# Option 2: process height first
|
||||
oh2 = int(oh // dh * dh)
|
||||
ow2 = int(max_area / oh2 // dw * dw)
|
||||
ratio2 = ow2 / oh2
|
||||
|
||||
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
|
||||
return ow1, oh1
|
||||
return ow2, oh2
|
||||
|
||||
|
||||
def generate_video(
|
||||
model_dir: str,
|
||||
prompt: str,
|
||||
negative_prompt: str | None = None,
|
||||
image: str | None = None,
|
||||
width: int = 1280,
|
||||
height: int = 704,
|
||||
num_frames: int = 81,
|
||||
steps: int = None,
|
||||
guide_scale: str | float | tuple = None,
|
||||
shift: float = None,
|
||||
seed: int = -1,
|
||||
output_path: str = "output.mp4",
|
||||
scheduler: str = "unipc",
|
||||
loras: list | None = None,
|
||||
loras_high: list | None = None,
|
||||
loras_low: list | None = None,
|
||||
tiling: str = "auto",
|
||||
no_compile: bool = False,
|
||||
trim_first_frames: int = 0,
|
||||
debug_latents: bool = False,
|
||||
):
|
||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||
|
||||
Args:
|
||||
model_dir: Path to converted MLX model directory
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt (None = use config default, "" = no negative prompt)
|
||||
image: Path to input image for I2V (None = T2V mode)
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames (must be 4n+1)
|
||||
steps: Number of diffusion steps (None = use config default)
|
||||
guide_scale: Guidance scale: float for single, (low,high) for dual (None = config default)
|
||||
shift: Noise schedule shift (None = use config default)
|
||||
seed: Random seed (-1 for random)
|
||||
output_path: Output video path
|
||||
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
|
||||
loras: Optional list of (path, strength) tuples applied to all models
|
||||
loras_high: Optional list of (path, strength) tuples for high-noise model only
|
||||
loras_low: Optional list of (path, strength) tuples for low-noise model only
|
||||
tiling: Tiling mode for VAE decoding. Options:
|
||||
- "auto": Automatically determine tiling based on video size (default)
|
||||
- "none": Disable tiling
|
||||
- "default", "aggressive", "conservative": Preset tiling configs
|
||||
- "spatial": Spatial tiling only
|
||||
- "temporal": Temporal tiling only
|
||||
no_compile: If True, skip mx.compile on models (useful for debugging)
|
||||
trim_first_frames: Number of temporal latent positions to generate extra
|
||||
and discard from the start. Each position = 4 pixel frames. Use 1
|
||||
to fix first-frame artifacts on 14B models (generates 4 extra frames,
|
||||
discards first 4). Use 2 for more aggressive trimming. Default: 0.
|
||||
debug_latents: If True, print per-temporal-position latent statistics
|
||||
after denoising for diagnosing first-frame artifacts.
|
||||
"""
|
||||
import json
|
||||
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan.scheduler import (
|
||||
FlowDPMPP2MScheduler,
|
||||
FlowMatchEulerScheduler,
|
||||
FlowUniPCScheduler,
|
||||
)
|
||||
|
||||
model_dir = Path(model_dir)
|
||||
|
||||
# Load config from model dir if available, otherwise auto-detect
|
||||
config_path = model_dir / "config.json"
|
||||
quantization = None
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
# Extract quantization config (not a model config field)
|
||||
quantization = config_dict.pop("quantization", None)
|
||||
# Handle tuple fields stored as lists in JSON
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**{
|
||||
k: v for k, v in config_dict.items()
|
||||
if k in WanModelConfig.__dataclass_fields__
|
||||
})
|
||||
else:
|
||||
# Auto-detect: dual model files → 2.2, single model → 2.1
|
||||
if (model_dir / "low_noise_model.safetensors").exists():
|
||||
config = WanModelConfig.wan22_t2v_14b()
|
||||
else:
|
||||
# Detect 1.3B vs 14B from weight shapes
|
||||
model_path = model_dir / "model.safetensors"
|
||||
if model_path.exists():
|
||||
probe = mx.load(str(model_path), return_metadata=False)
|
||||
for k, v in probe.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
dim = v.shape[0]
|
||||
if dim <= 2048:
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
break
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
del probe
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
|
||||
is_dual = config.dual_model
|
||||
is_i2v = image is not None
|
||||
|
||||
# Validate config against actual weights (handles mismatched config.json)
|
||||
if not is_dual:
|
||||
model_path = model_dir / "model.safetensors"
|
||||
if model_path.exists():
|
||||
probe = mx.load(str(model_path), return_metadata=False)
|
||||
for k, v in probe.items():
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
actual_dim = v.shape[0]
|
||||
if actual_dim != config.dim:
|
||||
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
|
||||
if actual_dim <= 2048:
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
else:
|
||||
config = WanModelConfig.wan21_t2v_14b()
|
||||
break
|
||||
del probe
|
||||
|
||||
# Auto-correct Wan2.2 VAE params from stale configs
|
||||
if config.in_dim == 48 and config.vae_z_dim != 48:
|
||||
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
|
||||
config = WanModelConfig(**{
|
||||
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
|
||||
"vae_z_dim": 48,
|
||||
"vae_stride": (4, 16, 16),
|
||||
"sample_fps": 24,
|
||||
})
|
||||
|
||||
# Apply defaults from config if not overridden
|
||||
if steps is None:
|
||||
steps = config.sample_steps
|
||||
if shift is None:
|
||||
shift = config.sample_shift
|
||||
if guide_scale is None:
|
||||
guide_scale = config.sample_guide_scale
|
||||
|
||||
# Normalize guide_scale
|
||||
if isinstance(guide_scale, (int, float)):
|
||||
guide_scale = float(guide_scale)
|
||||
elif isinstance(guide_scale, str):
|
||||
parts = [float(x) for x in guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
|
||||
if isinstance(guide_scale, tuple):
|
||||
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
|
||||
else:
|
||||
cfg_disabled = guide_scale <= 1.0
|
||||
|
||||
# Validate frame count
|
||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||
|
||||
gen_frames = num_frames
|
||||
if trim_first_frames > 0:
|
||||
gen_frames = num_frames + trim_first_frames * 4
|
||||
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
pipeline_str = "Image-to-Video" if is_i2v else "Text-to-Video"
|
||||
# Resolve negative prompt: explicit user value > config default
|
||||
# The official Wan2.2 uses a Chinese negative prompt (config.sample_neg_prompt)
|
||||
# that prevents oversaturation, artifacts, and comic look. We use it by default.
|
||||
# Text cleaning (_clean_text) normalizes fullwidth chars to match official tokenization.
|
||||
if negative_prompt is None:
|
||||
neg_prompt_resolved = config.sample_neg_prompt
|
||||
else:
|
||||
neg_prompt_resolved = negative_prompt
|
||||
print(f"{Colors.CYAN}{'='*60}")
|
||||
print(f" {version_str} {pipeline_str} Generation (MLX, {mode_str})")
|
||||
print(f"{'='*60}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Prompt: {prompt}")
|
||||
if is_i2v:
|
||||
print(f" Image: {image}")
|
||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
||||
if cfg_disabled:
|
||||
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
||||
print(f"{Colors.RESET}")
|
||||
|
||||
# Seed
|
||||
if seed < 0:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
mx.random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
print(f"{Colors.DIM} Seed: {seed}{Colors.RESET}")
|
||||
|
||||
# Align dimensions to patch_size * vae_stride (required for patchify)
|
||||
vae_stride = config.vae_stride
|
||||
patch_size = config.patch_size
|
||||
align_h = patch_size[1] * vae_stride[1] # e.g. 2*16=32
|
||||
align_w = patch_size[2] * vae_stride[2]
|
||||
if height % align_h != 0 or width % align_w != 0:
|
||||
old_h, old_w = height, width
|
||||
height = (height // align_h) * align_h
|
||||
width = (width // align_w) * align_w
|
||||
if height == 0:
|
||||
height = align_h
|
||||
if width == 0:
|
||||
width = align_w
|
||||
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
||||
|
||||
# Enforce max_area constraint (model-specific resolution limit)
|
||||
if config.max_area > 0 and height * width > config.max_area:
|
||||
old_h, old_w = height, width
|
||||
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
|
||||
print(
|
||||
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
||||
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Compute target latent shape
|
||||
z_dim = config.vae_z_dim
|
||||
t_latent = (gen_frames - 1) // vae_stride[0] + 1
|
||||
h_latent = height // vae_stride[1]
|
||||
w_latent = width // vae_stride[2]
|
||||
target_shape = (z_dim, t_latent, h_latent, w_latent)
|
||||
|
||||
# Sequence length for transformer
|
||||
seq_len = math.ceil(
|
||||
(h_latent * w_latent) / (patch_size[1] * patch_size[2]) * t_latent
|
||||
)
|
||||
|
||||
print(f"{Colors.DIM} Latent shape: {target_shape}")
|
||||
print(f" Sequence length: {seq_len}{Colors.RESET}")
|
||||
|
||||
# Load T5 encoder
|
||||
t1 = time.time()
|
||||
print(f"\n{Colors.BLUE}Loading T5 encoder...{Colors.RESET}")
|
||||
t5_path = model_dir / "t5_encoder.safetensors"
|
||||
t5_encoder = load_t5_encoder(t5_path, config)
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
|
||||
# Encode prompts
|
||||
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
||||
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
||||
if cfg_disabled:
|
||||
context_null = None
|
||||
mx.eval(context)
|
||||
else:
|
||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
||||
mx.eval(context, context_null)
|
||||
|
||||
# Free T5 from memory
|
||||
del t5_encoder
|
||||
gc.collect(); mx.clear_cache()
|
||||
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
# I2V: encode image to latent space
|
||||
z_img = None
|
||||
i2v_mask = None
|
||||
i2v_mask_tokens = None
|
||||
y_i2v = None
|
||||
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
|
||||
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
|
||||
if is_i2v:
|
||||
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
|
||||
t_img = time.time()
|
||||
|
||||
vae_path = model_dir / "vae.safetensors"
|
||||
|
||||
if is_i2v_channel_concat:
|
||||
# I2V-14B: encode full video (first frame = image, rest = zeros)
|
||||
# and construct y tensor with mask + encoded latents
|
||||
from PIL import Image
|
||||
|
||||
img = Image.open(image).convert("RGB")
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
||||
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
|
||||
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
|
||||
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
|
||||
video = mx.concatenate([
|
||||
img_chw[:, None, :, :],
|
||||
mx.zeros((3, num_frames - 1, height, width)),
|
||||
], axis=1)
|
||||
|
||||
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
|
||||
mx.eval(z_video)
|
||||
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
|
||||
|
||||
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
|
||||
msk = mx.concatenate([
|
||||
mx.repeat(msk[:, :1], 4, axis=1),
|
||||
msk[:, 1:],
|
||||
], axis=1)
|
||||
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
|
||||
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
||||
mx.eval(y_i2v)
|
||||
|
||||
del vae_enc, img_arr, img_chw, video, z_video, msk
|
||||
else:
|
||||
# TI2V-5B: encode single image, blend with noise via mask
|
||||
img_tensor = preprocess_image(image, width, height)
|
||||
mx.eval(img_tensor)
|
||||
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||
mx.eval(z_img)
|
||||
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
|
||||
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||
|
||||
del vae_enc, img_tensor
|
||||
|
||||
gc.collect(); mx.clear_cache()
|
||||
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||
|
||||
# Load transformer models
|
||||
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||
if quantization:
|
||||
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
|
||||
t2 = time.time()
|
||||
|
||||
# Merge per-model LoRAs with shared LoRAs
|
||||
_loras_low = (loras or []) + (loras_low or []) or None
|
||||
_loras_high = (loras or []) + (loras_high or []) or None
|
||||
_loras_single = loras
|
||||
|
||||
if is_dual:
|
||||
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
|
||||
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
|
||||
else:
|
||||
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
|
||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||
|
||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||
# Each model has its own text_embedding weights, so dual models need separate embeddings
|
||||
if cfg_disabled:
|
||||
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context])
|
||||
context_emb_high = high_noise_model.embed_text([context])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cond_low = context_emb_low[0:1]
|
||||
context_cond_high = context_emb_high[0:1]
|
||||
else:
|
||||
context_emb = single_model.embed_text([context])
|
||||
mx.eval(context_emb)
|
||||
context_cond = context_emb[0:1]
|
||||
else:
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
|
||||
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
|
||||
else:
|
||||
context_emb = single_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
|
||||
|
||||
# Precompute cross-attention K/V caches (constant across all steps)
|
||||
if cfg_disabled:
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cond)
|
||||
mx.eval(cross_kv)
|
||||
else:
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cfg)
|
||||
mx.eval(cross_kv)
|
||||
|
||||
# Precompute RoPE frequencies (grid sizes are constant across all steps)
|
||||
f_grid = t_latent // patch_size[0]
|
||||
h_grid = h_latent // patch_size[1]
|
||||
w_grid = w_latent // patch_size[2]
|
||||
if cfg_disabled:
|
||||
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
|
||||
else:
|
||||
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
|
||||
if is_dual:
|
||||
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
|
||||
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
||||
else:
|
||||
rope_cos_sin = single_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin)
|
||||
|
||||
# Setup scheduler
|
||||
_schedulers = {
|
||||
"euler": FlowMatchEulerScheduler,
|
||||
"dpm++": FlowDPMPP2MScheduler,
|
||||
"unipc": FlowUniPCScheduler,
|
||||
}
|
||||
sched_cls = _schedulers.get(scheduler, FlowUniPCScheduler)
|
||||
sched = sched_cls(num_train_timesteps=config.num_train_timesteps)
|
||||
sched.set_timesteps(steps, shift=shift)
|
||||
|
||||
# Generate initial noise
|
||||
noise = mx.random.normal(target_shape)
|
||||
|
||||
# I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
|
||||
if is_i2v_mask_blend:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
# Boundary for model switching (dual model only)
|
||||
boundary = (config.boundary * config.num_train_timesteps) if is_dual else None
|
||||
|
||||
# Diffusion loop
|
||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||
t3 = time.time()
|
||||
|
||||
# Compile model forward for faster denoising
|
||||
if not no_compile:
|
||||
models_to_compile = (
|
||||
[high_noise_model, low_noise_model] if is_dual else [single_model]
|
||||
)
|
||||
for m in models_to_compile:
|
||||
m._compiled = mx.compile(m)
|
||||
|
||||
# Pre-convert timesteps to Python list to avoid .item() sync each step
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
# Select model, cached K/V, and precomputed RoPE
|
||||
if is_dual:
|
||||
if timestep_val >= boundary:
|
||||
model = high_noise_model
|
||||
kv = cross_kv_high
|
||||
rcs = rope_cos_sin_high
|
||||
else:
|
||||
model = low_noise_model
|
||||
kv = cross_kv_low
|
||||
rcs = rope_cos_sin_low
|
||||
else:
|
||||
model = single_model
|
||||
kv = cross_kv
|
||||
rcs = rope_cos_sin
|
||||
|
||||
# Use compiled forward when available (faster after first trace)
|
||||
_call = getattr(model, '_compiled', model)
|
||||
|
||||
if cfg_disabled:
|
||||
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
t_batch = t_tokens # [1, L]
|
||||
else:
|
||||
t_batch = mx.array([timestep_val])
|
||||
|
||||
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
if is_dual:
|
||||
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
|
||||
else:
|
||||
ctx = context_cond
|
||||
preds = _call(
|
||||
[latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred = preds[0]
|
||||
del preds
|
||||
else:
|
||||
# CFG: batch cond + uncond into single B=2 forward pass
|
||||
if is_dual:
|
||||
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
||||
else:
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
|
||||
else:
|
||||
t_batch = mx.array([timestep_val, timestep_val])
|
||||
|
||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
ctx = context_cfg if not is_dual else (
|
||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
||||
)
|
||||
preds = _call(
|
||||
[latents, latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
del noise_pred_cond, noise_pred_uncond, preds
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
|
||||
# TI2V-5B: re-apply mask to keep first frame frozen
|
||||
if is_i2v_mask_blend:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
||||
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
del noise_pred
|
||||
mx.eval(latents)
|
||||
|
||||
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
|
||||
|
||||
# Diagnostic: per-temporal-position latent statistics
|
||||
if debug_latents:
|
||||
lat_np = np.array(latents) # [C, T, H, W]
|
||||
n_t = lat_np.shape[1]
|
||||
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
|
||||
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
|
||||
for t_pos in range(min(n_t, 8)):
|
||||
frame = lat_np[:, t_pos, :, :]
|
||||
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
|
||||
if n_t > 8:
|
||||
interior = lat_np[:, 4:, :, :]
|
||||
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
|
||||
print()
|
||||
|
||||
# Free transformer models and text embeddings
|
||||
if is_dual:
|
||||
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
||||
if cfg_disabled:
|
||||
del context_cond_low, context_cond_high
|
||||
else:
|
||||
del context_cfg_low, context_cfg_high
|
||||
else:
|
||||
del single_model, cross_kv
|
||||
if cfg_disabled:
|
||||
del context_cond
|
||||
else:
|
||||
del context_cfg
|
||||
del model, kv, context
|
||||
if context_null is not None:
|
||||
del context_null
|
||||
gc.collect(); mx.clear_cache()
|
||||
|
||||
# Load VAE and decode
|
||||
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
|
||||
t4 = time.time()
|
||||
vae_path = model_dir / "vae.safetensors"
|
||||
vae = load_vae_decoder(vae_path, config)
|
||||
|
||||
is_wan22_vae = config.vae_z_dim == 48
|
||||
|
||||
# Temporal extend: prepend reflected latent frames to the VAE input so that
|
||||
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
|
||||
# This gives the first real frame a full temporal receptive field of real data.
|
||||
# Select tiling configuration
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||
|
||||
if tiling == "none":
|
||||
tiling_config = None
|
||||
elif tiling == "auto":
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
elif tiling == "default":
|
||||
tiling_config = TilingConfig.default()
|
||||
elif tiling == "aggressive":
|
||||
tiling_config = TilingConfig.aggressive()
|
||||
elif tiling == "conservative":
|
||||
tiling_config = TilingConfig.conservative()
|
||||
elif tiling == "spatial":
|
||||
tiling_config = TilingConfig.spatial_only()
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
|
||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||
z = latents.transpose(1, 2, 3, 0)[None]
|
||||
z = denormalize_latents(z)
|
||||
if tiling_config is not None:
|
||||
video = vae.decode_tiled(z, tiling_config)
|
||||
else:
|
||||
video = vae(z)
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [T', H', W', 3]
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
if tiling_config is not None:
|
||||
video = vae.decode_tiled(latents[None], tiling_config)
|
||||
else:
|
||||
video = vae.decode(latents[None])
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [3, T', H, W]
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
||||
|
||||
# Trim first N temporal chunks if requested (avoids first-frame artifacts)
|
||||
if trim_first_frames > 0:
|
||||
trim_pixels = trim_first_frames * 4
|
||||
video = video[trim_pixels:]
|
||||
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
|
||||
|
||||
save_video(video, output_path, fps=config.sample_fps)
|
||||
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument("--image", type=str, default=None,
|
||||
help="Path to input image for I2V (omit for T2V mode)")
|
||||
parser.add_argument("--negative-prompt", type=str, default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
|
||||
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
|
||||
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
||||
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
|
||||
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
|
||||
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
|
||||
parser.add_argument(
|
||||
"--scheduler", type=str, default="unipc",
|
||||
choices=["euler", "dpm++", "unipc"],
|
||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiling",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-compile", action="store_true",
|
||||
help="Disable mx.compile on models (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trim-first-frames", type=int, default=0, metavar="N",
|
||||
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
|
||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||
"Default: 0 (disabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-latents", action="store_true",
|
||||
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse guide scale
|
||||
guide_scale = None
|
||||
if args.guide_scale is not None:
|
||||
parts = [float(x) for x in args.guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Handle negative prompt: --no-negative-prompt forces empty, otherwise pass through
|
||||
neg_prompt = args.negative_prompt
|
||||
if args.no_negative_prompt:
|
||||
neg_prompt = ""
|
||||
|
||||
# Parse LoRA configs: convert [path, strength_str] → (path, float)
|
||||
def _parse_lora_args(lora_list):
|
||||
if not lora_list:
|
||||
return None
|
||||
return [(path, float(strength)) for path, strength in lora_list]
|
||||
|
||||
generate_video(
|
||||
model_dir=args.model_dir,
|
||||
prompt=args.prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
image=args.image,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
num_frames=args.num_frames,
|
||||
steps=args.steps,
|
||||
guide_scale=guide_scale,
|
||||
shift=args.shift,
|
||||
seed=args.seed,
|
||||
output_path=args.output_path,
|
||||
scheduler=args.scheduler,
|
||||
loras=_parse_lora_args(args.lora),
|
||||
loras_high=_parse_lora_args(args.lora_high),
|
||||
loras_low=_parse_lora_args(args.lora_low),
|
||||
tiling=args.tiling,
|
||||
no_compile=args.no_compile,
|
||||
trim_first_frames=args.trim_first_frames,
|
||||
debug_latents=args.debug_latents,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user