diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 998b7f8..b99ab7b 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -35,8 +35,9 @@ from mlx_video.conditioning.latent import LatentState, apply_denoise_mask class PipelineType(Enum): """Pipeline type selector.""" - DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG - DEV = "dev" # Single-stage, dynamic sigmas, CFG + DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG + DEV = "dev" # Single-stage, dynamic sigmas, CFG + DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res) # Distilled model sigma schedules @@ -61,6 +62,111 @@ AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_L DEFAULT_NEGATIVE_PROMPT = "worst quality, inconsistent motion, blurry, jittery, distorted" +def load_and_merge_lora( + model: LTXModel, + lora_path: str, + strength: float = 1.0, +) -> None: + """Load LoRA weights and merge them into the transformer model in-place. + + Supports two formats: + - Raw PyTorch: keys like diffusion_model.{module}.lora_A.weight (needs sanitization) + - Pre-converted MLX: keys like {module}.lora_A.weight (already sanitized) + + Merge formula: weight += (lora_B * strength) @ lora_A + + Args: + model: The LTXModel transformer to merge into + lora_path: Path to the LoRA safetensors file or directory containing one + strength: LoRA strength/coefficient (default 1.0) + """ + # Resolve path: if directory, find the safetensors file inside + lora_file = Path(lora_path) + if lora_file.is_dir(): + candidates = sorted(lora_file.glob("*.safetensors")) + if not candidates: + raise FileNotFoundError(f"No .safetensors files found in {lora_path}") + lora_file = candidates[0] + console.print(f"[dim]Using LoRA file: {lora_file.name}[/]") + + # Load LoRA weights + lora_weights = mx.load(str(lora_file)) + + # Detect format: raw PyTorch has 'diffusion_model.' prefix + has_prefix = any(k.startswith("diffusion_model.") for k in lora_weights) + + # Group into A/B pairs by module name + lora_pairs = {} + for key in lora_weights: + module_key = key + if has_prefix: + if not key.startswith("diffusion_model."): + continue + module_key = key.replace("diffusion_model.", "") + + if module_key.endswith(".lora_A.weight"): + base_key = module_key.replace(".lora_A.weight", "") + lora_pairs.setdefault(base_key, {})["A"] = lora_weights[key] + elif module_key.endswith(".lora_B.weight"): + base_key = module_key.replace(".lora_B.weight", "") + lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key] + + # Apply key sanitization only for raw PyTorch format + if has_prefix: + sanitized_pairs = {} + for key, pair in lora_pairs.items(): + new_key = key + new_key = new_key.replace(".to_out.0.", ".to_out.") + new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") + new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") + new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") + new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") + new_key = new_key.replace(".linear_1.", ".linear1.") + new_key = new_key.replace(".linear_2.", ".linear2.") + sanitized_pairs[new_key] = pair + else: + sanitized_pairs = lora_pairs + + # Get current model weights as a flat dict + def flatten_params(params, prefix=""): + flat = {} + for k, v in params.items(): + full_key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + flat.update(flatten_params(v, full_key)) + else: + flat[full_key] = v + return flat + + flat_weights = flatten_params(dict(model.parameters())) + + # Merge LoRA deltas + merged_count = 0 + updates = [] + for module_key, pair in sanitized_pairs.items(): + if "A" not in pair or "B" not in pair: + continue + + weight_key = f"{module_key}.weight" + if weight_key not in flat_weights: + continue + + lora_a = pair["A"].astype(mx.float32) # (rank, in_features) + lora_b = pair["B"].astype(mx.float32) # (out_features, rank) + + # delta = (lora_B * strength) @ lora_A + delta = (lora_b * strength) @ lora_a + + base_weight = flat_weights[weight_key].astype(mx.float32) + merged_weight = base_weight + delta + updates.append((weight_key, merged_weight.astype(mx.bfloat16))) + merged_count += 1 + + model.load_weights(updates, strict=False) + mx.eval(model.parameters()) + console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})") + + def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array: """Compute CFG delta for classifier-free guidance. @@ -888,12 +994,15 @@ def generate_video( use_apg: bool = False, apg_eta: float = 1.0, apg_norm_threshold: float = 0.0, + lora_path: Optional[str] = None, + lora_strength: float = 1.0, ): """Generate video using LTX-2 models. - Supports two pipelines: + Supports three pipelines: - DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG - DEV: Single-stage generation with dynamic sigmas and CFG + - DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG) Args: model_repo: Model repository ID @@ -928,7 +1037,8 @@ def generate_video( start_time = time.time() # Validate dimensions - divisor = 64 if pipeline == PipelineType.DISTILLED else 32 + is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE) + divisor = 64 if is_two_stage else 32 assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}" assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}" @@ -942,12 +1052,17 @@ def generate_video( if audio: mode_str += "+Audio" - pipeline_name = "DEV" if pipeline == PipelineType.DEV else "DISTILLED" + pipeline_names = { + PipelineType.DISTILLED: "DISTILLED", + PipelineType.DEV: "DEV", + PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE", + } + pipeline_name = pipeline_names[pipeline] header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height} • {num_frames} frames[/]" console.print(Panel(header, expand=False)) console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") - if pipeline == PipelineType.DEV: + if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}, Rescale: {cfg_rescale}[/]") if is_i2v: @@ -962,9 +1077,8 @@ def generate_video( model_path = get_model_path(model_repo) text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) - # Model weight file # Calculate latent dimensions - if pipeline == PipelineType.DISTILLED: + if is_two_stage: stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 stage2_h, stage2_w = height // 32, width // 32 else: @@ -996,8 +1110,8 @@ def generate_video( console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]") # Encode prompts - if pipeline == PipelineType.DEV: - # Dev pipeline needs positive and negative embeddings + if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): + # Dev/dev-two-stage pipelines need positive and negative embeddings for CFG if audio: video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) @@ -1009,6 +1123,9 @@ def generate_video( audio_embeddings_pos = audio_embeddings_neg = None model_dtype = video_embeddings_pos.dtype mx.eval(video_embeddings_pos, video_embeddings_neg) + # For dev-two-stage, stage 2 uses single positive embedding (no CFG) + if pipeline == PipelineType.DEV_TWO_STAGE: + text_embeddings = video_embeddings_pos else: # Distilled pipeline - single embedding if audio: @@ -1172,7 +1289,7 @@ def generate_video( audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings, ) - else: + elif pipeline == PipelineType.DEV: # ====================================================================== # DEV PIPELINE: Single-stage with CFG # ====================================================================== @@ -1193,7 +1310,6 @@ def generate_video( console.print("[green]✓[/] VAE encoder loaded and image encoded") # Generate sigma schedule with token-count-dependent shifting - num_tokens = latent_frames * latent_h * latent_w sigmas = ltx2_scheduler(steps=num_inference_steps) mx.eval(sigmas) console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") @@ -1261,6 +1377,181 @@ def generate_video( # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) + elif pipeline == PipelineType.DEV_TWO_STAGE: + # ====================================================================== + # DEV TWO-STAGE PIPELINE: + # Stage 1: Dev denoising at half resolution with CFG + # Upsample: 2x spatial via LatentUpsampler + # Stage 2: Distilled denoising at full resolution with LoRA, no CFG + # ====================================================================== + + # Load VAE encoder for I2V + stage1_image_latent = None + stage2_image_latent = None + if is_i2v: + with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"): + vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") + + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) + stage1_image_latent = vae_encoder(stage1_image_tensor) + mx.eval(stage1_image_latent) + + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) + stage2_image_latent = vae_encoder(stage2_image_tensor) + mx.eval(stage2_image_latent) + + del vae_encoder + mx.clear_cache() + console.print("[green]✓[/] VAE encoder loaded and image encoded") + + # Stage 1: Dev denoising at half resolution with CFG + sigmas = ltx2_scheduler(steps=num_inference_steps) + mx.eval(sigmas) + console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]") + + console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {width//2}x{height//2} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})") + mx.random.seed(seed) + + positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) + mx.eval(positions) + + audio_positions = None + audio_latents = None + if audio: + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + mx.eval(audio_positions, audio_latents) + + # Apply I2V conditioning for stage 1 + state1 = None + stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w) + if is_i2v and stage1_image_latent is not None: + state1 = LatentState( + latent=mx.zeros(stage1_shape, dtype=model_dtype), + clean_latent=mx.zeros(stage1_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength) + state1 = apply_conditioning(state1, [conditioning]) + + noise = mx.random.normal(stage1_shape, dtype=model_dtype) + noise_scale = sigmas[0] + scaled_mask = state1.denoise_mask * noise_scale + state1 = LatentState( + latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=state1.clean_latent, + denoise_mask=state1.denoise_mask, + ) + latents = state1.latent + mx.eval(latents) + else: + latents = mx.random.normal(stage1_shape, dtype=model_dtype) + mx.eval(latents) + + # Run stage 1 with dev-style CFG denoising + if audio: + latents, audio_latents = denoise_dev_av( + latents, audio_latents, + positions, audio_positions, + video_embeddings_pos, video_embeddings_neg, + audio_embeddings_pos, audio_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, + cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1, + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold + ) + else: + latents = denoise_dev( + latents, positions, + video_embeddings_pos, video_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, + verbose=verbose, state=state1, + use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold + ) + + # Upsample latents 2x + with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"): + upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors")) + if not upscaler_files: + raise FileNotFoundError(f"No spatial upscaler found in {model_path}") + upsampler = load_upsampler(str(upscaler_files[0])) + mx.eval(upsampler.parameters()) + + vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) + + latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std) + mx.eval(latents) + + del upsampler + mx.clear_cache() + console.print("[green]✓[/] Latents upsampled") + + # Merge LoRA weights for stage 2 (distilled refinement) + if lora_path is None: + # Auto-detect LoRA file in model directory + lora_files = sorted(model_path.glob("*distilled-lora*.safetensors")) + if lora_files: + lora_path = str(lora_files[0]) + console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]") + else: + console.print("[yellow]⚠️ No LoRA file found. Stage 2 will use base weights.[/]") + + if lora_path is not None: + with console.status("[blue]🔧 Merging distilled LoRA weights...[/]", spinner="dots"): + load_and_merge_lora(transformer, lora_path, strength=lora_strength) + + # Stage 2: Distilled refinement at full resolution (no CFG) + console.print(f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)") + positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) + mx.eval(positions) + + state2 = None + if is_i2v and stage2_image_latent is not None: + state2 = LatentState( + latent=latents, + clean_latent=mx.zeros_like(latents), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength) + state2 = apply_conditioning(state2, [conditioning]) + + noise = mx.random.normal(latents.shape).astype(model_dtype) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + scaled_mask = state2.denoise_mask * noise_scale + state2 = LatentState( + latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=state2.clean_latent, + denoise_mask=state2.denoise_mask, + ) + latents = state2.latent + mx.eval(latents) + + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale + mx.eval(audio_latents) + else: + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) + noise = mx.random.normal(latents.shape).astype(model_dtype) + latents = noise * noise_scale + latents * one_minus_scale + mx.eval(latents) + + if audio and audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale + mx.eval(audio_latents) + + # Stage 2 uses distilled denoising (no CFG) + latents, audio_latents = denoise_distilled( + latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, + verbose=verbose, state=state2, + audio_latents=audio_latents, audio_positions=audio_positions, + audio_embeddings=audio_embeddings_pos if audio else None, + ) + del transformer mx.clear_cache() @@ -1445,6 +1736,9 @@ Examples: python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 3.0 python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 40 + # Dev two-stage pipeline (dev + LoRA refinement) + python -m mlx_video.generate --prompt "A cat walking" --pipeline dev-two-stage --cfg-scale 3.0 + # Image-to-Video (works with both pipelines) python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --pipeline dev @@ -1456,8 +1750,8 @@ Examples: ) parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate") - parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev"], - help="Pipeline type: distilled (two-stage, fast) or dev (single-stage, CFG)") + parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage"], + help="Pipeline type: distilled (two-stage, fast), dev (single-stage, CFG), or dev-two-stage (dev + LoRA refinement)") parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, help="Negative prompt for CFG (dev pipeline only)") parser.add_argument("--height", "-H", type=int, default=512, help="Output video height") @@ -1488,9 +1782,16 @@ Examples: parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)") parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)") parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)") + parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)") + parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)") args = parser.parse_args() - pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED + pipeline_map = { + "distilled": PipelineType.DISTILLED, + "dev": PipelineType.DEV, + "dev-two-stage": PipelineType.DEV_TWO_STAGE, + } + pipeline = pipeline_map[args.pipeline] generate_video( model_repo=args.model_repo, @@ -1522,6 +1823,8 @@ Examples: use_apg=args.apg, apg_eta=args.apg_eta, apg_norm_threshold=args.apg_norm_threshold, + lora_path=args.lora_path, + lora_strength=args.lora_strength, )