From 7435facc527fa6be9d33256455c64db214fe1ed0 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 13 Mar 2026 01:22:45 +0100 Subject: [PATCH] Add support for DEV_TWO_STAGE pipeline and implement LoRA merging functionality in generate.py. Enhance video generation capabilities by allowing LoRA weights to be loaded and merged into the model, improving flexibility in model configurations. Update pipeline handling to accommodate the new two-stage generation process. --- mlx_video/generate.py | 333 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 318 insertions(+), 15 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 998b7f8..b99ab7b 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -35,8 +35,9 @@ from mlx_video.conditioning.latent import LatentState, apply_denoise_mask class PipelineType(Enum): """Pipeline type selector.""" - DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG - DEV = "dev" # Single-stage, dynamic sigmas, CFG + DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG + DEV = "dev" # Single-stage, dynamic sigmas, CFG + DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res) # Distilled model sigma schedules @@ -61,6 +62,111 @@ AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_L DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted" +def load_and_merge_lora( + model: LTXModel, + lora_path: str, + strength: float = 1.0, +) -> None: + """Load LoRA weights and merge them into the transformer model in-place. + + Supports two formats: + - Raw PyTorch: keys like diffusion_model.{module}.lora_A.weight (needs sanitization) + - Pre-converted MLX: keys like {module}.lora_A.weight (already sanitized) + + Merge formula: weight += (lora_B * strength) @ lora_A + + Args: + model: The LTXModel transformer to merge into + lora_path: Path to the LoRA safetensors file or directory containing one + strength: LoRA strength/coefficient (default 1.0) + """ + # Resolve path: if directory, find the safetensors file inside + lora_file = Path(lora_path) + if lora_file.is_dir(): + candidates = sorted(lora_file.glob("*.safetensors")) + if not candidates: + raise FileNotFoundError(f"No .safetensors files found in {lora_path}") + lora_file = candidates[0] + console.print(f"[dim]Using LoRA file: {lora_file.name}[/]") + + # Load LoRA weights + lora_weights = mx.load(str(lora_file)) + + # Detect format: raw PyTorch has 'diffusion_model.' prefix + has_prefix = any(k.startswith("diffusion_model.") for k in lora_weights) + + # Group into A/B pairs by module name + lora_pairs = {} + for key in lora_weights: + module_key = key + if has_prefix: + if not key.startswith("diffusion_model."): + continue + module_key = key.replace("diffusion_model.", "") + + if module_key.endswith(".lora_A.weight"): + base_key = module_key.replace(".lora_A.weight", "") + lora_pairs.setdefault(base_key, {})["A"] = lora_weights[key] + elif module_key.endswith(".lora_B.weight"): + base_key = module_key.replace(".lora_B.weight", "") + lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key] + + # Apply key sanitization only for raw PyTorch format + if has_prefix: + sanitized_pairs = {} + for key, pair in lora_pairs.items(): + new_key = key + new_key = new_key.replace(".to_out.0.", ".to_out.") + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") + new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") + new_key = new_key.replace(".linear_1.", ".linear1.") + new_key = new_key.replace(".linear_2.", ".linear2.") + sanitized_pairs[new_key] = pair + else: + sanitized_pairs = lora_pairs + + # Get current model weights as a flat dict + def flatten_params(params, prefix=""): + flat = {} + for k, v in params.items(): + full_key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + flat.update(flatten_params(v, full_key)) + else: + flat[full_key] = v + return flat + + flat_weights = flatten_params(dict(model.parameters())) + + # Merge LoRA deltas + merged_count = 0 + updates = [] + for module_key, pair in sanitized_pairs.items(): + if "A" not in pair or "B" not in pair: + continue + + weight_key = f"{module_key}.weight" + if weight_key not in flat_weights: + continue + + lora_a = pair["A"].astype(mx.float32) # (rank, in_features) + lora_b = pair["B"].astype(mx.float32) # (out_features, rank) + + # delta = (lora_B * strength) @ lora_A + delta = (lora_b * strength) @ lora_a + + base_weight = flat_weights[weight_key].astype(mx.float32) + merged_weight = base_weight + delta + updates.append((weight_key, merged_weight.astype(mx.bfloat16))) + merged_count += 1 + + model.load_weights(updates, strict=False) + mx.eval(model.parameters()) + console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})") + + def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: """Compute CFG delta for classifier-free guidance. @@ -888,12 +994,15 @@ def generate_video( use_apg: bool = False, apg_eta: float = 1.0, apg_norm_threshold: float = 0.0, + lora_path: Optional[str] = None, + lora_strength: float = 1.0, ): """Generate video using LTX-2 models. - Supports two pipelines: + Supports three pipelines: - DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG - DEV: Single-stage generation with dynamic sigmas and CFG + - DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG) Args: model_repo: Model repository ID @@ -928,7 +1037,8 @@ def generate_video( start_time = time.time() # Validate dimensions - divisor = 64 if pipeline == PipelineType.DISTILLED else 32 + is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE) + divisor = 64 if is_two_stage else 32 assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}" assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}" @@ -942,12 +1052,17 @@ def generate_video( if audio: mode_str += "+Audio" - pipeline_name = "DEV" if pipeline == PipelineType.DEV else "DISTILLED" + pipeline_names = { + PipelineType.DISTILLED: "DISTILLED", + PipelineType.DEV: "DEV", + PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE", + } + pipeline_name = pipeline_names[pipeline] header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height} • {num_frames} frames[/]" console.print(Panel(header, expand=False)) console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") - if pipeline == PipelineType.DEV: + if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]") if is_i2v: @@ -962,9 +1077,8 @@ def generate_video( model_path = get_model_path(model_repo) text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) - # Model weight file # Calculate latent dimensions - if pipeline == PipelineType.DISTILLED: + if is_two_stage: stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 stage2_h, stage2_w = height // 32, width // 32 else: @@ -996,8 +1110,8 @@ def generate_video( console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]") # Encode prompts - if pipeline == PipelineType.DEV: - # Dev pipeline needs positive and negative embeddings + if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): + # Dev/dev-two-stage pipelines need positive and negative embeddings for CFG if audio: video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) @@ -1009,6 +1123,9 @@ def generate_video( audio_embeddings_pos = audio_embeddings_neg = None model_dtype = video_embeddings_pos.dtype mx.eval(video_embeddings_pos, video_embeddings_neg) + # For dev-two-stage, stage 2 uses single positive embedding (no CFG) + if pipeline == PipelineType.DEV_TWO_STAGE: + text_embeddings = video_embeddings_pos else: # Distilled pipeline - single embedding if audio: @@ -1172,7 +1289,7 @@ def generate_video( audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, ) - else: + elif pipeline == PipelineType.DEV: # ====================================================================== # DEV PIPELINE: Single-stage with CFG # ====================================================================== @@ -1193,7 +1310,6 @@ def generate_video( console.print("[green]✓[/] VAE encoder loaded and image encoded") # Generate sigma schedule with token-count-dependent shifting - num_tokens = latent_frames * latent_h * latent_w sigmas = ltx2_scheduler(steps=num_inference_steps) mx.eval(sigmas) console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") @@ -1261,6 +1377,181 @@ def generate_video( # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) + elif pipeline == PipelineType.DEV_TWO_STAGE: + # ====================================================================== + # DEV TWO-STAGE PIPELINE: + # Stage 1: Dev denoising at half resolution with CFG + # Upsample: 2x spatial via LatentUpsampler + # Stage 2: Distilled denoising at full resolution with LoRA, no CFG + # ====================================================================== + + # Load VAE encoder for I2V + stage1_image_latent = None + stage2_image_latent = None + if is_i2v: + with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): + vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") + + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) + stage1_image_latent = vae_encoder(stage1_image_tensor) + mx.eval(stage1_image_latent) + + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) + stage2_image_latent = vae_encoder(stage2_image_tensor) + mx.eval(stage2_image_latent) + + del vae_encoder + mx.clear_cache() + console.print("[green]✓[/] VAE encoder loaded and image encoded") + + # Stage 1: Dev denoising at half resolution with CFG + sigmas = ltx2_scheduler(steps=num_inference_steps) + mx.eval(sigmas) + console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") + + console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {width//2}x{height//2} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})") + mx.random.seed(seed) + + positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) + mx.eval(positions) + + audio_positions = None + audio_latents = None + if audio: + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + mx.eval(audio_positions, audio_latents) + + # Apply I2V conditioning for stage 1 + state1 = None + stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w) + if is_i2v and stage1_image_latent is not None: + state1 = LatentState( + latent=mx.zeros(stage1_shape, dtype=model_dtype), + clean_latent=mx.zeros(stage1_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength) + state1 = apply_conditioning(state1, [conditioning]) + + noise = mx.random.normal(stage1_shape, dtype=model_dtype) + noise_scale = sigmas[0] + scaled_mask = state1.denoise_mask * noise_scale + state1 = LatentState( + latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=state1.clean_latent, + denoise_mask=state1.denoise_mask, + ) + latents = state1.latent + mx.eval(latents) + else: + latents = mx.random.normal(stage1_shape, dtype=model_dtype) + mx.eval(latents) + + # Run stage 1 with dev-style CFG denoising + if audio: + latents, audio_latents = denoise_dev_av( + latents, audio_latents, + positions, audio_positions, + video_embeddings_pos, video_embeddings_neg, + audio_embeddings_pos, audio_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, + cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1, + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold + ) + else: + latents = denoise_dev( + latents, positions, + video_embeddings_pos, video_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, + verbose=verbose, state=state1, + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold + ) + + # Upsample latents 2x + with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): + upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors")) + if not upscaler_files: + raise FileNotFoundError(f"No spatial upscaler found in {model_path}") + upsampler = load_upsampler(str(upscaler_files[0])) + mx.eval(upsampler.parameters()) + + vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) + + latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std) + mx.eval(latents) + + del upsampler + mx.clear_cache() + console.print("[green]✓[/] Latents upsampled") + + # Merge LoRA weights for stage 2 (distilled refinement) + if lora_path is None: + # Auto-detect LoRA file in model directory + lora_files = sorted(model_path.glob("*distilled-lora*.safetensors")) + if lora_files: + lora_path = str(lora_files[0]) + console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]") + else: + console.print("[yellow]⚠️ No LoRA file found. Stage 2 will use base weights.[/]") + + if lora_path is not None: + with console.status("[blue]🔧 Merging distilled LoRA weights...[/]", spinner="dots"): + load_and_merge_lora(transformer, lora_path, strength=lora_strength) + + # Stage 2: Distilled refinement at full resolution (no CFG) + console.print(f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)") + positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) + mx.eval(positions) + + state2 = None + if is_i2v and stage2_image_latent is not None: + state2 = LatentState( + latent=latents, + clean_latent=mx.zeros_like(latents), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength) + state2 = apply_conditioning(state2, [conditioning]) + + noise = mx.random.normal(latents.shape).astype(model_dtype) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + scaled_mask = state2.denoise_mask * noise_scale + state2 = LatentState( + latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=state2.clean_latent, + denoise_mask=state2.denoise_mask, + ) + latents = state2.latent + mx.eval(latents) + + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale + mx.eval(audio_latents) + else: + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) + noise = mx.random.normal(latents.shape).astype(model_dtype) + latents = noise * noise_scale + latents * one_minus_scale + mx.eval(latents) + + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale + mx.eval(audio_latents) + + # Stage 2 uses distilled denoising (no CFG) + latents, audio_latents = denoise_distilled( + latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, + verbose=verbose, state=state2, + audio_latents=audio_latents, audio_positions=audio_positions, + audio_embeddings=audio_embeddings_pos if audio else None, + ) + del transformer mx.clear_cache() @@ -1445,6 +1736,9 @@ Examples: python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 3.0 python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 40 + # Dev two-stage pipeline (dev + LoRA refinement) + python -m mlx_video.generate --prompt "A cat walking" --pipeline dev-two-stage --cfg-scale 3.0 + # Image-to-Video (works with both pipelines) python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --pipeline dev @@ -1456,8 +1750,8 @@ Examples: ) parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate") - parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev"], - help="Pipeline type: distilled (two-stage, fast) or dev (single-stage, CFG)") + parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage"], + help="Pipeline type: distilled (two-stage, fast), dev (single-stage, CFG), or dev-two-stage (dev + LoRA refinement)") parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, help="Negative prompt for CFG (dev pipeline only)") parser.add_argument("--height", "-H", type=int, default=512, help="Output video height") @@ -1488,9 +1782,16 @@ Examples: parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)") + parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)") + parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)") args = parser.parse_args() - pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED + pipeline_map = { + "distilled": PipelineType.DISTILLED, + "dev": PipelineType.DEV, + "dev-two-stage": PipelineType.DEV_TWO_STAGE, + } + pipeline = pipeline_map[args.pipeline] generate_video( model_repo=args.model_repo, @@ -1522,6 +1823,8 @@ Examples: use_apg=args.apg, apg_eta=args.apg_eta, apg_norm_threshold=args.apg_norm_threshold, + lora_path=args.lora_path, + lora_strength=args.lora_strength, )