Add Dev Two-Stage HQ pipeline mode

This commit is contained in:
Prince Canuma
2026-03-16 00:34:13 +01:00
parent df81bc852f
commit f53b9e0807
3 changed files with 739 additions and 12 deletions

View File

@@ -38,6 +38,7 @@ class PipelineType(Enum):
DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG
DEV = "dev" # Single-stage, dynamic sigmas, CFG
DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res)
DEV_TWO_STAGE_HQ = "dev-two-stage-hq" # Two-stage: res_2s sampler, LoRA both stages
# Distilled model sigma schedules
@@ -1012,6 +1013,329 @@ def denoise_dev_av(
return video_latents, audio_latents
def denoise_res2s_av(
video_latents: mx.array,
audio_latents: mx.array,
video_positions: mx.array,
audio_positions: mx.array,
video_embeddings_pos: mx.array,
video_embeddings_neg: mx.array,
audio_embeddings_pos: mx.array,
audio_embeddings_neg: mx.array,
transformer: LTXModel,
sigmas: mx.array,
cfg_scale: float = 3.0,
audio_cfg_scale: float = 7.0,
cfg_rescale: float = 0.45,
audio_cfg_rescale: Optional[float] = None,
verbose: bool = True,
video_state: Optional[LatentState] = None,
stg_scale: float = 0.0,
stg_video_blocks: Optional[list] = None,
stg_audio_blocks: Optional[list] = None,
modality_scale: float = 1.0,
noise_seed: int = 42,
bongmath: bool = True,
bongmath_max_iter: int = 100,
) -> tuple[mx.array, mx.array]:
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
Two model evaluations per step (current point + midpoint), with SDE noise
injection and optional bong iteration for anchor refinement.
Args:
audio_cfg_rescale: Separate rescale for audio. If None, uses cfg_rescale.
noise_seed: Seed for SDE noise generators.
bongmath: Enable iterative anchor refinement.
bongmath_max_iter: Max bong iterations per step.
"""
from mlx_video.models.ltx.rope import precompute_freqs_cis
from mlx_video.samplers import get_res2s_coefficients, sde_noise_step, get_new_noise
if audio_cfg_rescale is None:
audio_cfg_rescale = cfg_rescale
dtype = video_latents.dtype
if video_state is not None:
video_latents = video_state.latent
video_latents = video_latents.astype(mx.float32)
audio_latents = audio_latents.astype(mx.float32)
sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0
use_stg = stg_scale != 0.0 and stg_video_blocks is not None
use_modality = modality_scale != 1.0
n_full_steps = len(sigmas_list) - 1
# Pad sigmas if last is 0 (avoid division by zero in RK steps)
if sigmas_list[-1] == 0:
sigmas_list = sigmas_list[:-1] + [0.0011, 0.0]
# Compute step sizes in log-space for the main loop steps only.
# After padding, sigmas_list may have an extra [0.0011, 0.0] tail;
# we only need hs for the n_full_steps pairs the loop actually uses.
hs = [-math.log(sigmas_list[i + 1] / sigmas_list[i]) for i in range(n_full_steps)]
# Precompute RoPE
precomputed_video_rope = precompute_freqs_cis(
video_positions,
dim=transformer.inner_dim,
theta=transformer.positional_embedding_theta,
max_pos=transformer.positional_embedding_max_pos,
use_middle_indices_grid=transformer.use_middle_indices_grid,
num_attention_heads=transformer.num_attention_heads,
rope_type=transformer.rope_type,
double_precision=transformer.config.double_precision_rope,
)
precomputed_audio_rope = precompute_freqs_cis(
audio_positions,
dim=transformer.audio_inner_dim,
theta=transformer.positional_embedding_theta,
max_pos=transformer.audio_positional_embedding_max_pos,
use_middle_indices_grid=transformer.use_middle_indices_grid,
num_attention_heads=transformer.audio_num_attention_heads,
rope_type=transformer.rope_type,
double_precision=transformer.config.double_precision_rope,
)
mx.eval(precomputed_video_rope, precomputed_audio_rope)
phi_cache = {}
c2 = 0.5
# Noise key management: step noise and substep noise use different keys
step_noise_key = mx.random.key(noise_seed)
substep_noise_key = mx.random.key(noise_seed + 10000)
def _eval_guided_denoise(v_latents, a_latents, sigma):
"""Run all guidance passes and return (video_denoised, audio_denoised) in float32 spatial format."""
b, c, f, h, w = v_latents.shape
num_video_tokens = f * h * w
video_flat = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
ab, ac, at, af = a_latents.shape
audio_flat = mx.transpose(a_latents, (0, 2, 1, 3))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
# Timesteps
if video_state is not None:
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
sigma_array = mx.full((b,), sigma, dtype=dtype)
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
# Pass 1: Positive conditioning
video_modality_pos = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_pos, context_mask=None, enabled=True,
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
)
audio_modality_pos = Modality(
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
context=audio_embeddings_pos, context_mask=None, enabled=True,
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
)
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
mx.eval(video_vel_pos, audio_vel_pos)
# Convert velocity to x0
video_flat_f32 = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1))
audio_flat_f32 = mx.reshape(mx.transpose(a_latents, (0, 2, 1, 3)), (ab, at, ac * af))
video_ts_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
audio_ts_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1)
video_x0_pos = video_flat_f32 - video_ts_f32 * video_vel_pos.astype(mx.float32)
audio_x0_pos = audio_flat_f32 - audio_ts_f32 * audio_vel_pos.astype(mx.float32)
video_x0_guided = video_x0_pos
audio_x0_guided = audio_x0_pos
# Pass 2: CFG
if use_cfg:
video_modality_neg = Modality(
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
context=video_embeddings_neg, context_mask=None, enabled=True,
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
)
audio_modality_neg = Modality(
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
context=audio_embeddings_neg, context_mask=None, enabled=True,
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
)
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
mx.eval(video_vel_neg, audio_vel_neg)
video_x0_neg = video_flat_f32 - video_ts_f32 * video_vel_neg.astype(mx.float32)
audio_x0_neg = audio_flat_f32 - audio_ts_f32 * audio_vel_neg.astype(mx.float32)
video_x0_guided = video_x0_pos + (cfg_scale - 1.0) * (video_x0_pos - video_x0_neg)
audio_x0_guided = audio_x0_pos + (audio_cfg_scale - 1.0) * (audio_x0_pos - audio_x0_neg)
# Pass 3: STG
if use_stg:
video_vel_ptb, audio_vel_ptb = transformer(
video=video_modality_pos, audio=audio_modality_pos,
stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks,
)
mx.eval(video_vel_ptb, audio_vel_ptb)
video_x0_ptb = video_flat_f32 - video_ts_f32 * video_vel_ptb.astype(mx.float32)
audio_x0_ptb = audio_flat_f32 - audio_ts_f32 * audio_vel_ptb.astype(mx.float32)
video_x0_guided = video_x0_guided + stg_scale * (video_x0_pos - video_x0_ptb)
audio_x0_guided = audio_x0_guided + stg_scale * (audio_x0_pos - audio_x0_ptb)
# Pass 4: Modality isolation
if use_modality:
video_vel_iso, audio_vel_iso = transformer(
video=video_modality_pos, audio=audio_modality_pos,
skip_cross_modal=True,
)
mx.eval(video_vel_iso, audio_vel_iso)
video_x0_iso = video_flat_f32 - video_ts_f32 * video_vel_iso.astype(mx.float32)
audio_x0_iso = audio_flat_f32 - audio_ts_f32 * audio_vel_iso.astype(mx.float32)
video_x0_guided = video_x0_guided + (modality_scale - 1.0) * (video_x0_pos - video_x0_iso)
audio_x0_guided = audio_x0_guided + (modality_scale - 1.0) * (audio_x0_pos - audio_x0_iso)
# Rescale (separate factors for video and audio)
if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
v_factor = video_x0_pos.std() / (video_x0_guided.std() + 1e-8)
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
video_x0_guided = video_x0_guided * v_factor
if audio_cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
a_factor = audio_x0_pos.std() / (audio_x0_guided.std() + 1e-8)
a_factor = audio_cfg_rescale * a_factor + (1.0 - audio_cfg_rescale)
audio_x0_guided = audio_x0_guided * a_factor
# Reshape to spatial
video_denoised = mx.reshape(mx.transpose(video_x0_guided, (0, 2, 1)), (b, c, f, h, w))
audio_denoised = mx.reshape(audio_x0_guided, (ab, at, ac, af))
audio_denoised = mx.transpose(audio_denoised, (0, 2, 1, 3))
# Post-process with mask
if video_state is not None:
clean_f32 = video_state.clean_latent.astype(mx.float32)
mask_f32 = video_state.denoise_mask.astype(mx.float32)
video_denoised = video_denoised * mask_f32 + clean_f32 * (1.0 - mask_f32)
mx.eval(video_denoised, audio_denoised)
return video_denoised, audio_denoised
# Main res_2s loop
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeRemainingColumn(),
console=console,
disable=not verbose,
) as progress:
passes = ["res2s"]
if use_cfg: passes.append("CFG")
if use_stg: passes.append("STG")
if use_modality: passes.append("Mod")
label = "+".join(passes)
task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=n_full_steps)
for step_idx in range(n_full_steps):
sigma = sigmas_list[step_idx]
sigma_next = sigmas_list[step_idx + 1]
h = hs[step_idx]
# Initialize anchor
x_anchor_video = video_latents
x_anchor_audio = audio_latents
# ============================================================
# Stage 1: Evaluate denoiser at current sigma
# ============================================================
denoised_video_1, denoised_audio_1 = _eval_guided_denoise(
video_latents, audio_latents, sigma
)
# RK coefficients
a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2)
# Substep sigma (geometric midpoint for c2=0.5)
sub_sigma = math.sqrt(sigma * sigma_next)
# Compute midpoint
eps_1_video = denoised_video_1 - x_anchor_video
eps_1_audio = denoised_audio_1 - x_anchor_audio
x_mid_video = x_anchor_video + h * a21 * eps_1_video
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
# SDE noise injection at substep
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
substep_noise_v = get_new_noise(video_latents.shape, key1)
substep_noise_a = get_new_noise(audio_latents.shape, key2)
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
mx.eval(x_mid_video, x_mid_audio)
# ============================================================
# Bong iteration: refine anchor (pure arithmetic, no model calls)
# ============================================================
if bongmath and h < 0.5 and sigma > 0.03:
for _ in range(bongmath_max_iter):
x_anchor_video = x_mid_video - h * a21 * eps_1_video
eps_1_video = denoised_video_1 - x_anchor_video
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
eps_1_audio = denoised_audio_1 - x_anchor_audio
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
# ============================================================
# Stage 2: Evaluate denoiser at midpoint sigma
# ============================================================
denoised_video_2, denoised_audio_2 = _eval_guided_denoise(
x_mid_video.astype(mx.float32), x_mid_audio.astype(mx.float32), sub_sigma
)
# ============================================================
# Final combination with RK coefficients
# ============================================================
eps_2_video = denoised_video_2 - x_anchor_video
eps_2_audio = denoised_audio_2 - x_anchor_audio
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
# SDE noise injection at step level
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
step_noise_v = get_new_noise(video_latents.shape, key1)
step_noise_a = get_new_noise(audio_latents.shape, key2)
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
video_latents = x_next_video.astype(mx.float32)
audio_latents = x_next_audio.astype(mx.float32)
mx.eval(video_latents, audio_latents)
progress.advance(task)
# Final clean step if original schedule ended at 0
if sigmas.tolist()[-1] == 0:
denoised_video, denoised_audio = _eval_guided_denoise(
video_latents, audio_latents, sigmas_list[n_full_steps]
)
video_latents = denoised_video
audio_latents = denoised_audio
mx.eval(video_latents, audio_latents)
return video_latents, audio_latents
# =============================================================================
# Audio Loading and Processing
# =============================================================================
@@ -1117,13 +1441,16 @@ def generate_video(
modality_scale: float = 1.0,
lora_path: Optional[str] = None,
lora_strength: float = 1.0,
lora_strength_stage_1: Optional[float] = None,
lora_strength_stage_2: Optional[float] = None,
):
"""Generate video using LTX-2 models.
Supports three pipelines:
Supports four pipelines:
- DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG
- DEV: Single-stage generation with dynamic sigmas and CFG
- DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG)
- DEV_TWO_STAGE_HQ: res_2s sampler, LoRA both stages (0.25/0.5), lower rescale
Args:
model_repo: Model repository ID
@@ -1158,7 +1485,7 @@ def generate_video(
start_time = time.time()
# Validate dimensions
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE)
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ)
divisor = 64 if is_two_stage else 32
assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}"
assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}"
@@ -1177,13 +1504,14 @@ def generate_video(
PipelineType.DISTILLED: "DISTILLED",
PipelineType.DEV: "DEV",
PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE",
PipelineType.DEV_TWO_STAGE_HQ: "DEV-TWO-STAGE-HQ",
}
pipeline_name = pipeline_names[pipeline]
header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height}{num_frames} frames[/]"
console.print(Panel(header, expand=False))
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else ""
stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else ""
mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else ""
@@ -1237,14 +1565,14 @@ def generate_video(
# Encode prompts - always get audio embeddings since the model was trained
# with joint audio-video processing (PyTorch unconditionally generates audio)
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
model_dtype = video_embeddings_pos.dtype
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
if pipeline == PipelineType.DEV_TWO_STAGE:
if pipeline in (PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
text_embeddings = video_embeddings_pos
else:
# Distilled pipeline - single embedding
@@ -1655,6 +1983,190 @@ def generate_video(
audio_embeddings=audio_embeddings_pos,
)
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
# ======================================================================
# DEV TWO-STAGE HQ PIPELINE:
# Stage 1: res_2s denoising at half resolution with CFG + LoRA@0.25
# Upsample: 2x spatial via LatentUpsampler
# Stage 2: res_2s refinement at full resolution with LoRA@0.5, no CFG
# ======================================================================
# HQ defaults
hq_lora_strength_s1 = lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25
hq_lora_strength_s2 = lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5
hq_cfg_rescale = cfg_rescale if cfg_rescale != 0.7 else 0.45 # Override default 0.7 → 0.45
hq_steps = num_inference_steps if num_inference_steps != 30 else 15 # Override default 30 → 15
# Load VAE encoder for I2V
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
del vae_encoder
mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Auto-detect and merge LoRA for stage 1 (strength 0.25)
if lora_path is None:
lora_files = sorted(model_path.glob("*distilled-lora*.safetensors"))
if lora_files:
lora_path = str(lora_files[0])
console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]")
else:
console.print("[yellow]Warning: No LoRA file found. HQ pipeline works best with distilled LoRA.[/]")
if lora_path is not None:
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
# Stage 1: res_2s denoising at half resolution with CFG
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
num_tokens = latent_frames * stage1_h * stage1_w
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
mx.eval(sigmas)
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
mx.random.seed(seed)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
audio_positions = create_audio_position_grid(1, audio_frames)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
mx.eval(audio_positions, audio_latents)
# Apply I2V conditioning for stage 1
state1 = None
stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w)
if is_i2v and stage1_image_latent is not None:
state1 = LatentState(
latent=mx.zeros(stage1_shape, dtype=model_dtype),
clean_latent=mx.zeros(stage1_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength)
state1 = apply_conditioning(state1, [conditioning])
noise = mx.random.normal(stage1_shape, dtype=model_dtype)
noise_scale = sigmas[0]
scaled_mask = state1.denoise_mask * noise_scale
state1 = LatentState(
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state1.clean_latent,
denoise_mask=state1.denoise_mask,
)
latents = state1.latent
mx.eval(latents)
else:
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
mx.eval(latents)
# Stage 1: res_2s with CFG (STG disabled for HQ by default)
latents, audio_latents = denoise_res2s_av(
latents, audio_latents,
positions, audio_positions,
video_embeddings_pos, video_embeddings_neg,
audio_embeddings_pos, audio_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
audio_cfg_scale=audio_cfg_scale,
cfg_rescale=hq_cfg_rescale, audio_cfg_rescale=1.0,
verbose=verbose, video_state=state1,
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
noise_seed=seed,
)
mx.eval(audio_latents)
# Upsample latents 2x
with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"):
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if not upscaler_files:
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
upsampler = load_upsampler(str(upscaler_files[0]))
mx.eval(upsampler.parameters())
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
mx.eval(latents)
del upsampler
mx.clear_cache()
console.print("[green]✓[/] Latents upsampled")
# Merge additional LoRA for stage 2 (additive: 0.25 + 0.25 = 0.5 total)
if lora_path is not None:
additional_strength = hq_lora_strength_s2 - hq_lora_strength_s1
if additional_strength > 0:
with console.status(f"[blue]Adjusting LoRA (stage 2, total={hq_lora_strength_s2})...[/]", spinner="dots"):
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
# Stage 2: res_2s refinement at full resolution (no CFG)
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
state2 = None
if is_i2v and stage2_image_latent is not None:
state2 = LatentState(
latent=latents,
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength)
state2 = apply_conditioning(state2, [conditioning])
noise = mx.random.normal(latents.shape).astype(model_dtype)
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = state2.denoise_mask * noise_scale
state2 = LatentState(
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state2.clean_latent,
denoise_mask=state2.denoise_mask,
)
latents = state2.latent
mx.eval(latents)
else:
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
noise = mx.random.normal(latents.shape).astype(model_dtype)
latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents)
# Re-noise audio at sigma=0.909375 for joint refinement
if audio_latents is not None:
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
mx.eval(audio_latents)
# Stage 2: res_2s with no CFG (positive embeddings only)
stage2_sigmas = mx.array(STAGE_2_SIGMAS, dtype=mx.float32)
latents, audio_latents = denoise_res2s_av(
latents, audio_latents,
positions, audio_positions,
video_embeddings_pos, video_embeddings_pos, # both pos (no neg for stage 2)
audio_embeddings_pos, audio_embeddings_pos,
transformer, stage2_sigmas, cfg_scale=1.0, # no CFG
audio_cfg_scale=1.0,
cfg_rescale=0.0, verbose=verbose, video_state=state2,
noise_seed=seed + 1,
)
del transformer
mx.clear_cache()
@@ -1857,8 +2369,8 @@ Examples:
)
parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate")
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage"],
help="Pipeline type: distilled (two-stage, fast), dev (single-stage, CFG), or dev-two-stage (dev + LoRA refinement)")
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage", "dev-two-stage-hq"],
help="Pipeline type: distilled (fast), dev (CFG), dev-two-stage (dev + LoRA), dev-two-stage-hq (res_2s + LoRA both stages)")
parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT,
help="Negative prompt for CFG (dev pipeline only)")
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
@@ -1895,12 +2407,15 @@ Examples:
parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)")
parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)")
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
args = parser.parse_args()
pipeline_map = {
"distilled": PipelineType.DISTILLED,
"dev": PipelineType.DEV,
"dev-two-stage": PipelineType.DEV_TWO_STAGE,
"dev-two-stage-hq": PipelineType.DEV_TWO_STAGE_HQ,
}
pipeline = pipeline_map[args.pipeline]
@@ -1940,6 +2455,8 @@ Examples:
modality_scale=args.modality_scale,
lora_path=args.lora_path,
lora_strength=args.lora_strength,
lora_strength_stage_1=args.lora_strength_stage_1,
lora_strength_stage_2=args.lora_strength_stage_2,
)