Add Dev Two-Stage HQ pipeline mode
This commit is contained in:
@@ -38,6 +38,7 @@ class PipelineType(Enum):
|
||||
DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG
|
||||
DEV = "dev" # Single-stage, dynamic sigmas, CFG
|
||||
DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res)
|
||||
DEV_TWO_STAGE_HQ = "dev-two-stage-hq" # Two-stage: res_2s sampler, LoRA both stages
|
||||
|
||||
|
||||
# Distilled model sigma schedules
|
||||
@@ -1012,6 +1013,329 @@ def denoise_dev_av(
|
||||
return video_latents, audio_latents
|
||||
|
||||
|
||||
def denoise_res2s_av(
|
||||
video_latents: mx.array,
|
||||
audio_latents: mx.array,
|
||||
video_positions: mx.array,
|
||||
audio_positions: mx.array,
|
||||
video_embeddings_pos: mx.array,
|
||||
video_embeddings_neg: mx.array,
|
||||
audio_embeddings_pos: mx.array,
|
||||
audio_embeddings_neg: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigmas: mx.array,
|
||||
cfg_scale: float = 3.0,
|
||||
audio_cfg_scale: float = 7.0,
|
||||
cfg_rescale: float = 0.45,
|
||||
audio_cfg_rescale: Optional[float] = None,
|
||||
verbose: bool = True,
|
||||
video_state: Optional[LatentState] = None,
|
||||
stg_scale: float = 0.0,
|
||||
stg_video_blocks: Optional[list] = None,
|
||||
stg_audio_blocks: Optional[list] = None,
|
||||
modality_scale: float = 1.0,
|
||||
noise_seed: int = 42,
|
||||
bongmath: bool = True,
|
||||
bongmath_max_iter: int = 100,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
|
||||
|
||||
Two model evaluations per step (current point + midpoint), with SDE noise
|
||||
injection and optional bong iteration for anchor refinement.
|
||||
|
||||
Args:
|
||||
audio_cfg_rescale: Separate rescale for audio. If None, uses cfg_rescale.
|
||||
noise_seed: Seed for SDE noise generators.
|
||||
bongmath: Enable iterative anchor refinement.
|
||||
bongmath_max_iter: Max bong iterations per step.
|
||||
"""
|
||||
from mlx_video.models.ltx.rope import precompute_freqs_cis
|
||||
from mlx_video.samplers import get_res2s_coefficients, sde_noise_step, get_new_noise
|
||||
|
||||
if audio_cfg_rescale is None:
|
||||
audio_cfg_rescale = cfg_rescale
|
||||
|
||||
dtype = video_latents.dtype
|
||||
if video_state is not None:
|
||||
video_latents = video_state.latent
|
||||
|
||||
video_latents = video_latents.astype(mx.float32)
|
||||
audio_latents = audio_latents.astype(mx.float32)
|
||||
|
||||
sigmas_list = sigmas.tolist()
|
||||
use_cfg = cfg_scale != 1.0
|
||||
use_stg = stg_scale != 0.0 and stg_video_blocks is not None
|
||||
use_modality = modality_scale != 1.0
|
||||
n_full_steps = len(sigmas_list) - 1
|
||||
|
||||
# Pad sigmas if last is 0 (avoid division by zero in RK steps)
|
||||
if sigmas_list[-1] == 0:
|
||||
sigmas_list = sigmas_list[:-1] + [0.0011, 0.0]
|
||||
|
||||
# Compute step sizes in log-space for the main loop steps only.
|
||||
# After padding, sigmas_list may have an extra [0.0011, 0.0] tail;
|
||||
# we only need hs for the n_full_steps pairs the loop actually uses.
|
||||
hs = [-math.log(sigmas_list[i + 1] / sigmas_list[i]) for i in range(n_full_steps)]
|
||||
|
||||
# Precompute RoPE
|
||||
precomputed_video_rope = precompute_freqs_cis(
|
||||
video_positions,
|
||||
dim=transformer.inner_dim,
|
||||
theta=transformer.positional_embedding_theta,
|
||||
max_pos=transformer.positional_embedding_max_pos,
|
||||
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
||||
num_attention_heads=transformer.num_attention_heads,
|
||||
rope_type=transformer.rope_type,
|
||||
double_precision=transformer.config.double_precision_rope,
|
||||
)
|
||||
precomputed_audio_rope = precompute_freqs_cis(
|
||||
audio_positions,
|
||||
dim=transformer.audio_inner_dim,
|
||||
theta=transformer.positional_embedding_theta,
|
||||
max_pos=transformer.audio_positional_embedding_max_pos,
|
||||
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
||||
num_attention_heads=transformer.audio_num_attention_heads,
|
||||
rope_type=transformer.rope_type,
|
||||
double_precision=transformer.config.double_precision_rope,
|
||||
)
|
||||
mx.eval(precomputed_video_rope, precomputed_audio_rope)
|
||||
|
||||
phi_cache = {}
|
||||
c2 = 0.5
|
||||
|
||||
# Noise key management: step noise and substep noise use different keys
|
||||
step_noise_key = mx.random.key(noise_seed)
|
||||
substep_noise_key = mx.random.key(noise_seed + 10000)
|
||||
|
||||
def _eval_guided_denoise(v_latents, a_latents, sigma):
|
||||
"""Run all guidance passes and return (video_denoised, audio_denoised) in float32 spatial format."""
|
||||
b, c, f, h, w = v_latents.shape
|
||||
num_video_tokens = f * h * w
|
||||
video_flat = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
||||
|
||||
ab, ac, at, af = a_latents.shape
|
||||
audio_flat = mx.transpose(a_latents, (0, 2, 1, 3))
|
||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
||||
|
||||
# Timesteps
|
||||
if video_state is not None:
|
||||
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
|
||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
||||
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||
else:
|
||||
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||
audio_timesteps = mx.full((ab, at), sigma, dtype=dtype)
|
||||
|
||||
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||
audio_sigma_array = mx.full((ab,), sigma, dtype=dtype)
|
||||
|
||||
# Pass 1: Positive conditioning
|
||||
video_modality_pos = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
context=video_embeddings_pos, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
||||
)
|
||||
audio_modality_pos = Modality(
|
||||
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
||||
context=audio_embeddings_pos, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
||||
)
|
||||
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
||||
mx.eval(video_vel_pos, audio_vel_pos)
|
||||
|
||||
# Convert velocity to x0
|
||||
video_flat_f32 = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1))
|
||||
audio_flat_f32 = mx.reshape(mx.transpose(a_latents, (0, 2, 1, 3)), (ab, at, ac * af))
|
||||
video_ts_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
|
||||
audio_ts_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1)
|
||||
|
||||
video_x0_pos = video_flat_f32 - video_ts_f32 * video_vel_pos.astype(mx.float32)
|
||||
audio_x0_pos = audio_flat_f32 - audio_ts_f32 * audio_vel_pos.astype(mx.float32)
|
||||
|
||||
video_x0_guided = video_x0_pos
|
||||
audio_x0_guided = audio_x0_pos
|
||||
|
||||
# Pass 2: CFG
|
||||
if use_cfg:
|
||||
video_modality_neg = Modality(
|
||||
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
||||
context=video_embeddings_neg, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
||||
)
|
||||
audio_modality_neg = Modality(
|
||||
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
||||
context=audio_embeddings_neg, context_mask=None, enabled=True,
|
||||
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
||||
)
|
||||
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
||||
mx.eval(video_vel_neg, audio_vel_neg)
|
||||
|
||||
video_x0_neg = video_flat_f32 - video_ts_f32 * video_vel_neg.astype(mx.float32)
|
||||
audio_x0_neg = audio_flat_f32 - audio_ts_f32 * audio_vel_neg.astype(mx.float32)
|
||||
|
||||
video_x0_guided = video_x0_pos + (cfg_scale - 1.0) * (video_x0_pos - video_x0_neg)
|
||||
audio_x0_guided = audio_x0_pos + (audio_cfg_scale - 1.0) * (audio_x0_pos - audio_x0_neg)
|
||||
|
||||
# Pass 3: STG
|
||||
if use_stg:
|
||||
video_vel_ptb, audio_vel_ptb = transformer(
|
||||
video=video_modality_pos, audio=audio_modality_pos,
|
||||
stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks,
|
||||
)
|
||||
mx.eval(video_vel_ptb, audio_vel_ptb)
|
||||
|
||||
video_x0_ptb = video_flat_f32 - video_ts_f32 * video_vel_ptb.astype(mx.float32)
|
||||
audio_x0_ptb = audio_flat_f32 - audio_ts_f32 * audio_vel_ptb.astype(mx.float32)
|
||||
|
||||
video_x0_guided = video_x0_guided + stg_scale * (video_x0_pos - video_x0_ptb)
|
||||
audio_x0_guided = audio_x0_guided + stg_scale * (audio_x0_pos - audio_x0_ptb)
|
||||
|
||||
# Pass 4: Modality isolation
|
||||
if use_modality:
|
||||
video_vel_iso, audio_vel_iso = transformer(
|
||||
video=video_modality_pos, audio=audio_modality_pos,
|
||||
skip_cross_modal=True,
|
||||
)
|
||||
mx.eval(video_vel_iso, audio_vel_iso)
|
||||
|
||||
video_x0_iso = video_flat_f32 - video_ts_f32 * video_vel_iso.astype(mx.float32)
|
||||
audio_x0_iso = audio_flat_f32 - audio_ts_f32 * audio_vel_iso.astype(mx.float32)
|
||||
|
||||
video_x0_guided = video_x0_guided + (modality_scale - 1.0) * (video_x0_pos - video_x0_iso)
|
||||
audio_x0_guided = audio_x0_guided + (modality_scale - 1.0) * (audio_x0_pos - audio_x0_iso)
|
||||
|
||||
# Rescale (separate factors for video and audio)
|
||||
if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
|
||||
v_factor = video_x0_pos.std() / (video_x0_guided.std() + 1e-8)
|
||||
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
||||
video_x0_guided = video_x0_guided * v_factor
|
||||
if audio_cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
|
||||
a_factor = audio_x0_pos.std() / (audio_x0_guided.std() + 1e-8)
|
||||
a_factor = audio_cfg_rescale * a_factor + (1.0 - audio_cfg_rescale)
|
||||
audio_x0_guided = audio_x0_guided * a_factor
|
||||
|
||||
# Reshape to spatial
|
||||
video_denoised = mx.reshape(mx.transpose(video_x0_guided, (0, 2, 1)), (b, c, f, h, w))
|
||||
audio_denoised = mx.reshape(audio_x0_guided, (ab, at, ac, af))
|
||||
audio_denoised = mx.transpose(audio_denoised, (0, 2, 1, 3))
|
||||
|
||||
# Post-process with mask
|
||||
if video_state is not None:
|
||||
clean_f32 = video_state.clean_latent.astype(mx.float32)
|
||||
mask_f32 = video_state.denoise_mask.astype(mx.float32)
|
||||
video_denoised = video_denoised * mask_f32 + clean_f32 * (1.0 - mask_f32)
|
||||
|
||||
mx.eval(video_denoised, audio_denoised)
|
||||
return video_denoised, audio_denoised
|
||||
|
||||
# Main res_2s loop
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
disable=not verbose,
|
||||
) as progress:
|
||||
passes = ["res2s"]
|
||||
if use_cfg: passes.append("CFG")
|
||||
if use_stg: passes.append("STG")
|
||||
if use_modality: passes.append("Mod")
|
||||
label = "+".join(passes)
|
||||
task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=n_full_steps)
|
||||
|
||||
for step_idx in range(n_full_steps):
|
||||
sigma = sigmas_list[step_idx]
|
||||
sigma_next = sigmas_list[step_idx + 1]
|
||||
h = hs[step_idx]
|
||||
|
||||
# Initialize anchor
|
||||
x_anchor_video = video_latents
|
||||
x_anchor_audio = audio_latents
|
||||
|
||||
# ============================================================
|
||||
# Stage 1: Evaluate denoiser at current sigma
|
||||
# ============================================================
|
||||
denoised_video_1, denoised_audio_1 = _eval_guided_denoise(
|
||||
video_latents, audio_latents, sigma
|
||||
)
|
||||
|
||||
# RK coefficients
|
||||
a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2)
|
||||
|
||||
# Substep sigma (geometric midpoint for c2=0.5)
|
||||
sub_sigma = math.sqrt(sigma * sigma_next)
|
||||
|
||||
# Compute midpoint
|
||||
eps_1_video = denoised_video_1 - x_anchor_video
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
|
||||
x_mid_video = x_anchor_video + h * a21 * eps_1_video
|
||||
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
|
||||
|
||||
# SDE noise injection at substep
|
||||
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
|
||||
substep_noise_v = get_new_noise(video_latents.shape, key1)
|
||||
substep_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
|
||||
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
|
||||
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
|
||||
mx.eval(x_mid_video, x_mid_audio)
|
||||
|
||||
# ============================================================
|
||||
# Bong iteration: refine anchor (pure arithmetic, no model calls)
|
||||
# ============================================================
|
||||
if bongmath and h < 0.5 and sigma > 0.03:
|
||||
for _ in range(bongmath_max_iter):
|
||||
x_anchor_video = x_mid_video - h * a21 * eps_1_video
|
||||
eps_1_video = denoised_video_1 - x_anchor_video
|
||||
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
|
||||
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
||||
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
|
||||
|
||||
# ============================================================
|
||||
# Stage 2: Evaluate denoiser at midpoint sigma
|
||||
# ============================================================
|
||||
denoised_video_2, denoised_audio_2 = _eval_guided_denoise(
|
||||
x_mid_video.astype(mx.float32), x_mid_audio.astype(mx.float32), sub_sigma
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Final combination with RK coefficients
|
||||
# ============================================================
|
||||
eps_2_video = denoised_video_2 - x_anchor_video
|
||||
eps_2_audio = denoised_audio_2 - x_anchor_audio
|
||||
|
||||
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
|
||||
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
|
||||
|
||||
# SDE noise injection at step level
|
||||
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
|
||||
step_noise_v = get_new_noise(video_latents.shape, key1)
|
||||
step_noise_a = get_new_noise(audio_latents.shape, key2)
|
||||
|
||||
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
|
||||
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
|
||||
|
||||
video_latents = x_next_video.astype(mx.float32)
|
||||
audio_latents = x_next_audio.astype(mx.float32)
|
||||
mx.eval(video_latents, audio_latents)
|
||||
progress.advance(task)
|
||||
|
||||
# Final clean step if original schedule ended at 0
|
||||
if sigmas.tolist()[-1] == 0:
|
||||
denoised_video, denoised_audio = _eval_guided_denoise(
|
||||
video_latents, audio_latents, sigmas_list[n_full_steps]
|
||||
)
|
||||
video_latents = denoised_video
|
||||
audio_latents = denoised_audio
|
||||
mx.eval(video_latents, audio_latents)
|
||||
|
||||
return video_latents, audio_latents
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Audio Loading and Processing
|
||||
# =============================================================================
|
||||
@@ -1117,13 +1441,16 @@ def generate_video(
|
||||
modality_scale: float = 1.0,
|
||||
lora_path: Optional[str] = None,
|
||||
lora_strength: float = 1.0,
|
||||
lora_strength_stage_1: Optional[float] = None,
|
||||
lora_strength_stage_2: Optional[float] = None,
|
||||
):
|
||||
"""Generate video using LTX-2 models.
|
||||
|
||||
Supports three pipelines:
|
||||
Supports four pipelines:
|
||||
- DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG
|
||||
- DEV: Single-stage generation with dynamic sigmas and CFG
|
||||
- DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG)
|
||||
- DEV_TWO_STAGE_HQ: res_2s sampler, LoRA both stages (0.25/0.5), lower rescale
|
||||
|
||||
Args:
|
||||
model_repo: Model repository ID
|
||||
@@ -1158,7 +1485,7 @@ def generate_video(
|
||||
start_time = time.time()
|
||||
|
||||
# Validate dimensions
|
||||
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE)
|
||||
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ)
|
||||
divisor = 64 if is_two_stage else 32
|
||||
assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}"
|
||||
assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}"
|
||||
@@ -1177,13 +1504,14 @@ def generate_video(
|
||||
PipelineType.DISTILLED: "DISTILLED",
|
||||
PipelineType.DEV: "DEV",
|
||||
PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE",
|
||||
PipelineType.DEV_TWO_STAGE_HQ: "DEV-TWO-STAGE-HQ",
|
||||
}
|
||||
pipeline_name = pipeline_names[pipeline]
|
||||
header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height} • {num_frames} frames[/]"
|
||||
console.print(Panel(header, expand=False))
|
||||
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
|
||||
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
||||
audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else ""
|
||||
stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else ""
|
||||
mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else ""
|
||||
@@ -1237,14 +1565,14 @@ def generate_video(
|
||||
|
||||
# Encode prompts - always get audio embeddings since the model was trained
|
||||
# with joint audio-video processing (PyTorch unconditionally generates audio)
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE):
|
||||
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
||||
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
|
||||
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
|
||||
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
|
||||
model_dtype = video_embeddings_pos.dtype
|
||||
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
|
||||
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
|
||||
if pipeline == PipelineType.DEV_TWO_STAGE:
|
||||
if pipeline in (PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
||||
text_embeddings = video_embeddings_pos
|
||||
else:
|
||||
# Distilled pipeline - single embedding
|
||||
@@ -1655,6 +1983,190 @@ def generate_video(
|
||||
audio_embeddings=audio_embeddings_pos,
|
||||
)
|
||||
|
||||
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
|
||||
# ======================================================================
|
||||
# DEV TWO-STAGE HQ PIPELINE:
|
||||
# Stage 1: res_2s denoising at half resolution with CFG + LoRA@0.25
|
||||
# Upsample: 2x spatial via LatentUpsampler
|
||||
# Stage 2: res_2s refinement at full resolution with LoRA@0.5, no CFG
|
||||
# ======================================================================
|
||||
|
||||
# HQ defaults
|
||||
hq_lora_strength_s1 = lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25
|
||||
hq_lora_strength_s2 = lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5
|
||||
hq_cfg_rescale = cfg_rescale if cfg_rescale != 0.7 else 0.45 # Override default 0.7 → 0.45
|
||||
hq_steps = num_inference_steps if num_inference_steps != 30 else 15 # Override default 30 → 15
|
||||
|
||||
# Load VAE encoder for I2V
|
||||
stage1_image_latent = None
|
||||
stage2_image_latent = None
|
||||
if is_i2v:
|
||||
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
|
||||
del vae_encoder
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||
|
||||
# Auto-detect and merge LoRA for stage 1 (strength 0.25)
|
||||
if lora_path is None:
|
||||
lora_files = sorted(model_path.glob("*distilled-lora*.safetensors"))
|
||||
if lora_files:
|
||||
lora_path = str(lora_files[0])
|
||||
console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]")
|
||||
else:
|
||||
console.print("[yellow]Warning: No LoRA file found. HQ pipeline works best with distilled LoRA.[/]")
|
||||
|
||||
if lora_path is not None:
|
||||
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
|
||||
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
|
||||
|
||||
# Stage 1: res_2s denoising at half resolution with CFG
|
||||
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
|
||||
num_tokens = latent_frames * stage1_h * stage1_w
|
||||
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
|
||||
mx.eval(sigmas)
|
||||
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
|
||||
|
||||
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
|
||||
mx.random.seed(seed)
|
||||
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
||||
mx.eval(audio_positions, audio_latents)
|
||||
|
||||
# Apply I2V conditioning for stage 1
|
||||
state1 = None
|
||||
stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
||||
if is_i2v and stage1_image_latent is not None:
|
||||
state1 = LatentState(
|
||||
latent=mx.zeros(stage1_shape, dtype=model_dtype),
|
||||
clean_latent=mx.zeros(stage1_shape, dtype=model_dtype),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
||||
state1 = apply_conditioning(state1, [conditioning])
|
||||
|
||||
noise = mx.random.normal(stage1_shape, dtype=model_dtype)
|
||||
noise_scale = sigmas[0]
|
||||
scaled_mask = state1.denoise_mask * noise_scale
|
||||
state1 = LatentState(
|
||||
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=state1.clean_latent,
|
||||
denoise_mask=state1.denoise_mask,
|
||||
)
|
||||
latents = state1.latent
|
||||
mx.eval(latents)
|
||||
else:
|
||||
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
|
||||
mx.eval(latents)
|
||||
|
||||
# Stage 1: res_2s with CFG (STG disabled for HQ by default)
|
||||
latents, audio_latents = denoise_res2s_av(
|
||||
latents, audio_latents,
|
||||
positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_neg,
|
||||
audio_embeddings_pos, audio_embeddings_neg,
|
||||
transformer, sigmas, cfg_scale=cfg_scale,
|
||||
audio_cfg_scale=audio_cfg_scale,
|
||||
cfg_rescale=hq_cfg_rescale, audio_cfg_rescale=1.0,
|
||||
verbose=verbose, video_state=state1,
|
||||
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
||||
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
||||
noise_seed=seed,
|
||||
)
|
||||
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Upsample latents 2x
|
||||
with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"):
|
||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
||||
if not upscaler_files:
|
||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
||||
mx.eval(latents)
|
||||
|
||||
del upsampler
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] Latents upsampled")
|
||||
|
||||
# Merge additional LoRA for stage 2 (additive: 0.25 + 0.25 = 0.5 total)
|
||||
if lora_path is not None:
|
||||
additional_strength = hq_lora_strength_s2 - hq_lora_strength_s1
|
||||
if additional_strength > 0:
|
||||
with console.status(f"[blue]Adjusting LoRA (stage 2, total={hq_lora_strength_s2})...[/]", spinner="dots"):
|
||||
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
|
||||
|
||||
# Stage 2: res_2s refinement at full resolution (no CFG)
|
||||
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
state2 = None
|
||||
if is_i2v and stage2_image_latent is not None:
|
||||
state2 = LatentState(
|
||||
latent=latents,
|
||||
clean_latent=mx.zeros_like(latents),
|
||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||
)
|
||||
conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
||||
state2 = apply_conditioning(state2, [conditioning])
|
||||
|
||||
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
scaled_mask = state2.denoise_mask * noise_scale
|
||||
state2 = LatentState(
|
||||
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||
clean_latent=state2.clean_latent,
|
||||
denoise_mask=state2.denoise_mask,
|
||||
)
|
||||
latents = state2.latent
|
||||
mx.eval(latents)
|
||||
else:
|
||||
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
||||
latents = noise * noise_scale + latents * one_minus_scale
|
||||
mx.eval(latents)
|
||||
|
||||
# Re-noise audio at sigma=0.909375 for joint refinement
|
||||
if audio_latents is not None:
|
||||
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
||||
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Stage 2: res_2s with no CFG (positive embeddings only)
|
||||
stage2_sigmas = mx.array(STAGE_2_SIGMAS, dtype=mx.float32)
|
||||
latents, audio_latents = denoise_res2s_av(
|
||||
latents, audio_latents,
|
||||
positions, audio_positions,
|
||||
video_embeddings_pos, video_embeddings_pos, # both pos (no neg for stage 2)
|
||||
audio_embeddings_pos, audio_embeddings_pos,
|
||||
transformer, stage2_sigmas, cfg_scale=1.0, # no CFG
|
||||
audio_cfg_scale=1.0,
|
||||
cfg_rescale=0.0, verbose=verbose, video_state=state2,
|
||||
noise_seed=seed + 1,
|
||||
)
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
@@ -1857,8 +2369,8 @@ Examples:
|
||||
)
|
||||
|
||||
parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate")
|
||||
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage"],
|
||||
help="Pipeline type: distilled (two-stage, fast), dev (single-stage, CFG), or dev-two-stage (dev + LoRA refinement)")
|
||||
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage", "dev-two-stage-hq"],
|
||||
help="Pipeline type: distilled (fast), dev (CFG), dev-two-stage (dev + LoRA), dev-two-stage-hq (res_2s + LoRA both stages)")
|
||||
parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT,
|
||||
help="Negative prompt for CFG (dev pipeline only)")
|
||||
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
|
||||
@@ -1895,12 +2407,15 @@ Examples:
|
||||
parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)")
|
||||
parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)")
|
||||
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
|
||||
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
|
||||
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
|
||||
args = parser.parse_args()
|
||||
|
||||
pipeline_map = {
|
||||
"distilled": PipelineType.DISTILLED,
|
||||
"dev": PipelineType.DEV,
|
||||
"dev-two-stage": PipelineType.DEV_TWO_STAGE,
|
||||
"dev-two-stage-hq": PipelineType.DEV_TWO_STAGE_HQ,
|
||||
}
|
||||
pipeline = pipeline_map[args.pipeline]
|
||||
|
||||
@@ -1940,6 +2455,8 @@ Examples:
|
||||
modality_scale=args.modality_scale,
|
||||
lora_path=args.lora_path,
|
||||
lora_strength=args.lora_strength,
|
||||
lora_strength_stage_1=args.lora_strength_stage_1,
|
||||
lora_strength_stage_2=args.lora_strength_stage_2,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user