diff --git a/README.md b/README.md index da7c7aa..fdbddf9 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Supported models: ## Features - Text-to-video (T2V) and Image-to-video (I2V) generation -- Three pipeline modes: Distilled, Dev, and Dev Two-Stage +- Four pipeline modes: Distilled, Dev, Dev Two-Stage, and Dev Two-Stage HQ - Synchronized audio-video generation (experimental) - LoRA support (including HuggingFace repos) - Prompt enhancement via Gemma @@ -35,13 +35,14 @@ Supported models: ### Pipelines -mlx-video supports three pipeline types via the `--pipeline` flag: +mlx-video supports four pipeline types via the `--pipeline` flag: | Pipeline | Description | CFG | Stages | Speed | |----------|-------------|-----|--------|-------| | `distilled` (default) | Fixed sigma schedule, no CFG | No | 2 (8+3 steps) | Fastest | | `dev` | Dynamic sigmas, constant CFG | Yes | 1 (30 steps) | Medium | -| `dev-two-stage` | Dev + LoRA refinement | Yes (stage 1) | 2 (30+3 steps) | Slowest, highest quality | +| `dev-two-stage` | Dev + LoRA refinement | Yes (stage 1) | 2 (30+3 steps) | Slow | +| `dev-two-stage-hq` | res_2s sampler + LoRA both stages | Yes (stage 1) | 2 (15+3 steps) | Slow, highest quality | ### Text-to-Video @@ -52,13 +53,24 @@ uv run mlx_video.generate --prompt "Two dogs wearing sunglasses, cinematic, suns # Dev - single-stage with CFG uv run mlx_video.generate --pipeline dev --prompt "A cinematic scene" --cfg-scale 3.0 -# Dev two-stage - dev + LoRA refinement (highest quality) +# Dev two-stage - dev + LoRA refinement uv run mlx_video.generate --pipeline dev-two-stage \ --prompt "Two dogs of the poodle breed wearing sunglasses, close up, cinematic, sunset" \ -n 145 --width 1024 --height 768 \ --model-repo prince-canuma/LTX-2-dev \ --cfg-scale 3.0 --lora-strength 0.8 \ --enhance-prompt + +# Dev two-stage HQ - res_2s sampler, LoRA both stages (highest quality) +uv run mlx_video.generate --pipeline dev-two-stage-hq \ + --prompt "A cinematic scene of ocean waves at golden hour" \ + --model-repo prince-canuma/LTX-2-dev + +# HQ with custom LoRA strengths +uv run mlx_video.generate --pipeline dev-two-stage-hq \ + --prompt "A sunset over mountains" \ + --model-repo prince-canuma/LTX-2-dev \ + --lora-strength-stage-1 0.3 --lora-strength-stage-2 0.6 ``` Poodles demo @@ -124,7 +136,7 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom | Option | Default | Description | |--------|---------|-------------| | `--prompt`, `-p` | (required) | Text description of the video | -| `--pipeline` | `distilled` | Pipeline type: `distilled`, `dev`, or `dev-two-stage` | +| `--pipeline` | `distilled` | Pipeline type: `distilled`, `dev`, `dev-two-stage`, or `dev-two-stage-hq` | | `--height`, `-H` | 512 | Output height (divisible by 64 for two-stage, 32 for dev) | | `--width`, `-W` | 512 | Output width (divisible by 64 for two-stage, 32 for dev) | | `--num-frames`, `-n` | 33 | Number of frames (must be 1 + 8*k) | @@ -161,6 +173,15 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom | `--lora-path` | auto-detect | Path to LoRA file, directory, or HuggingFace repo | | `--lora-strength` | 1.0 | LoRA merge strength | +**Dev-Two-Stage HQ options:** + +| Option | Default | Description | +|--------|---------|-------------| +| `--lora-strength-stage-1` | 0.25 | LoRA strength for stage 1 | +| `--lora-strength-stage-2` | 0.5 | LoRA strength for stage 2 | + +HQ defaults: 15 steps (vs 30), `cfg-rescale` 0.45 (vs 0.7), STG disabled. Uses the res_2s second-order sampler (2 model evals per step) for better quality at the same compute budget. + ## How It Works ### Distilled Pipeline (default) @@ -179,6 +200,14 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom 3. **Stage 2**: Distilled refinement at full resolution with LoRA weights (3 steps, no CFG) 4. **Decode**: VAE decoder converts latents to RGB video +### Dev Two-Stage HQ Pipeline +1. **Stage 1**: res_2s denoising at half resolution with CFG + LoRA@0.25 (15 steps, 2 evals/step) +2. **Upsample**: 2x spatial upsampling via LatentUpsampler +3. **Stage 2**: res_2s refinement at full resolution with LoRA@0.5 (3 steps, no CFG) +4. **Decode**: VAE decoder converts latents to RGB video + +The res_2s sampler uses an exponential Rosenbrock-type Runge-Kutta integrator with SDE noise injection, producing higher quality results than Euler at the same compute budget (~30 total model evaluations). + ## Requirements - macOS with Apple Silicon diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 5945486..d6f5517 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -38,6 +38,7 @@ class PipelineType(Enum): DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG DEV = "dev" # Single-stage, dynamic sigmas, CFG DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res) + DEV_TWO_STAGE_HQ = "dev-two-stage-hq" # Two-stage: res_2s sampler, LoRA both stages # Distilled model sigma schedules @@ -1012,6 +1013,329 @@ def denoise_dev_av( return video_latents, audio_latents +def denoise_res2s_av( + video_latents: mx.array, + audio_latents: mx.array, + video_positions: mx.array, + audio_positions: mx.array, + video_embeddings_pos: mx.array, + video_embeddings_neg: mx.array, + audio_embeddings_pos: mx.array, + audio_embeddings_neg: mx.array, + transformer: LTXModel, + sigmas: mx.array, + cfg_scale: float = 3.0, + audio_cfg_scale: float = 7.0, + cfg_rescale: float = 0.45, + audio_cfg_rescale: Optional[float] = None, + verbose: bool = True, + video_state: Optional[LatentState] = None, + stg_scale: float = 0.0, + stg_video_blocks: Optional[list] = None, + stg_audio_blocks: Optional[list] = None, + modality_scale: float = 1.0, + noise_seed: int = 42, + bongmath: bool = True, + bongmath_max_iter: int = 100, +) -> tuple[mx.array, mx.array]: + """Run res_2s second-order denoising loop with CFG/STG/modality guidance. + + Two model evaluations per step (current point + midpoint), with SDE noise + injection and optional bong iteration for anchor refinement. + + Args: + audio_cfg_rescale: Separate rescale for audio. If None, uses cfg_rescale. + noise_seed: Seed for SDE noise generators. + bongmath: Enable iterative anchor refinement. + bongmath_max_iter: Max bong iterations per step. + """ + from mlx_video.models.ltx.rope import precompute_freqs_cis + from mlx_video.samplers import get_res2s_coefficients, sde_noise_step, get_new_noise + + if audio_cfg_rescale is None: + audio_cfg_rescale = cfg_rescale + + dtype = video_latents.dtype + if video_state is not None: + video_latents = video_state.latent + + video_latents = video_latents.astype(mx.float32) + audio_latents = audio_latents.astype(mx.float32) + + sigmas_list = sigmas.tolist() + use_cfg = cfg_scale != 1.0 + use_stg = stg_scale != 0.0 and stg_video_blocks is not None + use_modality = modality_scale != 1.0 + n_full_steps = len(sigmas_list) - 1 + + # Pad sigmas if last is 0 (avoid division by zero in RK steps) + if sigmas_list[-1] == 0: + sigmas_list = sigmas_list[:-1] + [0.0011, 0.0] + + # Compute step sizes in log-space for the main loop steps only. + # After padding, sigmas_list may have an extra [0.0011, 0.0] tail; + # we only need hs for the n_full_steps pairs the loop actually uses. + hs = [-math.log(sigmas_list[i + 1] / sigmas_list[i]) for i in range(n_full_steps)] + + # Precompute RoPE + precomputed_video_rope = precompute_freqs_cis( + video_positions, + dim=transformer.inner_dim, + theta=transformer.positional_embedding_theta, + max_pos=transformer.positional_embedding_max_pos, + use_middle_indices_grid=transformer.use_middle_indices_grid, + num_attention_heads=transformer.num_attention_heads, + rope_type=transformer.rope_type, + double_precision=transformer.config.double_precision_rope, + ) + precomputed_audio_rope = precompute_freqs_cis( + audio_positions, + dim=transformer.audio_inner_dim, + theta=transformer.positional_embedding_theta, + max_pos=transformer.audio_positional_embedding_max_pos, + use_middle_indices_grid=transformer.use_middle_indices_grid, + num_attention_heads=transformer.audio_num_attention_heads, + rope_type=transformer.rope_type, + double_precision=transformer.config.double_precision_rope, + ) + mx.eval(precomputed_video_rope, precomputed_audio_rope) + + phi_cache = {} + c2 = 0.5 + + # Noise key management: step noise and substep noise use different keys + step_noise_key = mx.random.key(noise_seed) + substep_noise_key = mx.random.key(noise_seed + 10000) + + def _eval_guided_denoise(v_latents, a_latents, sigma): + """Run all guidance passes and return (video_denoised, audio_denoised) in float32 spatial format.""" + b, c, f, h, w = v_latents.shape + num_video_tokens = f * h * w + video_flat = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)).astype(dtype) + + ab, ac, at, af = a_latents.shape + audio_flat = mx.transpose(a_latents, (0, 2, 1, 3)) + audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype) + + # Timesteps + if video_state is not None: + denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1)) + denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) + denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) + video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat + else: + video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) + audio_timesteps = mx.full((ab, at), sigma, dtype=dtype) + + sigma_array = mx.full((b,), sigma, dtype=dtype) + audio_sigma_array = mx.full((ab,), sigma, dtype=dtype) + + # Pass 1: Positive conditioning + video_modality_pos = Modality( + latent=video_flat, timesteps=video_timesteps, positions=video_positions, + context=video_embeddings_pos, context_mask=None, enabled=True, + positional_embeddings=precomputed_video_rope, sigma=sigma_array, + ) + audio_modality_pos = Modality( + latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, + context=audio_embeddings_pos, context_mask=None, enabled=True, + positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array, + ) + video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos) + mx.eval(video_vel_pos, audio_vel_pos) + + # Convert velocity to x0 + video_flat_f32 = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)) + audio_flat_f32 = mx.reshape(mx.transpose(a_latents, (0, 2, 1, 3)), (ab, at, ac * af)) + video_ts_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1) + audio_ts_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1) + + video_x0_pos = video_flat_f32 - video_ts_f32 * video_vel_pos.astype(mx.float32) + audio_x0_pos = audio_flat_f32 - audio_ts_f32 * audio_vel_pos.astype(mx.float32) + + video_x0_guided = video_x0_pos + audio_x0_guided = audio_x0_pos + + # Pass 2: CFG + if use_cfg: + video_modality_neg = Modality( + latent=video_flat, timesteps=video_timesteps, positions=video_positions, + context=video_embeddings_neg, context_mask=None, enabled=True, + positional_embeddings=precomputed_video_rope, sigma=sigma_array, + ) + audio_modality_neg = Modality( + latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions, + context=audio_embeddings_neg, context_mask=None, enabled=True, + positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array, + ) + video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg) + mx.eval(video_vel_neg, audio_vel_neg) + + video_x0_neg = video_flat_f32 - video_ts_f32 * video_vel_neg.astype(mx.float32) + audio_x0_neg = audio_flat_f32 - audio_ts_f32 * audio_vel_neg.astype(mx.float32) + + video_x0_guided = video_x0_pos + (cfg_scale - 1.0) * (video_x0_pos - video_x0_neg) + audio_x0_guided = audio_x0_pos + (audio_cfg_scale - 1.0) * (audio_x0_pos - audio_x0_neg) + + # Pass 3: STG + if use_stg: + video_vel_ptb, audio_vel_ptb = transformer( + video=video_modality_pos, audio=audio_modality_pos, + stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks, + ) + mx.eval(video_vel_ptb, audio_vel_ptb) + + video_x0_ptb = video_flat_f32 - video_ts_f32 * video_vel_ptb.astype(mx.float32) + audio_x0_ptb = audio_flat_f32 - audio_ts_f32 * audio_vel_ptb.astype(mx.float32) + + video_x0_guided = video_x0_guided + stg_scale * (video_x0_pos - video_x0_ptb) + audio_x0_guided = audio_x0_guided + stg_scale * (audio_x0_pos - audio_x0_ptb) + + # Pass 4: Modality isolation + if use_modality: + video_vel_iso, audio_vel_iso = transformer( + video=video_modality_pos, audio=audio_modality_pos, + skip_cross_modal=True, + ) + mx.eval(video_vel_iso, audio_vel_iso) + + video_x0_iso = video_flat_f32 - video_ts_f32 * video_vel_iso.astype(mx.float32) + audio_x0_iso = audio_flat_f32 - audio_ts_f32 * audio_vel_iso.astype(mx.float32) + + video_x0_guided = video_x0_guided + (modality_scale - 1.0) * (video_x0_pos - video_x0_iso) + audio_x0_guided = audio_x0_guided + (modality_scale - 1.0) * (audio_x0_pos - audio_x0_iso) + + # Rescale (separate factors for video and audio) + if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality): + v_factor = video_x0_pos.std() / (video_x0_guided.std() + 1e-8) + v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) + video_x0_guided = video_x0_guided * v_factor + if audio_cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality): + a_factor = audio_x0_pos.std() / (audio_x0_guided.std() + 1e-8) + a_factor = audio_cfg_rescale * a_factor + (1.0 - audio_cfg_rescale) + audio_x0_guided = audio_x0_guided * a_factor + + # Reshape to spatial + video_denoised = mx.reshape(mx.transpose(video_x0_guided, (0, 2, 1)), (b, c, f, h, w)) + audio_denoised = mx.reshape(audio_x0_guided, (ab, at, ac, af)) + audio_denoised = mx.transpose(audio_denoised, (0, 2, 1, 3)) + + # Post-process with mask + if video_state is not None: + clean_f32 = video_state.clean_latent.astype(mx.float32) + mask_f32 = video_state.denoise_mask.astype(mx.float32) + video_denoised = video_denoised * mask_f32 + clean_f32 * (1.0 - mask_f32) + + mx.eval(video_denoised, audio_denoised) + return video_denoised, audio_denoised + + # Main res_2s loop + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(), + console=console, + disable=not verbose, + ) as progress: + passes = ["res2s"] + if use_cfg: passes.append("CFG") + if use_stg: passes.append("STG") + if use_modality: passes.append("Mod") + label = "+".join(passes) + task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=n_full_steps) + + for step_idx in range(n_full_steps): + sigma = sigmas_list[step_idx] + sigma_next = sigmas_list[step_idx + 1] + h = hs[step_idx] + + # Initialize anchor + x_anchor_video = video_latents + x_anchor_audio = audio_latents + + # ============================================================ + # Stage 1: Evaluate denoiser at current sigma + # ============================================================ + denoised_video_1, denoised_audio_1 = _eval_guided_denoise( + video_latents, audio_latents, sigma + ) + + # RK coefficients + a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2) + + # Substep sigma (geometric midpoint for c2=0.5) + sub_sigma = math.sqrt(sigma * sigma_next) + + # Compute midpoint + eps_1_video = denoised_video_1 - x_anchor_video + eps_1_audio = denoised_audio_1 - x_anchor_audio + + x_mid_video = x_anchor_video + h * a21 * eps_1_video + x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio + + # SDE noise injection at substep + substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3) + substep_noise_v = get_new_noise(video_latents.shape, key1) + substep_noise_a = get_new_noise(audio_latents.shape, key2) + + x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v) + x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a) + mx.eval(x_mid_video, x_mid_audio) + + # ============================================================ + # Bong iteration: refine anchor (pure arithmetic, no model calls) + # ============================================================ + if bongmath and h < 0.5 and sigma > 0.03: + for _ in range(bongmath_max_iter): + x_anchor_video = x_mid_video - h * a21 * eps_1_video + eps_1_video = denoised_video_1 - x_anchor_video + x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio + eps_1_audio = denoised_audio_1 - x_anchor_audio + mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio) + + # ============================================================ + # Stage 2: Evaluate denoiser at midpoint sigma + # ============================================================ + denoised_video_2, denoised_audio_2 = _eval_guided_denoise( + x_mid_video.astype(mx.float32), x_mid_audio.astype(mx.float32), sub_sigma + ) + + # ============================================================ + # Final combination with RK coefficients + # ============================================================ + eps_2_video = denoised_video_2 - x_anchor_video + eps_2_audio = denoised_audio_2 - x_anchor_audio + + x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video) + x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio) + + # SDE noise injection at step level + step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3) + step_noise_v = get_new_noise(video_latents.shape, key1) + step_noise_a = get_new_noise(audio_latents.shape, key2) + + x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v) + x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a) + + video_latents = x_next_video.astype(mx.float32) + audio_latents = x_next_audio.astype(mx.float32) + mx.eval(video_latents, audio_latents) + progress.advance(task) + + # Final clean step if original schedule ended at 0 + if sigmas.tolist()[-1] == 0: + denoised_video, denoised_audio = _eval_guided_denoise( + video_latents, audio_latents, sigmas_list[n_full_steps] + ) + video_latents = denoised_video + audio_latents = denoised_audio + mx.eval(video_latents, audio_latents) + + return video_latents, audio_latents + + # ============================================================================= # Audio Loading and Processing # ============================================================================= @@ -1117,13 +1441,16 @@ def generate_video( modality_scale: float = 1.0, lora_path: Optional[str] = None, lora_strength: float = 1.0, + lora_strength_stage_1: Optional[float] = None, + lora_strength_stage_2: Optional[float] = None, ): """Generate video using LTX-2 models. - Supports three pipelines: + Supports four pipelines: - DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG - DEV: Single-stage generation with dynamic sigmas and CFG - DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG) + - DEV_TWO_STAGE_HQ: res_2s sampler, LoRA both stages (0.25/0.5), lower rescale Args: model_repo: Model repository ID @@ -1158,7 +1485,7 @@ def generate_video( start_time = time.time() # Validate dimensions - is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE) + is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ) divisor = 64 if is_two_stage else 32 assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}" assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}" @@ -1177,13 +1504,14 @@ def generate_video( PipelineType.DISTILLED: "DISTILLED", PipelineType.DEV: "DEV", PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE", + PipelineType.DEV_TWO_STAGE_HQ: "DEV-TWO-STAGE-HQ", } pipeline_name = pipeline_names[pipeline] header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height} • {num_frames} frames[/]" console.print(Panel(header, expand=False)) console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]") - if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): + if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ): audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else "" stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else "" mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else "" @@ -1237,14 +1565,14 @@ def generate_video( # Encode prompts - always get audio embeddings since the model was trained # with joint audio-video processing (PyTorch unconditionally generates audio) - if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE): + if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ): # Dev/dev-two-stage pipelines need positive and negative embeddings for CFG video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True) video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True) model_dtype = video_embeddings_pos.dtype mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg) # For dev-two-stage, stage 2 uses single positive embedding (no CFG) - if pipeline == PipelineType.DEV_TWO_STAGE: + if pipeline in (PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ): text_embeddings = video_embeddings_pos else: # Distilled pipeline - single embedding @@ -1655,6 +1983,190 @@ def generate_video( audio_embeddings=audio_embeddings_pos, ) + elif pipeline == PipelineType.DEV_TWO_STAGE_HQ: + # ====================================================================== + # DEV TWO-STAGE HQ PIPELINE: + # Stage 1: res_2s denoising at half resolution with CFG + LoRA@0.25 + # Upsample: 2x spatial via LatentUpsampler + # Stage 2: res_2s refinement at full resolution with LoRA@0.5, no CFG + # ====================================================================== + + # HQ defaults + hq_lora_strength_s1 = lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25 + hq_lora_strength_s2 = lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5 + hq_cfg_rescale = cfg_rescale if cfg_rescale != 0.7 else 0.45 # Override default 0.7 → 0.45 + hq_steps = num_inference_steps if num_inference_steps != 30 else 15 # Override default 30 → 15 + + # Load VAE encoder for I2V + stage1_image_latent = None + stage2_image_latent = None + if is_i2v: + with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"): + vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder") + + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) + stage1_image_latent = vae_encoder(stage1_image_tensor) + mx.eval(stage1_image_latent) + + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) + stage2_image_latent = vae_encoder(stage2_image_tensor) + mx.eval(stage2_image_latent) + + del vae_encoder + mx.clear_cache() + console.print("[green]✓[/] VAE encoder loaded and image encoded") + + # Auto-detect and merge LoRA for stage 1 (strength 0.25) + if lora_path is None: + lora_files = sorted(model_path.glob("*distilled-lora*.safetensors")) + if lora_files: + lora_path = str(lora_files[0]) + console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]") + else: + console.print("[yellow]Warning: No LoRA file found. HQ pipeline works best with distilled LoRA.[/]") + + if lora_path is not None: + with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"): + load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1) + + # Stage 1: res_2s denoising at half resolution with CFG + # HQ passes actual token count to scheduler (unlike regular dev-two-stage) + num_tokens = latent_frames * stage1_h * stage1_w + sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens) + mx.eval(sigmas) + console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]") + + console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})") + mx.random.seed(seed) + + positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) + mx.eval(positions) + + audio_positions = create_audio_position_grid(1, audio_frames) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype) + mx.eval(audio_positions, audio_latents) + + # Apply I2V conditioning for stage 1 + state1 = None + stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w) + if is_i2v and stage1_image_latent is not None: + state1 = LatentState( + latent=mx.zeros(stage1_shape, dtype=model_dtype), + clean_latent=mx.zeros(stage1_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength) + state1 = apply_conditioning(state1, [conditioning]) + + noise = mx.random.normal(stage1_shape, dtype=model_dtype) + noise_scale = sigmas[0] + scaled_mask = state1.denoise_mask * noise_scale + state1 = LatentState( + latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=state1.clean_latent, + denoise_mask=state1.denoise_mask, + ) + latents = state1.latent + mx.eval(latents) + else: + latents = mx.random.normal(stage1_shape, dtype=model_dtype) + mx.eval(latents) + + # Stage 1: res_2s with CFG (STG disabled for HQ by default) + latents, audio_latents = denoise_res2s_av( + latents, audio_latents, + positions, audio_positions, + video_embeddings_pos, video_embeddings_neg, + audio_embeddings_pos, audio_embeddings_neg, + transformer, sigmas, cfg_scale=cfg_scale, + audio_cfg_scale=audio_cfg_scale, + cfg_rescale=hq_cfg_rescale, audio_cfg_rescale=1.0, + verbose=verbose, video_state=state1, + stg_scale=stg_scale, stg_video_blocks=stg_blocks, + stg_audio_blocks=stg_blocks, modality_scale=modality_scale, + noise_seed=seed, + ) + + mx.eval(audio_latents) + + # Upsample latents 2x + with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"): + upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors")) + if not upscaler_files: + raise FileNotFoundError(f"No spatial upscaler found in {model_path}") + upsampler = load_upsampler(str(upscaler_files[0])) + mx.eval(upsampler.parameters()) + + vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder")) + + latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std) + mx.eval(latents) + + del upsampler + mx.clear_cache() + console.print("[green]✓[/] Latents upsampled") + + # Merge additional LoRA for stage 2 (additive: 0.25 + 0.25 = 0.5 total) + if lora_path is not None: + additional_strength = hq_lora_strength_s2 - hq_lora_strength_s1 + if additional_strength > 0: + with console.status(f"[blue]Adjusting LoRA (stage 2, total={hq_lora_strength_s2})...[/]", spinner="dots"): + load_and_merge_lora(transformer, lora_path, strength=additional_strength) + + # Stage 2: res_2s refinement at full resolution (no CFG) + console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)") + positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) + mx.eval(positions) + + state2 = None + if is_i2v and stage2_image_latent is not None: + state2 = LatentState( + latent=latents, + clean_latent=mx.zeros_like(latents), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), + ) + conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength) + state2 = apply_conditioning(state2, [conditioning]) + + noise = mx.random.normal(latents.shape).astype(model_dtype) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + scaled_mask = state2.denoise_mask * noise_scale + state2 = LatentState( + latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), + clean_latent=state2.clean_latent, + denoise_mask=state2.denoise_mask, + ) + latents = state2.latent + mx.eval(latents) + else: + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) + noise = mx.random.normal(latents.shape).astype(model_dtype) + latents = noise * noise_scale + latents * one_minus_scale + mx.eval(latents) + + # Re-noise audio at sigma=0.909375 for joint refinement + if audio_latents is not None: + audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype) + audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale) + mx.eval(audio_latents) + + # Stage 2: res_2s with no CFG (positive embeddings only) + stage2_sigmas = mx.array(STAGE_2_SIGMAS, dtype=mx.float32) + latents, audio_latents = denoise_res2s_av( + latents, audio_latents, + positions, audio_positions, + video_embeddings_pos, video_embeddings_pos, # both pos (no neg for stage 2) + audio_embeddings_pos, audio_embeddings_pos, + transformer, stage2_sigmas, cfg_scale=1.0, # no CFG + audio_cfg_scale=1.0, + cfg_rescale=0.0, verbose=verbose, video_state=state2, + noise_seed=seed + 1, + ) + del transformer mx.clear_cache() @@ -1857,8 +2369,8 @@ Examples: ) parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate") - parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage"], - help="Pipeline type: distilled (two-stage, fast), dev (single-stage, CFG), or dev-two-stage (dev + LoRA refinement)") + parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage", "dev-two-stage-hq"], + help="Pipeline type: distilled (fast), dev (CFG), dev-two-stage (dev + LoRA), dev-two-stage-hq (res_2s + LoRA both stages)") parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, help="Negative prompt for CFG (dev pipeline only)") parser.add_argument("--height", "-H", type=int, default=512, help="Output video height") @@ -1895,12 +2407,15 @@ Examples: parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)") parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)") parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)") + parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)") + parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)") args = parser.parse_args() pipeline_map = { "distilled": PipelineType.DISTILLED, "dev": PipelineType.DEV, "dev-two-stage": PipelineType.DEV_TWO_STAGE, + "dev-two-stage-hq": PipelineType.DEV_TWO_STAGE_HQ, } pipeline = pipeline_map[args.pipeline] @@ -1940,6 +2455,8 @@ Examples: modality_scale=args.modality_scale, lora_path=args.lora_path, lora_strength=args.lora_strength, + lora_strength_stage_1=args.lora_strength_stage_1, + lora_strength_stage_2=args.lora_strength_stage_2, ) diff --git a/mlx_video/samplers.py b/mlx_video/samplers.py new file mode 100644 index 0000000..489780b --- /dev/null +++ b/mlx_video/samplers.py @@ -0,0 +1,181 @@ +"""Second-order res_2s sampler for diffusion models. + +Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE +noise injection, ported from the LTX-2 PyTorch implementation. +""" + +import math +from typing import Optional + +import mlx.core as mx + + +# --------------------------------------------------------------------------- +# Phi functions and RK coefficients (pure Python math, no MLX needed) +# --------------------------------------------------------------------------- + +def phi(j: int, neg_h: float) -> float: + """Compute phi_j(z) where z = -h (negative step size in log-space). + + phi_1(z) = (e^z - 1) / z + phi_2(z) = (e^z - 1 - z) / z^2 + phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j + """ + if abs(neg_h) < 1e-10: + return 1.0 / math.factorial(j) + + remainder = sum(neg_h**k / math.factorial(k) for k in range(j)) + return (math.exp(neg_h) - remainder) / (neg_h**j) + + +def get_res2s_coefficients( + h: float, + phi_cache: dict, + c2: float = 0.5, +) -> tuple[float, float, float]: + """Compute res_2s Runge-Kutta coefficients for a given step size. + + Args: + h: Step size in log-space = log(sigma / sigma_next) + phi_cache: Dictionary to cache phi function results. + c2: Substep position (default 0.5 = midpoint) + + Returns: + (a21, b1, b2): RK coefficients. + """ + def get_phi(j: int, neg_h: float) -> float: + cache_key = (j, neg_h) + if cache_key in phi_cache: + return phi_cache[cache_key] + result = phi(j, neg_h) + phi_cache[cache_key] = result + return result + + neg_h_c2 = -h * c2 + phi_1_c2 = get_phi(1, neg_h_c2) + a21 = c2 * phi_1_c2 + + neg_h_full = -h + phi_2_full = get_phi(2, neg_h_full) + b2 = phi_2_full / c2 + + phi_1_full = get_phi(1, neg_h_full) + b1 = phi_1_full - b2 + + return a21, b1, b2 + + +# --------------------------------------------------------------------------- +# SDE noise injection +# --------------------------------------------------------------------------- + +def get_sde_coeff( + sigma_next: float, +) -> tuple[float, float, float]: + """Compute SDE coefficients for variance-preserving noise injection. + + Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep). + + Returns: + (alpha_ratio, sigma_down, sigma_up) + """ + sigma_up = sigma_next * 0.5 + # Clamp sigma_up to avoid sqrt(negative) + sigma_up = min(sigma_up, sigma_next * 0.9999) + + sigma_signal = 1.0 - sigma_next # sigma_max=1 + sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0)) + alpha_ratio = sigma_signal + sigma_residual + + if alpha_ratio == 0: + sigma_down = sigma_next + else: + sigma_down = sigma_residual / alpha_ratio + + # Handle NaN edge cases + if math.isnan(sigma_up): + sigma_up = 0.0 + if math.isnan(sigma_down): + sigma_down = sigma_next + if math.isnan(alpha_ratio): + alpha_ratio = 1.0 + + return alpha_ratio, sigma_down, sigma_up + + +def sde_noise_step( + sample: mx.array, + denoised_sample: mx.array, + sigma: float, + sigma_next: float, + noise: mx.array, +) -> mx.array: + """Apply SDE noise injection step. + + Advances sample from sigma to sigma_next with stochastic noise injection. + + Args: + sample: Current sample (anchor point) + denoised_sample: Denoised prediction at this step + sigma: Current noise level + sigma_next: Next noise level + noise: Pre-generated noise tensor (channel-wise normalized) + + Returns: + Noised sample at sigma_next + """ + alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next) + + if sigma_up == 0 or sigma_next == 0: + return denoised_sample + + # Float32 arithmetic + sample_f32 = sample.astype(mx.float32) + denoised_f32 = denoised_sample.astype(mx.float32) + noise_f32 = noise.astype(mx.float32) + + # Extract epsilon prediction + eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next) + denoised_next = sample_f32 - sigma * eps_next + + # Mix deterministic and stochastic components + x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32 + + return x_noised + + +# --------------------------------------------------------------------------- +# Noise generation +# --------------------------------------------------------------------------- + +def channelwise_normalize(x: mx.array) -> mx.array: + """Normalize each channel to zero mean and unit variance over spatial dims. + + Operates on the last 2 dimensions (spatial H, W or time, freq). + """ + mean = mx.mean(x, axis=(-2, -1), keepdims=True) + x = x - mean + std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8) + x = x / std + return x + + +def get_new_noise(shape: tuple, key: mx.array) -> mx.array: + """Generate channel-wise normalized Gaussian noise. + + PyTorch uses float64; we use float32 (MLX doesn't support float64). + The channel-wise normalization is the key quality-affecting step. + + Args: + shape: Shape of the noise tensor + key: MLX random key for deterministic generation + + Returns: + Channel-wise normalized noise in float32 + """ + noise = mx.random.normal(shape, dtype=mx.float32, key=key) + # Global normalization + noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8) + # Channel-wise normalization + noise = channelwise_normalize(noise) + return noise