format
This commit is contained in:
@@ -4,18 +4,15 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
|
||||
from mlx_video.models.wan.loading import (
|
||||
_clean_text,
|
||||
encode_text,
|
||||
load_t5_encoder,
|
||||
load_vae_decoder,
|
||||
@@ -24,6 +21,7 @@ from mlx_video.models.wan.loading import (
|
||||
)
|
||||
from mlx_video.models.wan.postprocess import save_video
|
||||
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output."""
|
||||
|
||||
@@ -37,6 +35,7 @@ class Colors:
|
||||
DIM = "\033[2m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
# Backward-compat alias (tests and external code may use the old name)
|
||||
_build_i2v_mask = build_i2v_mask
|
||||
|
||||
@@ -143,10 +142,13 @@ def generate_video(
|
||||
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
|
||||
if key in config_dict and isinstance(config_dict[key], list):
|
||||
config_dict[key] = tuple(config_dict[key])
|
||||
config = WanModelConfig(**{
|
||||
k: v for k, v in config_dict.items()
|
||||
if k in WanModelConfig.__dataclass_fields__
|
||||
})
|
||||
config = WanModelConfig(
|
||||
**{
|
||||
k: v
|
||||
for k, v in config_dict.items()
|
||||
if k in WanModelConfig.__dataclass_fields__
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Auto-detect: dual model files → 2.2, single model → 2.1
|
||||
if (model_dir / "low_noise_model.safetensors").exists():
|
||||
@@ -182,7 +184,9 @@ def generate_video(
|
||||
if "patch_embedding_proj.weight" in k:
|
||||
actual_dim = v.shape[0]
|
||||
if actual_dim != config.dim:
|
||||
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}"
|
||||
)
|
||||
if actual_dim <= 2048:
|
||||
config = WanModelConfig.wan21_t2v_1_3b()
|
||||
else:
|
||||
@@ -192,13 +196,20 @@ def generate_video(
|
||||
|
||||
# Auto-correct Wan2.2 VAE params from stale configs
|
||||
if config.in_dim == 48 and config.vae_z_dim != 48:
|
||||
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
|
||||
config = WanModelConfig(**{
|
||||
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
|
||||
"vae_z_dim": 48,
|
||||
"vae_stride": (4, 16, 16),
|
||||
"sample_fps": 24,
|
||||
})
|
||||
print(
|
||||
f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}"
|
||||
)
|
||||
config = WanModelConfig(
|
||||
**{
|
||||
**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
},
|
||||
"vae_z_dim": 48,
|
||||
"vae_stride": (4, 16, 16),
|
||||
"sample_fps": 24,
|
||||
}
|
||||
)
|
||||
|
||||
# Apply defaults from config if not overridden
|
||||
if steps is None:
|
||||
@@ -227,7 +238,9 @@ def generate_video(
|
||||
gen_frames = num_frames
|
||||
if trim_first_frames > 0:
|
||||
gen_frames = num_frames + trim_first_frames * 4
|
||||
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}"
|
||||
)
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
@@ -247,10 +260,16 @@ def generate_video(
|
||||
if is_i2v:
|
||||
print(f" Image: {image}")
|
||||
if neg_prompt_resolved and neg_prompt_resolved.strip():
|
||||
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
|
||||
neg_display = (
|
||||
neg_prompt_resolved[:60] + "..."
|
||||
if len(neg_prompt_resolved) > 60
|
||||
else neg_prompt_resolved
|
||||
)
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
||||
print(
|
||||
f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}"
|
||||
)
|
||||
if cfg_disabled:
|
||||
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
||||
print(f"{Colors.RESET}")
|
||||
@@ -275,12 +294,16 @@ def generate_video(
|
||||
height = align_h
|
||||
if width == 0:
|
||||
width = align_w
|
||||
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}"
|
||||
)
|
||||
|
||||
# Enforce max_area constraint (model-specific resolution limit)
|
||||
if config.max_area > 0 and height * width > config.max_area:
|
||||
old_h, old_w = height, width
|
||||
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
|
||||
width, height = _best_output_size(
|
||||
width, height, align_w, align_h, config.max_area
|
||||
)
|
||||
print(
|
||||
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
||||
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
||||
@@ -309,6 +332,7 @@ def generate_video(
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
|
||||
# Encode prompts
|
||||
@@ -318,12 +342,15 @@ def generate_video(
|
||||
context_null = None
|
||||
mx.eval(context)
|
||||
else:
|
||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
||||
context_null = encode_text(
|
||||
t5_encoder, tokenizer, neg_prompt_resolved, config.text_len
|
||||
)
|
||||
mx.eval(context, context_null)
|
||||
|
||||
# Free T5 from memory
|
||||
del t5_encoder
|
||||
gc.collect(); mx.clear_cache()
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
|
||||
# I2V: encode image to latent space
|
||||
@@ -346,18 +373,25 @@ def generate_video(
|
||||
|
||||
img = Image.open(image).convert("RGB")
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
||||
img = img.resize(
|
||||
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
|
||||
)
|
||||
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
|
||||
img_arr = mx.array(
|
||||
np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0
|
||||
) # [H, W, 3]
|
||||
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
|
||||
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
|
||||
video = mx.concatenate([
|
||||
img_chw[:, None, :, :],
|
||||
mx.zeros((3, num_frames - 1, height, width)),
|
||||
], axis=1)
|
||||
video = mx.concatenate(
|
||||
[
|
||||
img_chw[:, None, :, :],
|
||||
mx.zeros((3, num_frames - 1, height, width)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
@@ -367,12 +401,17 @@ def generate_video(
|
||||
|
||||
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
|
||||
)
|
||||
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
|
||||
msk = mx.concatenate([
|
||||
mx.repeat(msk[:, :1], 4, axis=1),
|
||||
msk[:, 1:],
|
||||
], axis=1)
|
||||
msk = mx.concatenate(
|
||||
[
|
||||
mx.repeat(msk[:, :1], 4, axis=1),
|
||||
msk[:, 1:],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
@@ -395,13 +434,16 @@ def generate_video(
|
||||
|
||||
del vae_enc, img_tensor
|
||||
|
||||
gc.collect(); mx.clear_cache()
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||
|
||||
# Load transformer models
|
||||
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
|
||||
if quantization:
|
||||
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}"
|
||||
)
|
||||
t2 = time.time()
|
||||
|
||||
# Merge per-model LoRAs with shared LoRAs
|
||||
@@ -412,10 +454,16 @@ def generate_video(
|
||||
if is_dual:
|
||||
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
|
||||
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
|
||||
low_noise_model = load_wan_model(
|
||||
low_noise_path, config, quantization, loras=_loras_low
|
||||
)
|
||||
high_noise_model = load_wan_model(
|
||||
high_noise_path, config, quantization, loras=_loras_high
|
||||
)
|
||||
else:
|
||||
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
|
||||
single_model = load_wan_model(
|
||||
model_dir / "model.safetensors", config, quantization, loras=_loras_single
|
||||
)
|
||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||
|
||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||
@@ -437,8 +485,12 @@ def generate_video(
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
|
||||
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
|
||||
context_cfg_low = mx.concatenate(
|
||||
[context_emb_low[0:1], context_emb_low[1:2]], axis=0
|
||||
)
|
||||
context_cfg_high = mx.concatenate(
|
||||
[context_emb_high[0:1], context_emb_high[1:2]], axis=0
|
||||
)
|
||||
else:
|
||||
context_emb = single_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
@@ -534,7 +586,7 @@ def generate_video(
|
||||
rcs = rope_cos_sin
|
||||
|
||||
# Use compiled forward when available (faster after first trace)
|
||||
_call = getattr(model, '_compiled', model)
|
||||
_call = getattr(model, "_compiled", model)
|
||||
|
||||
if cfg_disabled:
|
||||
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||
@@ -552,7 +604,9 @@ def generate_video(
|
||||
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
if is_dual:
|
||||
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
|
||||
ctx = (
|
||||
context_cond_high if timestep_val >= boundary else context_cond_low
|
||||
)
|
||||
else:
|
||||
ctx = context_cond
|
||||
preds = _call(
|
||||
@@ -571,7 +625,11 @@ def generate_video(
|
||||
if is_dual:
|
||||
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
||||
else:
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
gs = (
|
||||
guide_scale
|
||||
if isinstance(guide_scale, (int, float))
|
||||
else guide_scale[0]
|
||||
)
|
||||
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
@@ -586,8 +644,10 @@ def generate_video(
|
||||
|
||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
ctx = context_cfg if not is_dual else (
|
||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
||||
ctx = (
|
||||
context_cfg
|
||||
if not is_dual
|
||||
else (context_cfg_high if timestep_val >= boundary else context_cfg_low)
|
||||
)
|
||||
preds = _call(
|
||||
[latents, latents],
|
||||
@@ -618,16 +678,24 @@ def generate_video(
|
||||
if debug_latents:
|
||||
lat_np = np.array(latents) # [C, T, H, W]
|
||||
n_t = lat_np.shape[1]
|
||||
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
|
||||
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
|
||||
print(
|
||||
f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}"
|
||||
)
|
||||
print(
|
||||
f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}"
|
||||
)
|
||||
for t_pos in range(min(n_t, 8)):
|
||||
frame = lat_np[:, t_pos, :, :]
|
||||
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
|
||||
print(
|
||||
f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}"
|
||||
)
|
||||
if n_t > 8:
|
||||
interior = lat_np[:, 4:, :, :]
|
||||
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
|
||||
print(
|
||||
f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Free transformer models and text embeddings
|
||||
@@ -646,7 +714,8 @@ def generate_video(
|
||||
del model, kv, context
|
||||
if context_null is not None:
|
||||
del context_null
|
||||
gc.collect(); mx.clear_cache()
|
||||
gc.collect()
|
||||
mx.clear_cache()
|
||||
|
||||
# Load VAE and decode
|
||||
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
|
||||
@@ -677,13 +746,25 @@ def generate_video(
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}"
|
||||
)
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
spatial_info = (
|
||||
f"{tiling_config.spatial_config.tile_size_in_pixels}px"
|
||||
if tiling_config.spatial_config
|
||||
else "none"
|
||||
)
|
||||
temporal_info = (
|
||||
f"{tiling_config.temporal_config.tile_size_in_frames}f"
|
||||
if tiling_config.temporal_config
|
||||
else "none"
|
||||
)
|
||||
print(
|
||||
f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}"
|
||||
)
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
@@ -718,7 +799,9 @@ def generate_video(
|
||||
if trim_first_frames > 0:
|
||||
trim_pixels = trim_first_frames * 4
|
||||
video = video[trim_pixels:]
|
||||
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
|
||||
print(
|
||||
f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}"
|
||||
)
|
||||
|
||||
save_video(video, output_path, fps=config.sample_fps)
|
||||
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
||||
@@ -727,58 +810,124 @@ def generate_video(
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
|
||||
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument("--image", type=str, default=None,
|
||||
help="Path to input image for I2V (omit for T2V mode)")
|
||||
parser.add_argument("--negative-prompt", type=str, default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
|
||||
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
|
||||
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
||||
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
|
||||
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
|
||||
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
|
||||
parser.add_argument(
|
||||
"--scheduler", type=str, default="unipc",
|
||||
"--model-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to converted MLX model directory",
|
||||
)
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to input image for I2V (omit for T2V mode)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative-prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-negative-prompt",
|
||||
action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width", type=int, default=1280, help="Video width (default: 1280)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=704,
|
||||
help="Video height (default: 704; 720p models use 704)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of diffusion steps (default: from config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide-scale",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Guidance scale: single float or low,high pair",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Noise schedule shift (default: from config)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--output-path", type=str, default="output.mp4", help="Output video path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="unipc",
|
||||
choices=["euler", "dpm++", "unipc"],
|
||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
"--lora",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
"--lora-high",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
"--lora-low",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tiling",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||
choices=[
|
||||
"auto",
|
||||
"none",
|
||||
"default",
|
||||
"aggressive",
|
||||
"conservative",
|
||||
"spatial",
|
||||
"temporal",
|
||||
],
|
||||
help="VAE tiling mode to reduce memory during decoding (default: auto)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-compile", action="store_true",
|
||||
"--no-compile",
|
||||
action="store_true",
|
||||
help="Disable mx.compile on models (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trim-first-frames", type=int, default=0, metavar="N",
|
||||
"--trim-first-frames",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
|
||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||
"Default: 0 (disabled)",
|
||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||
"Default: 0 (disabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-latents", action="store_true",
|
||||
"--debug-latents",
|
||||
action="store_true",
|
||||
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user