Add Dev Two-Stage HQ pipeline mode

This commit is contained in:
Prince Canuma
2026-03-16 00:34:13 +01:00
parent df81bc852f
commit f53b9e0807
3 changed files with 739 additions and 12 deletions

View File

@@ -24,7 +24,7 @@ Supported models:
## Features
- Text-to-video (T2V) and Image-to-video (I2V) generation
- Three pipeline modes: Distilled, Dev, and Dev Two-Stage
- Four pipeline modes: Distilled, Dev, Dev Two-Stage, and Dev Two-Stage HQ
- Synchronized audio-video generation (experimental)
- LoRA support (including HuggingFace repos)
- Prompt enhancement via Gemma
@@ -35,13 +35,14 @@ Supported models:
### Pipelines
mlx-video supports three pipeline types via the `--pipeline` flag:
mlx-video supports four pipeline types via the `--pipeline` flag:
| Pipeline | Description | CFG | Stages | Speed |
|----------|-------------|-----|--------|-------|
| `distilled` (default) | Fixed sigma schedule, no CFG | No | 2 (8+3 steps) | Fastest |
| `dev` | Dynamic sigmas, constant CFG | Yes | 1 (30 steps) | Medium |
| `dev-two-stage` | Dev + LoRA refinement | Yes (stage 1) | 2 (30+3 steps) | Slowest, highest quality |
| `dev-two-stage` | Dev + LoRA refinement | Yes (stage 1) | 2 (30+3 steps) | Slow |
| `dev-two-stage-hq` | res_2s sampler + LoRA both stages | Yes (stage 1) | 2 (15+3 steps) | Slow, highest quality |
### Text-to-Video
@@ -52,13 +53,24 @@ uv run mlx_video.generate --prompt "Two dogs wearing sunglasses, cinematic, suns
# Dev - single-stage with CFG
uv run mlx_video.generate --pipeline dev --prompt "A cinematic scene" --cfg-scale 3.0
# Dev two-stage - dev + LoRA refinement (highest quality)
# Dev two-stage - dev + LoRA refinement
uv run mlx_video.generate --pipeline dev-two-stage \
--prompt "Two dogs of the poodle breed wearing sunglasses, close up, cinematic, sunset" \
-n 145 --width 1024 --height 768 \
--model-repo prince-canuma/LTX-2-dev \
--cfg-scale 3.0 --lora-strength 0.8 \
--enhance-prompt
# Dev two-stage HQ - res_2s sampler, LoRA both stages (highest quality)
uv run mlx_video.generate --pipeline dev-two-stage-hq \
--prompt "A cinematic scene of ocean waves at golden hour" \
--model-repo prince-canuma/LTX-2-dev
# HQ with custom LoRA strengths
uv run mlx_video.generate --pipeline dev-two-stage-hq \
--prompt "A sunset over mountains" \
--model-repo prince-canuma/LTX-2-dev \
--lora-strength-stage-1 0.3 --lora-strength-stage-2 0.6
```
<img src="https://github.com/Blaizzy/mlx-video/raw/main/examples/poodles.gif" width="512" alt="Poodles demo">
@@ -124,7 +136,7 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
| Option | Default | Description |
|--------|---------|-------------|
| `--prompt`, `-p` | (required) | Text description of the video |
| `--pipeline` | `distilled` | Pipeline type: `distilled`, `dev`, or `dev-two-stage` |
| `--pipeline` | `distilled` | Pipeline type: `distilled`, `dev`, `dev-two-stage`, or `dev-two-stage-hq` |
| `--height`, `-H` | 512 | Output height (divisible by 64 for two-stage, 32 for dev) |
| `--width`, `-W` | 512 | Output width (divisible by 64 for two-stage, 32 for dev) |
| `--num-frames`, `-n` | 33 | Number of frames (must be 1 + 8*k) |
@@ -161,6 +173,15 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
| `--lora-path` | auto-detect | Path to LoRA file, directory, or HuggingFace repo |
| `--lora-strength` | 1.0 | LoRA merge strength |
**Dev-Two-Stage HQ options:**
| Option | Default | Description |
|--------|---------|-------------|
| `--lora-strength-stage-1` | 0.25 | LoRA strength for stage 1 |
| `--lora-strength-stage-2` | 0.5 | LoRA strength for stage 2 |
HQ defaults: 15 steps (vs 30), `cfg-rescale` 0.45 (vs 0.7), STG disabled. Uses the res_2s second-order sampler (2 model evals per step) for better quality at the same compute budget.
## How It Works
### Distilled Pipeline (default)
@@ -179,6 +200,14 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
3. **Stage 2**: Distilled refinement at full resolution with LoRA weights (3 steps, no CFG)
4. **Decode**: VAE decoder converts latents to RGB video
### Dev Two-Stage HQ Pipeline
1. **Stage 1**: res_2s denoising at half resolution with CFG + LoRA@0.25 (15 steps, 2 evals/step)
2. **Upsample**: 2x spatial upsampling via LatentUpsampler
3. **Stage 2**: res_2s refinement at full resolution with LoRA@0.5 (3 steps, no CFG)
4. **Decode**: VAE decoder converts latents to RGB video
The res_2s sampler uses an exponential Rosenbrock-type Runge-Kutta integrator with SDE noise injection, producing higher quality results than Euler at the same compute budget (~30 total model evaluations).
## Requirements
- macOS with Apple Silicon

View File

@@ -38,6 +38,7 @@ class PipelineType(Enum):
DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG
DEV = "dev" # Single-stage, dynamic sigmas, CFG
DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res)
DEV_TWO_STAGE_HQ = "dev-two-stage-hq" # Two-stage: res_2s sampler, LoRA both stages
# Distilled model sigma schedules
@@ -1012,6 +1013,329 @@ def denoise_dev_av(
return video_latents, audio_latents
def denoise_res2s_av(
video_latents: mx.array,
audio_latents: mx.array,
video_positions: mx.array,
audio_positions: mx.array,
video_embeddings_pos: mx.array,
video_embeddings_neg: mx.array,
audio_embeddings_pos: mx.array,
audio_embeddings_neg: mx.array,
transformer: LTXModel,
sigmas: mx.array,
cfg_scale: float = 3.0,
audio_cfg_scale: float = 7.0,
cfg_rescale: float = 0.45,
audio_cfg_rescale: Optional[float] = None,
verbose: bool = True,
video_state: Optional[LatentState] = None,
stg_scale: float = 0.0,
stg_video_blocks: Optional[list] = None,
stg_audio_blocks: Optional[list] = None,
modality_scale: float = 1.0,
noise_seed: int = 42,
bongmath: bool = True,
bongmath_max_iter: int = 100,
) -> tuple[mx.array, mx.array]:
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
Two model evaluations per step (current point + midpoint), with SDE noise
injection and optional bong iteration for anchor refinement.
Args:
audio_cfg_rescale: Separate rescale for audio. If None, uses cfg_rescale.
noise_seed: Seed for SDE noise generators.
bongmath: Enable iterative anchor refinement.
bongmath_max_iter: Max bong iterations per step.
"""
from mlx_video.models.ltx.rope import precompute_freqs_cis
from mlx_video.samplers import get_res2s_coefficients, sde_noise_step, get_new_noise
if audio_cfg_rescale is None:
audio_cfg_rescale = cfg_rescale
dtype = video_latents.dtype
if video_state is not None:
video_latents = video_state.latent
video_latents = video_latents.astype(mx.float32)
audio_latents = audio_latents.astype(mx.float32)
sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0
use_stg = stg_scale != 0.0 and stg_video_blocks is not None
use_modality = modality_scale != 1.0
n_full_steps = len(sigmas_list) - 1
# Pad sigmas if last is 0 (avoid division by zero in RK steps)
if sigmas_list[-1] == 0:
sigmas_list = sigmas_list[:-1] + [0.0011, 0.0]
# Compute step sizes in log-space for the main loop steps only.
# After padding, sigmas_list may have an extra [0.0011, 0.0] tail;
# we only need hs for the n_full_steps pairs the loop actually uses.
hs = [-math.log(sigmas_list[i + 1] / sigmas_list[i]) for i in range(n_full_steps)]
# Precompute RoPE
precomputed_video_rope = precompute_freqs_cis(
video_positions,
dim=transformer.inner_dim,
theta=transformer.positional_embedding_theta,
max_pos=transformer.positional_embedding_max_pos,
use_middle_indices_grid=transformer.use_middle_indices_grid,
num_attention_heads=transformer.num_attention_heads,
rope_type=transformer.rope_type,
double_precision=transformer.config.double_precision_rope,
)
precomputed_audio_rope = precompute_freqs_cis(
audio_positions,
dim=transformer.audio_inner_dim,
theta=transformer.positional_embedding_theta,
max_pos=transformer.audio_positional_embedding_max_pos,
use_middle_indices_grid=transformer.use_middle_indices_grid,
num_attention_heads=transformer.audio_num_attention_heads,
rope_type=transformer.rope_type,
double_precision=transformer.config.double_precision_rope,
)
mx.eval(precomputed_video_rope, precomputed_audio_rope)
phi_cache = {}
c2 = 0.5
# Noise key management: step noise and substep noise use different keys
step_noise_key = mx.random.key(noise_seed)
substep_noise_key = mx.random.key(noise_seed + 10000)
def _eval_guided_denoise(v_latents, a_latents, sigma):
"""Run all guidance passes and return (video_denoised, audio_denoised) in float32 spatial format."""
b, c, f, h, w = v_latents.shape
num_video_tokens = f * h * w
video_flat = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
ab, ac, at, af = a_latents.shape
audio_flat = mx.transpose(a_latents, (0, 2, 1, 3))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
# Timesteps
if video_state is not None:
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
sigma_array = mx.full((b,), sigma, dtype=dtype)
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
# Pass 1: Positive conditioning
video_modality_pos = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_pos, context_mask=None, enabled=True,
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
)
audio_modality_pos = Modality(
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
context=audio_embeddings_pos, context_mask=None, enabled=True,
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
)
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
mx.eval(video_vel_pos, audio_vel_pos)
# Convert velocity to x0
video_flat_f32 = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1))
audio_flat_f32 = mx.reshape(mx.transpose(a_latents, (0, 2, 1, 3)), (ab, at, ac * af))
video_ts_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
audio_ts_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1)
video_x0_pos = video_flat_f32 - video_ts_f32 * video_vel_pos.astype(mx.float32)
audio_x0_pos = audio_flat_f32 - audio_ts_f32 * audio_vel_pos.astype(mx.float32)
video_x0_guided = video_x0_pos
audio_x0_guided = audio_x0_pos
# Pass 2: CFG
if use_cfg:
video_modality_neg = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_neg, context_mask=None, enabled=True,
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
)
audio_modality_neg = Modality(
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
context=audio_embeddings_neg, context_mask=None, enabled=True,
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
)
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
mx.eval(video_vel_neg, audio_vel_neg)
video_x0_neg = video_flat_f32 - video_ts_f32 * video_vel_neg.astype(mx.float32)
audio_x0_neg = audio_flat_f32 - audio_ts_f32 * audio_vel_neg.astype(mx.float32)
video_x0_guided = video_x0_pos + (cfg_scale - 1.0) * (video_x0_pos - video_x0_neg)
audio_x0_guided = audio_x0_pos + (audio_cfg_scale - 1.0) * (audio_x0_pos - audio_x0_neg)
# Pass 3: STG
if use_stg:
video_vel_ptb, audio_vel_ptb = transformer(
video=video_modality_pos, audio=audio_modality_pos,
stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks,
)
mx.eval(video_vel_ptb, audio_vel_ptb)
video_x0_ptb = video_flat_f32 - video_ts_f32 * video_vel_ptb.astype(mx.float32)
audio_x0_ptb = audio_flat_f32 - audio_ts_f32 * audio_vel_ptb.astype(mx.float32)
video_x0_guided = video_x0_guided + stg_scale * (video_x0_pos - video_x0_ptb)
audio_x0_guided = audio_x0_guided + stg_scale * (audio_x0_pos - audio_x0_ptb)
# Pass 4: Modality isolation
if use_modality:
video_vel_iso, audio_vel_iso = transformer(
video=video_modality_pos, audio=audio_modality_pos,
skip_cross_modal=True,
)
mx.eval(video_vel_iso, audio_vel_iso)
video_x0_iso = video_flat_f32 - video_ts_f32 * video_vel_iso.astype(mx.float32)
audio_x0_iso = audio_flat_f32 - audio_ts_f32 * audio_vel_iso.astype(mx.float32)
video_x0_guided = video_x0_guided + (modality_scale - 1.0) * (video_x0_pos - video_x0_iso)
audio_x0_guided = audio_x0_guided + (modality_scale - 1.0) * (audio_x0_pos - audio_x0_iso)
# Rescale (separate factors for video and audio)
if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
v_factor = video_x0_pos.std() / (video_x0_guided.std() + 1e-8)
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
video_x0_guided = video_x0_guided * v_factor
if audio_cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
a_factor = audio_x0_pos.std() / (audio_x0_guided.std() + 1e-8)
a_factor = audio_cfg_rescale * a_factor + (1.0 - audio_cfg_rescale)
audio_x0_guided = audio_x0_guided * a_factor
# Reshape to spatial
video_denoised = mx.reshape(mx.transpose(video_x0_guided, (0, 2, 1)), (b, c, f, h, w))
audio_denoised = mx.reshape(audio_x0_guided, (ab, at, ac, af))
audio_denoised = mx.transpose(audio_denoised, (0, 2, 1, 3))
# Post-process with mask
if video_state is not None:
clean_f32 = video_state.clean_latent.astype(mx.float32)
mask_f32 = video_state.denoise_mask.astype(mx.float32)
video_denoised = video_denoised * mask_f32 + clean_f32 * (1.0 - mask_f32)
mx.eval(video_denoised, audio_denoised)
return video_denoised, audio_denoised
# Main res_2s loop
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
console=console,
disable=not verbose,
) as progress:
passes = ["res2s"]
if use_cfg: passes.append("CFG")
if use_stg: passes.append("STG")
if use_modality: passes.append("Mod")
label = "+".join(passes)
task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=n_full_steps)
for step_idx in range(n_full_steps):
sigma = sigmas_list[step_idx]
sigma_next = sigmas_list[step_idx + 1]
h = hs[step_idx]
# Initialize anchor
x_anchor_video = video_latents
x_anchor_audio = audio_latents
# ============================================================
# Stage 1: Evaluate denoiser at current sigma
# ============================================================
denoised_video_1, denoised_audio_1 = _eval_guided_denoise(
video_latents, audio_latents, sigma
)
# RK coefficients
a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2)
# Substep sigma (geometric midpoint for c2=0.5)
sub_sigma = math.sqrt(sigma * sigma_next)
# Compute midpoint
eps_1_video = denoised_video_1 - x_anchor_video
eps_1_audio = denoised_audio_1 - x_anchor_audio
x_mid_video = x_anchor_video + h * a21 * eps_1_video
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
# SDE noise injection at substep
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
substep_noise_v = get_new_noise(video_latents.shape, key1)
substep_noise_a = get_new_noise(audio_latents.shape, key2)
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
mx.eval(x_mid_video, x_mid_audio)
# ============================================================
# Bong iteration: refine anchor (pure arithmetic, no model calls)
# ============================================================
if bongmath and h < 0.5 and sigma > 0.03:
for _ in range(bongmath_max_iter):
x_anchor_video = x_mid_video - h * a21 * eps_1_video
eps_1_video = denoised_video_1 - x_anchor_video
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
eps_1_audio = denoised_audio_1 - x_anchor_audio
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
# ============================================================
# Stage 2: Evaluate denoiser at midpoint sigma
# ============================================================
denoised_video_2, denoised_audio_2 = _eval_guided_denoise(
x_mid_video.astype(mx.float32), x_mid_audio.astype(mx.float32), sub_sigma
)
# ============================================================
# Final combination with RK coefficients
# ============================================================
eps_2_video = denoised_video_2 - x_anchor_video
eps_2_audio = denoised_audio_2 - x_anchor_audio
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
# SDE noise injection at step level
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
step_noise_v = get_new_noise(video_latents.shape, key1)
step_noise_a = get_new_noise(audio_latents.shape, key2)
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
video_latents = x_next_video.astype(mx.float32)
audio_latents = x_next_audio.astype(mx.float32)
mx.eval(video_latents, audio_latents)
progress.advance(task)
# Final clean step if original schedule ended at 0
if sigmas.tolist()[-1] == 0:
denoised_video, denoised_audio = _eval_guided_denoise(
video_latents, audio_latents, sigmas_list[n_full_steps]
)
video_latents = denoised_video
audio_latents = denoised_audio
mx.eval(video_latents, audio_latents)
return video_latents, audio_latents
# =============================================================================
# Audio Loading and Processing
# =============================================================================
@@ -1117,13 +1441,16 @@ def generate_video(
modality_scale: float = 1.0,
lora_path: Optional[str] = None,
lora_strength: float = 1.0,
lora_strength_stage_1: Optional[float] = None,
lora_strength_stage_2: Optional[float] = None,
):
"""Generate video using LTX-2 models.
Supports three pipelines:
Supports four pipelines:
- DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG
- DEV: Single-stage generation with dynamic sigmas and CFG
- DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG)
- DEV_TWO_STAGE_HQ: res_2s sampler, LoRA both stages (0.25/0.5), lower rescale
Args:
model_repo: Model repository ID
@@ -1158,7 +1485,7 @@ def generate_video(
start_time = time.time()
# Validate dimensions
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE)
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ)
divisor = 64 if is_two_stage else 32
assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}"
assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}"
@@ -1177,13 +1504,14 @@ def generate_video(
PipelineType.DISTILLED: "DISTILLED",
PipelineType.DEV: "DEV",
PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE",
PipelineType.DEV_TWO_STAGE_HQ: "DEV-TWO-STAGE-HQ",
}
pipeline_name = pipeline_names[pipeline]
header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height}{num_frames} frames[/]"
console.print(Panel(header, expand=False))
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else ""
stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else ""
mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else ""
@@ -1237,14 +1565,14 @@ def generate_video(
# Encode prompts - always get audio embeddings since the model was trained
# with joint audio-video processing (PyTorch unconditionally generates audio)
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
if pipeline == PipelineType.DEV_TWO_STAGE:
if pipeline in (PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
text_embeddings = video_embeddings_pos
else:
# Distilled pipeline - single embedding
@@ -1655,6 +1983,190 @@ def generate_video(
audio_embeddings=audio_embeddings_pos,
)
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
# ======================================================================
# DEV TWO-STAGE HQ PIPELINE:
# Stage 1: res_2s denoising at half resolution with CFG + LoRA@0.25
# Upsample: 2x spatial via LatentUpsampler
# Stage 2: res_2s refinement at full resolution with LoRA@0.5, no CFG
# ======================================================================
# HQ defaults
hq_lora_strength_s1 = lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25
hq_lora_strength_s2 = lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5
hq_cfg_rescale = cfg_rescale if cfg_rescale != 0.7 else 0.45 # Override default 0.7 → 0.45
hq_steps = num_inference_steps if num_inference_steps != 30 else 15 # Override default 30 → 15
# Load VAE encoder for I2V
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
del vae_encoder
mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Auto-detect and merge LoRA for stage 1 (strength 0.25)
if lora_path is None:
lora_files = sorted(model_path.glob("*distilled-lora*.safetensors"))
if lora_files:
lora_path = str(lora_files[0])
console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]")
else:
console.print("[yellow]Warning: No LoRA file found. HQ pipeline works best with distilled LoRA.[/]")
if lora_path is not None:
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
# Stage 1: res_2s denoising at half resolution with CFG
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
num_tokens = latent_frames * stage1_h * stage1_w
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
mx.eval(sigmas)
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
mx.random.seed(seed)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning for stage 1
state1 = None
stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w)
if is_i2v and stage1_image_latent is not None:
state1 = LatentState(
latent=mx.zeros(stage1_shape, dtype=model_dtype),
clean_latent=mx.zeros(stage1_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength)
state1 = apply_conditioning(state1, [conditioning])
noise = mx.random.normal(stage1_shape, dtype=model_dtype)
noise_scale = sigmas[0]
scaled_mask = state1.denoise_mask * noise_scale
state1 = LatentState(
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state1.clean_latent,
denoise_mask=state1.denoise_mask,
)
latents = state1.latent
mx.eval(latents)
else:
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
mx.eval(latents)
# Stage 1: res_2s with CFG (STG disabled for HQ by default)
latents, audio_latents = denoise_res2s_av(
latents, audio_latents,
positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
audio_cfg_scale=audio_cfg_scale,
cfg_rescale=hq_cfg_rescale, audio_cfg_rescale=1.0,
verbose=verbose, video_state=state1,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
noise_seed=seed,
)
mx.eval(audio_latents)
# Upsample latents 2x
with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"):
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if not upscaler_files:
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
upsampler = load_upsampler(str(upscaler_files[0]))
mx.eval(upsampler.parameters())
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
mx.eval(latents)
del upsampler
mx.clear_cache()
console.print("[green]✓[/] Latents upsampled")
# Merge additional LoRA for stage 2 (additive: 0.25 + 0.25 = 0.5 total)
if lora_path is not None:
additional_strength = hq_lora_strength_s2 - hq_lora_strength_s1
if additional_strength > 0:
with console.status(f"[blue]Adjusting LoRA (stage 2, total={hq_lora_strength_s2})...[/]", spinner="dots"):
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
# Stage 2: res_2s refinement at full resolution (no CFG)
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
state2 = None
if is_i2v and stage2_image_latent is not None:
state2 = LatentState(
latent=latents,
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength)
state2 = apply_conditioning(state2, [conditioning])
noise = mx.random.normal(latents.shape).astype(model_dtype)
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = state2.denoise_mask * noise_scale
state2 = LatentState(
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state2.clean_latent,
denoise_mask=state2.denoise_mask,
)
latents = state2.latent
mx.eval(latents)
else:
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
noise = mx.random.normal(latents.shape).astype(model_dtype)
latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement
if audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
mx.eval(audio_latents)
# Stage 2: res_2s with no CFG (positive embeddings only)
stage2_sigmas = mx.array(STAGE_2_SIGMAS, dtype=mx.float32)
latents, audio_latents = denoise_res2s_av(
latents, audio_latents,
positions, audio_positions,
video_embeddings_pos, video_embeddings_pos, # both pos (no neg for stage 2)
audio_embeddings_pos, audio_embeddings_pos,
transformer, stage2_sigmas, cfg_scale=1.0, # no CFG
audio_cfg_scale=1.0,
cfg_rescale=0.0, verbose=verbose, video_state=state2,
noise_seed=seed + 1,
)
del transformer
mx.clear_cache()
@@ -1857,8 +2369,8 @@ Examples:
)
parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate")
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage"],
help="Pipeline type: distilled (two-stage, fast), dev (single-stage, CFG), or dev-two-stage (dev + LoRA refinement)")
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage", "dev-two-stage-hq"],
help="Pipeline type: distilled (fast), dev (CFG), dev-two-stage (dev + LoRA), dev-two-stage-hq (res_2s + LoRA both stages)")
parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT,
help="Negative prompt for CFG (dev pipeline only)")
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
@@ -1895,12 +2407,15 @@ Examples:
parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)")
parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)")
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
args = parser.parse_args()
pipeline_map = {
"distilled": PipelineType.DISTILLED,
"dev": PipelineType.DEV,
"dev-two-stage": PipelineType.DEV_TWO_STAGE,
"dev-two-stage-hq": PipelineType.DEV_TWO_STAGE_HQ,
}
pipeline = pipeline_map[args.pipeline]
@@ -1940,6 +2455,8 @@ Examples:
modality_scale=args.modality_scale,
lora_path=args.lora_path,
lora_strength=args.lora_strength,
lora_strength_stage_1=args.lora_strength_stage_1,
lora_strength_stage_2=args.lora_strength_stage_2,
)

181
mlx_video/samplers.py Normal file
View File

@@ -0,0 +1,181 @@
"""Second-order res_2s sampler for diffusion models.
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
noise injection, ported from the LTX-2 PyTorch implementation.
"""
import math
from typing import Optional
import mlx.core as mx
# ---------------------------------------------------------------------------
# Phi functions and RK coefficients (pure Python math, no MLX needed)
# ---------------------------------------------------------------------------
def phi(j: int, neg_h: float) -> float:
"""Compute phi_j(z) where z = -h (negative step size in log-space).
phi_1(z) = (e^z - 1) / z
phi_2(z) = (e^z - 1 - z) / z^2
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
"""
if abs(neg_h) < 1e-10:
return 1.0 / math.factorial(j)
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
return (math.exp(neg_h) - remainder) / (neg_h**j)
def get_res2s_coefficients(
h: float,
phi_cache: dict,
c2: float = 0.5,
) -> tuple[float, float, float]:
"""Compute res_2s Runge-Kutta coefficients for a given step size.
Args:
h: Step size in log-space = log(sigma / sigma_next)
phi_cache: Dictionary to cache phi function results.
c2: Substep position (default 0.5 = midpoint)
Returns:
(a21, b1, b2): RK coefficients.
"""
def get_phi(j: int, neg_h: float) -> float:
cache_key = (j, neg_h)
if cache_key in phi_cache:
return phi_cache[cache_key]
result = phi(j, neg_h)
phi_cache[cache_key] = result
return result
neg_h_c2 = -h * c2
phi_1_c2 = get_phi(1, neg_h_c2)
a21 = c2 * phi_1_c2
neg_h_full = -h
phi_2_full = get_phi(2, neg_h_full)
b2 = phi_2_full / c2
phi_1_full = get_phi(1, neg_h_full)
b1 = phi_1_full - b2
return a21, b1, b2
# ---------------------------------------------------------------------------
# SDE noise injection
# ---------------------------------------------------------------------------
def get_sde_coeff(
sigma_next: float,
) -> tuple[float, float, float]:
"""Compute SDE coefficients for variance-preserving noise injection.
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
Returns:
(alpha_ratio, sigma_down, sigma_up)
"""
sigma_up = sigma_next * 0.5
# Clamp sigma_up to avoid sqrt(negative)
sigma_up = min(sigma_up, sigma_next * 0.9999)
sigma_signal = 1.0 - sigma_next # sigma_max=1
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
alpha_ratio = sigma_signal + sigma_residual
if alpha_ratio == 0:
sigma_down = sigma_next
else:
sigma_down = sigma_residual / alpha_ratio
# Handle NaN edge cases
if math.isnan(sigma_up):
sigma_up = 0.0
if math.isnan(sigma_down):
sigma_down = sigma_next
if math.isnan(alpha_ratio):
alpha_ratio = 1.0
return alpha_ratio, sigma_down, sigma_up
def sde_noise_step(
sample: mx.array,
denoised_sample: mx.array,
sigma: float,
sigma_next: float,
noise: mx.array,
) -> mx.array:
"""Apply SDE noise injection step.
Advances sample from sigma to sigma_next with stochastic noise injection.
Args:
sample: Current sample (anchor point)
denoised_sample: Denoised prediction at this step
sigma: Current noise level
sigma_next: Next noise level
noise: Pre-generated noise tensor (channel-wise normalized)
Returns:
Noised sample at sigma_next
"""
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
if sigma_up == 0 or sigma_next == 0:
return denoised_sample
# Float32 arithmetic
sample_f32 = sample.astype(mx.float32)
denoised_f32 = denoised_sample.astype(mx.float32)
noise_f32 = noise.astype(mx.float32)
# Extract epsilon prediction
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
denoised_next = sample_f32 - sigma * eps_next
# Mix deterministic and stochastic components
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
return x_noised
# ---------------------------------------------------------------------------
# Noise generation
# ---------------------------------------------------------------------------
def channelwise_normalize(x: mx.array) -> mx.array:
"""Normalize each channel to zero mean and unit variance over spatial dims.
Operates on the last 2 dimensions (spatial H, W or time, freq).
"""
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
x = x - mean
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
x = x / std
return x
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
"""Generate channel-wise normalized Gaussian noise.
PyTorch uses float64; we use float32 (MLX doesn't support float64).
The channel-wise normalization is the key quality-affecting step.
Args:
shape: Shape of the noise tensor
key: MLX random key for deterministic generation
Returns:
Channel-wise normalized noise in float32
"""
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
# Global normalization
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
# Channel-wise normalization
noise = channelwise_normalize(noise)
return noise