feat(wan): Add LoRA with improved quantization pipeline
This commit is contained in:
@@ -43,6 +43,10 @@ def generate_video(
|
||||
seed: int = -1,
|
||||
output_path: str = "output.mp4",
|
||||
scheduler: str = "unipc",
|
||||
loras: list | None = None,
|
||||
loras_high: list | None = None,
|
||||
loras_low: list | None = None,
|
||||
|
||||
):
|
||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||
|
||||
@@ -60,6 +64,10 @@ def generate_video(
|
||||
seed: Random seed (-1 for random)
|
||||
output_path: Output video path
|
||||
scheduler: Solver type: 'euler', 'dpm++', or 'unipc' (default)
|
||||
loras: Optional list of (path, strength) tuples applied to all models
|
||||
loras_high: Optional list of (path, strength) tuples for high-noise model only
|
||||
loras_low: Optional list of (path, strength) tuples for low-noise model only
|
||||
|
||||
"""
|
||||
import json
|
||||
|
||||
@@ -156,6 +164,12 @@ def generate_video(
|
||||
parts = [float(x) for x in guide_scale.split(",")]
|
||||
guide_scale = tuple(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# Detect CFG-disabled mode (guide_scale=1.0 for all models → skip uncond pass for 2x speedup)
|
||||
if isinstance(guide_scale, tuple):
|
||||
cfg_disabled = all(gs <= 1.0 for gs in guide_scale)
|
||||
else:
|
||||
cfg_disabled = guide_scale <= 1.0
|
||||
|
||||
# Validate frame count
|
||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||
|
||||
@@ -181,6 +195,8 @@ def generate_video(
|
||||
print(f" Neg prompt: {neg_display}")
|
||||
print(f" Size: {width}x{height}, Frames: {num_frames}")
|
||||
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
|
||||
if cfg_disabled:
|
||||
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
|
||||
print(f"{Colors.RESET}")
|
||||
|
||||
# Seed
|
||||
@@ -233,8 +249,12 @@ def generate_video(
|
||||
# Encode prompts
|
||||
print(f"{Colors.BLUE}Encoding text...{Colors.RESET}")
|
||||
context = encode_text(t5_encoder, tokenizer, prompt, config.text_len)
|
||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
||||
mx.eval(context, context_null)
|
||||
if cfg_disabled:
|
||||
context_null = None
|
||||
mx.eval(context)
|
||||
else:
|
||||
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
|
||||
mx.eval(context, context_null)
|
||||
|
||||
# Free T5 from memory
|
||||
del t5_encoder
|
||||
@@ -319,48 +339,78 @@ def generate_video(
|
||||
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
|
||||
t2 = time.time()
|
||||
|
||||
# Merge per-model LoRAs with shared LoRAs
|
||||
_loras_low = (loras or []) + (loras_low or []) or None
|
||||
_loras_high = (loras or []) + (loras_high or []) or None
|
||||
_loras_single = loras
|
||||
|
||||
if is_dual:
|
||||
low_noise_path = model_dir / "low_noise_model.safetensors"
|
||||
high_noise_path = model_dir / "high_noise_model.safetensors"
|
||||
low_noise_model = load_wan_model(low_noise_path, config, quantization)
|
||||
high_noise_model = load_wan_model(high_noise_path, config, quantization)
|
||||
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
|
||||
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
|
||||
else:
|
||||
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization)
|
||||
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
|
||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||
|
||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||
# Each model has its own text_embedding weights, so dual models need separate embeddings
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
|
||||
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
|
||||
if cfg_disabled:
|
||||
# No CFG: only compute cond embeddings (B=1 forward pass, 2x faster)
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context])
|
||||
context_emb_high = high_noise_model.embed_text([context])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cond_low = context_emb_low[0:1]
|
||||
context_cond_high = context_emb_high[0:1]
|
||||
else:
|
||||
context_emb = single_model.embed_text([context])
|
||||
mx.eval(context_emb)
|
||||
context_cond = context_emb[0:1]
|
||||
else:
|
||||
context_emb = single_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
|
||||
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
|
||||
else:
|
||||
context_emb = single_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
|
||||
|
||||
# Precompute cross-attention K/V caches (constant across all steps)
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
if cfg_disabled:
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cond_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cond_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cond)
|
||||
mx.eval(cross_kv)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cfg)
|
||||
mx.eval(cross_kv)
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cfg)
|
||||
mx.eval(cross_kv)
|
||||
|
||||
# Precompute RoPE frequencies (grid sizes are constant across all steps)
|
||||
f_grid = t_latent // patch_size[0]
|
||||
h_grid = h_latent // patch_size[1]
|
||||
w_grid = w_latent // patch_size[2]
|
||||
cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
|
||||
if cfg_disabled:
|
||||
rope_grid_sizes = [(f_grid, h_grid, w_grid)]
|
||||
else:
|
||||
rope_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
|
||||
if is_dual:
|
||||
rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes)
|
||||
rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes)
|
||||
rope_cos_sin_low = low_noise_model.prepare_rope(rope_grid_sizes)
|
||||
rope_cos_sin_high = high_noise_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
||||
else:
|
||||
rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes)
|
||||
rope_cos_sin = ref_model.prepare_rope(rope_grid_sizes)
|
||||
mx.eval(rope_cos_sin)
|
||||
|
||||
# Setup scheduler
|
||||
@@ -395,58 +445,86 @@ def generate_video(
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
# Select model, guide scale, cached K/V, and precomputed RoPE
|
||||
# Select model, cached K/V, and precomputed RoPE
|
||||
if is_dual:
|
||||
if timestep_val >= boundary:
|
||||
model = high_noise_model
|
||||
gs = guide_scale[1]
|
||||
kv = cross_kv_high
|
||||
rcs = rope_cos_sin_high
|
||||
else:
|
||||
model = low_noise_model
|
||||
gs = guide_scale[0]
|
||||
kv = cross_kv_low
|
||||
rcs = rope_cos_sin_low
|
||||
else:
|
||||
model = single_model
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
kv = cross_kv
|
||||
rcs = rope_cos_sin
|
||||
|
||||
# Build per-token timesteps for TI2V-5B (first-frame patches get t=0)
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val # [1, L]
|
||||
# Pad to seq_len if needed
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
# Batch for CFG: both cond and uncond get same timesteps
|
||||
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0) # [2, L]
|
||||
if cfg_disabled:
|
||||
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
t_batch = t_tokens # [1, L]
|
||||
else:
|
||||
t_batch = mx.array([timestep_val])
|
||||
|
||||
y_arg = [y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
if is_dual:
|
||||
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
|
||||
else:
|
||||
ctx = context_cond
|
||||
preds = model(
|
||||
[latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred = preds[0]
|
||||
del preds
|
||||
else:
|
||||
t_batch = mx.array([timestep_val, timestep_val])
|
||||
# CFG: batch cond + uncond into single B=2 forward pass
|
||||
if is_dual:
|
||||
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
|
||||
else:
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
|
||||
# I2V-14B: pass y conditioning to model (same y for cond and uncond)
|
||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
if pad_len > 0:
|
||||
t_tokens = mx.concatenate(
|
||||
[t_tokens, mx.full((1, pad_len), timestep_val)], axis=1
|
||||
)
|
||||
t_batch = mx.concatenate([t_tokens, t_tokens], axis=0)
|
||||
else:
|
||||
t_batch = mx.array([timestep_val, timestep_val])
|
||||
|
||||
# CFG: batch cond + uncond into single B=2 forward pass
|
||||
ctx = context_cfg if not is_dual else (
|
||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
||||
)
|
||||
preds = model(
|
||||
[latents, latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
# Classifier-free guidance + scheduler step
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
ctx = context_cfg if not is_dual else (
|
||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
||||
)
|
||||
preds = model(
|
||||
[latents, latents],
|
||||
t=t_batch,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
del noise_pred_cond, noise_pred_uncond, preds
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
|
||||
@@ -455,7 +533,7 @@ def generate_video(
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
||||
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
|
||||
del noise_pred
|
||||
mx.eval(latents)
|
||||
|
||||
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
|
||||
@@ -463,11 +541,19 @@ def generate_video(
|
||||
# Free transformer models and text embeddings
|
||||
if is_dual:
|
||||
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
||||
del context_cfg_low, context_cfg_high
|
||||
if cfg_disabled:
|
||||
del context_cond_low, context_cond_high
|
||||
else:
|
||||
del context_cfg_low, context_cfg_high
|
||||
else:
|
||||
del single_model, cross_kv
|
||||
del context_cfg
|
||||
del model, kv, context, context_null
|
||||
if cfg_disabled:
|
||||
del context_cond
|
||||
else:
|
||||
del context_cfg
|
||||
del model, kv, context
|
||||
if context_null is not None:
|
||||
del context_null
|
||||
gc.collect(); mx.clear_cache()
|
||||
|
||||
# Load VAE and decode
|
||||
@@ -478,25 +564,36 @@ def generate_video(
|
||||
|
||||
is_wan22_vae = config.vae_z_dim == 48
|
||||
|
||||
# Warm-up: prepend a copy of the first latent frame to provide temporal
|
||||
# context for the real first frame. Causal convolutions in the VAE decoder
|
||||
# pad with zeros on the left, so the first few output frames have degraded
|
||||
# quality (no temporal context). By duplicating the first latent, the real
|
||||
# first frame sees its own features as left context instead of zeros.
|
||||
# We trim the extra output frames after decoding.
|
||||
warmup_trim = vae_stride[0] # 4 frames per latent temporal position
|
||||
latents_for_decode = mx.concatenate([latents[:, 0:1], latents], axis=1)
|
||||
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
|
||||
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
|
||||
z = latents.transpose(1, 2, 3, 0)[None] # [1, T, H, W, C]
|
||||
z = latents_for_decode.transpose(1, 2, 3, 0)[None] # [1, T+1, H, W, C]
|
||||
z = denormalize_latents(z)
|
||||
video = vae(z) # [1, T', H', W', 3]
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [T', H', W', 3]
|
||||
video = video[warmup_trim:] # Trim warm-up frames
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
else:
|
||||
video = vae.decode(latents[None]) # [1, 3, T, H, W]
|
||||
video = vae.decode(latents_for_decode[None]) # [1, 3, T+1*4, H, W]
|
||||
mx.eval(video)
|
||||
print(f"{Colors.DIM} VAE decode: {time.time() - t4:.1f}s{Colors.RESET}")
|
||||
|
||||
video = np.array(video[0]) # [3, T, H, W]
|
||||
video = np.array(video[0]) # [3, T', H, W]
|
||||
video = video[:, warmup_trim:] # Trim warm-up frames (channels-first)
|
||||
video = (video + 1.0) / 2.0
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
||||
@@ -529,6 +626,19 @@ def main():
|
||||
choices=["euler", "dpm++", "unipc"],
|
||||
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
|
||||
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse guide scale
|
||||
@@ -542,6 +652,12 @@ def main():
|
||||
if args.no_negative_prompt:
|
||||
neg_prompt = ""
|
||||
|
||||
# Parse LoRA configs: convert [path, strength_str] → (path, float)
|
||||
def _parse_lora_args(lora_list):
|
||||
if not lora_list:
|
||||
return None
|
||||
return [(path, float(strength)) for path, strength in lora_list]
|
||||
|
||||
generate_video(
|
||||
model_dir=args.model_dir,
|
||||
prompt=args.prompt,
|
||||
@@ -556,6 +672,10 @@ def main():
|
||||
seed=args.seed,
|
||||
output_path=args.output_path,
|
||||
scheduler=args.scheduler,
|
||||
loras=_parse_lora_args(args.lora),
|
||||
loras_high=_parse_lora_args(args.lora_high),
|
||||
loras_low=_parse_lora_args(args.lora_low),
|
||||
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user