2567 lines
117 KiB
Python
2567 lines
117 KiB
Python
"""Unified video and audio-video generation pipeline for LTX-2.
|
|
|
|
Supports both distilled (two-stage with upsampling) and dev (single-stage with CFG) pipelines.
|
|
"""
|
|
|
|
import argparse
|
|
import math
|
|
import time
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
from PIL import Image
|
|
from rich.console import Console
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
|
|
from rich.panel import Panel
|
|
|
|
# Rich console for styled output
|
|
console = Console()
|
|
|
|
|
|
from mlx_video.models.ltx_2.ltx import LTXModel
|
|
from mlx_video.models.ltx_2.transformer import Modality
|
|
|
|
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding, get_model_path
|
|
from mlx_video.models.ltx_2.video_vae.decoder import VideoDecoder
|
|
from mlx_video.models.ltx_2.video_vae import VideoEncoder
|
|
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig
|
|
from mlx_video.models.ltx_2.upsampler import load_upsampler, upsample_latents
|
|
from mlx_video.models.ltx_2.conditioning import VideoConditionByLatentIndex, apply_conditioning
|
|
from mlx_video.models.ltx_2.conditioning.latent import LatentState, apply_denoise_mask
|
|
|
|
|
|
class PipelineType(Enum):
|
|
"""Pipeline type selector."""
|
|
DISTILLED = "distilled" # Two-stage with upsampling, fixed sigmas, no CFG
|
|
DEV = "dev" # Single-stage, dynamic sigmas, CFG
|
|
DEV_TWO_STAGE = "dev-two-stage" # Two-stage: dev (half res, CFG) + distilled LoRA (full res)
|
|
DEV_TWO_STAGE_HQ = "dev-two-stage-hq" # Two-stage: res_2s sampler, LoRA both stages
|
|
|
|
|
|
# Distilled model sigma schedules
|
|
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
|
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
|
|
|
# Dev model scheduling constants
|
|
BASE_SHIFT_ANCHOR = 1024
|
|
MAX_SHIFT_ANCHOR = 4096
|
|
|
|
# Audio constants
|
|
AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate
|
|
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate
|
|
AUDIO_HOP_LENGTH = 160
|
|
AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4
|
|
AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying
|
|
AUDIO_MEL_BINS = 16
|
|
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
|
|
|
|
# Default negative prompt for CFG (dev pipeline)
|
|
# Matches PyTorch LTX-2 reference DEFAULT_NEGATIVE_PROMPT from constants.py
|
|
DEFAULT_NEGATIVE_PROMPT = (
|
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
|
)
|
|
|
|
|
|
def load_and_merge_lora(
|
|
model: LTXModel,
|
|
lora_path: str,
|
|
strength: float = 1.0,
|
|
) -> None:
|
|
"""Load LoRA weights and merge them into the transformer model in-place.
|
|
|
|
Supports two formats:
|
|
- Raw PyTorch: keys like diffusion_model.{module}.lora_A.weight (needs sanitization)
|
|
- Pre-converted MLX: keys like {module}.lora_A.weight (already sanitized)
|
|
|
|
Merge formula: weight += (lora_B * strength) @ lora_A
|
|
|
|
Args:
|
|
model: The LTXModel transformer to merge into
|
|
lora_path: Path to the LoRA safetensors file or directory containing one
|
|
strength: LoRA strength/coefficient (default 1.0)
|
|
"""
|
|
# Resolve path: local file/dir or HuggingFace repo
|
|
lora_file = Path(lora_path)
|
|
if lora_file.is_file():
|
|
pass # direct file path
|
|
elif lora_file.is_dir():
|
|
# Local directory: find safetensors inside
|
|
candidates = sorted(lora_file.glob("*.safetensors"))
|
|
if not candidates:
|
|
raise FileNotFoundError(f"No .safetensors files found in {lora_path}")
|
|
# Prefer distilled-lora files over full model weights
|
|
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
|
|
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
|
|
console.print(f"[dim]Using LoRA file: {lora_file.name}[/]")
|
|
else:
|
|
# Treat as HuggingFace repo ID
|
|
lora_dir = get_model_path(lora_path)
|
|
candidates = sorted(lora_dir.glob("*.safetensors"))
|
|
if not candidates:
|
|
raise FileNotFoundError(f"No .safetensors files found in {lora_dir}")
|
|
# Prefer distilled-lora files over full model weights
|
|
lora_candidates = [c for c in candidates if "distilled-lora" in c.name]
|
|
lora_file = lora_candidates[0] if lora_candidates else candidates[0]
|
|
console.print(f"[dim]Using LoRA from repo: {lora_path} ({lora_file.name})[/]")
|
|
|
|
# Load LoRA weights
|
|
lora_weights = mx.load(str(lora_file))
|
|
|
|
# Detect format: raw PyTorch has 'diffusion_model.' prefix
|
|
has_prefix = any(k.startswith("diffusion_model.") for k in lora_weights)
|
|
|
|
# Group into A/B pairs by module name
|
|
lora_pairs = {}
|
|
for key in lora_weights:
|
|
module_key = key
|
|
if has_prefix:
|
|
if not key.startswith("diffusion_model."):
|
|
continue
|
|
module_key = key.replace("diffusion_model.", "")
|
|
|
|
if module_key.endswith(".lora_A.weight"):
|
|
base_key = module_key.replace(".lora_A.weight", "")
|
|
lora_pairs.setdefault(base_key, {})["A"] = lora_weights[key]
|
|
elif module_key.endswith(".lora_B.weight"):
|
|
base_key = module_key.replace(".lora_B.weight", "")
|
|
lora_pairs.setdefault(base_key, {})["B"] = lora_weights[key]
|
|
|
|
# Apply key sanitization only for raw PyTorch format
|
|
# Replacements handle both mid-string and end-of-string positions
|
|
# since LoRA base keys end at the module name without trailing dot
|
|
_LORA_KEY_REPLACEMENTS = [
|
|
(".to_out.0", ".to_out"),
|
|
(".ff.net.0.proj", ".ff.proj_in"),
|
|
(".ff.net.2", ".ff.proj_out"),
|
|
(".audio_ff.net.0.proj", ".audio_ff.proj_in"),
|
|
(".audio_ff.net.2", ".audio_ff.proj_out"),
|
|
(".linear_1", ".linear1"),
|
|
(".linear_2", ".linear2"),
|
|
]
|
|
if has_prefix:
|
|
sanitized_pairs = {}
|
|
for key, pair in lora_pairs.items():
|
|
new_key = key
|
|
for old, new in _LORA_KEY_REPLACEMENTS:
|
|
if new_key.endswith(old):
|
|
new_key = new_key[:-len(old)] + new
|
|
else:
|
|
new_key = new_key.replace(old + ".", new + ".")
|
|
sanitized_pairs[new_key] = pair
|
|
else:
|
|
sanitized_pairs = lora_pairs
|
|
|
|
# Get current model weights as a flat dict (references, not copies)
|
|
def flatten_params(params, prefix=""):
|
|
flat = {}
|
|
for k, v in params.items():
|
|
full_key = f"{prefix}.{k}" if prefix else k
|
|
if isinstance(v, dict):
|
|
flat.update(flatten_params(v, full_key))
|
|
else:
|
|
flat[full_key] = v
|
|
return flat
|
|
|
|
flat_weights = flatten_params(dict(model.parameters()))
|
|
|
|
# Merge LoRA deltas in batches to avoid doubling memory
|
|
merged_count = 0
|
|
batch = []
|
|
batch_size = 100 # merge 100 weights at a time, then eval to free intermediates
|
|
|
|
for module_key, pair in sanitized_pairs.items():
|
|
if "A" not in pair or "B" not in pair:
|
|
continue
|
|
|
|
weight_key = f"{module_key}.weight"
|
|
if weight_key not in flat_weights:
|
|
continue
|
|
|
|
lora_a = pair["A"].astype(mx.float32) # (rank, in_features)
|
|
lora_b = pair["B"].astype(mx.float32) # (out_features, rank)
|
|
|
|
# delta = (lora_B * strength) @ lora_A
|
|
delta = (lora_b * strength) @ lora_a
|
|
|
|
base_weight = flat_weights.pop(weight_key)
|
|
merged_weight = (base_weight.astype(mx.float32) + delta).astype(base_weight.dtype)
|
|
batch.append((weight_key, merged_weight))
|
|
del base_weight
|
|
merged_count += 1
|
|
|
|
if len(batch) >= batch_size:
|
|
model.load_weights(batch, strict=False)
|
|
mx.eval(model.parameters())
|
|
batch.clear()
|
|
|
|
if batch:
|
|
model.load_weights(batch, strict=False)
|
|
mx.eval(model.parameters())
|
|
batch.clear()
|
|
|
|
del flat_weights, lora_weights
|
|
mx.clear_cache()
|
|
console.print(f"[green]✓[/] Merged {merged_count} LoRA pairs (strength={strength})")
|
|
|
|
|
|
def cfg_delta(cond: mx.array, uncond: mx.array, scale: float) -> mx.array:
|
|
"""Compute CFG delta for classifier-free guidance.
|
|
|
|
Args:
|
|
cond: Conditional prediction
|
|
uncond: Unconditional prediction
|
|
scale: CFG guidance scale
|
|
|
|
Returns:
|
|
Delta to add to unconditional for CFG: (scale - 1) * (cond - uncond)
|
|
"""
|
|
return (scale - 1.0) * (cond - uncond)
|
|
|
|
|
|
def apg_delta(
|
|
cond: mx.array,
|
|
uncond: mx.array,
|
|
scale: float,
|
|
eta: float = 1.0,
|
|
norm_threshold: float = 0.0,
|
|
) -> mx.array:
|
|
"""Compute APG (Adaptive Projected Guidance) delta.
|
|
|
|
Decomposes guidance into parallel and orthogonal components relative to
|
|
the conditional prediction, providing more stable guidance for I2V.
|
|
|
|
Based on: https://arxiv.org/abs/2407.12173
|
|
|
|
Args:
|
|
cond: Conditional prediction (x0_pos)
|
|
uncond: Unconditional prediction (x0_neg)
|
|
scale: Guidance strength (same as CFG scale)
|
|
eta: Weight for parallel component (1.0 = keep full parallel)
|
|
norm_threshold: Clamp guidance norm to this value (0 = no clamping)
|
|
|
|
Returns:
|
|
Delta to add to unconditional for APG guidance
|
|
"""
|
|
guidance = cond - uncond
|
|
|
|
# Optionally clamp guidance norm for stability
|
|
if norm_threshold > 0:
|
|
guidance_norm = mx.sqrt(mx.sum(guidance ** 2, axis=(-1, -2, -3), keepdims=True) + 1e-8)
|
|
scale_factor = mx.minimum(mx.ones_like(guidance_norm), norm_threshold / guidance_norm)
|
|
guidance = guidance * scale_factor
|
|
|
|
# Project guidance onto cond direction
|
|
batch_size = cond.shape[0]
|
|
cond_flat = mx.reshape(cond, (batch_size, -1))
|
|
guidance_flat = mx.reshape(guidance, (batch_size, -1))
|
|
|
|
# Projection coefficient: (guidance · cond) / (cond · cond)
|
|
dot_product = mx.sum(guidance_flat * cond_flat, axis=1, keepdims=True)
|
|
squared_norm = mx.sum(cond_flat ** 2, axis=1, keepdims=True) + 1e-8
|
|
proj_coeff = dot_product / squared_norm
|
|
|
|
# Reshape back and compute parallel/orthogonal components
|
|
proj_coeff = mx.reshape(proj_coeff, (batch_size,) + (1,) * (cond.ndim - 1))
|
|
g_parallel = proj_coeff * cond
|
|
g_orth = guidance - g_parallel
|
|
|
|
# Combine with eta weighting parallel component
|
|
g_apg = g_parallel * eta + g_orth
|
|
|
|
return g_apg * (scale - 1.0)
|
|
|
|
|
|
def ltx2_scheduler(
|
|
steps: int,
|
|
num_tokens: Optional[int] = None,
|
|
max_shift: float = 2.05,
|
|
base_shift: float = 0.95,
|
|
stretch: bool = True,
|
|
terminal: float = 0.1,
|
|
) -> mx.array:
|
|
"""LTX-2 scheduler for sigma generation (dev model).
|
|
|
|
Generates a sigma schedule with token-count-dependent shifting and optional
|
|
stretching to a terminal value.
|
|
|
|
Args:
|
|
steps: Number of inference steps
|
|
num_tokens: Number of latent tokens (F*H*W). If None, uses MAX_SHIFT_ANCHOR
|
|
max_shift: Maximum shift factor
|
|
base_shift: Base shift factor
|
|
stretch: Whether to stretch sigmas to terminal value
|
|
terminal: Terminal sigma value for stretching
|
|
|
|
Returns:
|
|
Array of sigma values of shape (steps + 1,)
|
|
"""
|
|
tokens = num_tokens if num_tokens is not None else MAX_SHIFT_ANCHOR
|
|
sigmas = np.linspace(1.0, 0.0, steps + 1)
|
|
|
|
# Compute shift based on token count
|
|
x1 = BASE_SHIFT_ANCHOR
|
|
x2 = MAX_SHIFT_ANCHOR
|
|
mm = (max_shift - base_shift) / (x2 - x1)
|
|
b = base_shift - mm * x1
|
|
sigma_shift = tokens * mm + b
|
|
|
|
# Apply shift transformation
|
|
power = 1
|
|
with np.errstate(divide='ignore', invalid='ignore'):
|
|
sigmas = np.where(
|
|
sigmas != 0,
|
|
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
|
0,
|
|
)
|
|
|
|
# Stretch sigmas to terminal value
|
|
if stretch:
|
|
non_zero_mask = sigmas != 0
|
|
non_zero_sigmas = sigmas[non_zero_mask]
|
|
one_minus_z = 1.0 - non_zero_sigmas
|
|
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
|
stretched = 1.0 - (one_minus_z / scale_factor)
|
|
sigmas[non_zero_mask] = stretched
|
|
|
|
return mx.array(sigmas, dtype=mx.float32)
|
|
|
|
|
|
def create_position_grid(
|
|
batch_size: int,
|
|
num_frames: int,
|
|
height: int,
|
|
width: int,
|
|
temporal_scale: int = 8,
|
|
spatial_scale: int = 32,
|
|
fps: float = 24.0,
|
|
causal_fix: bool = True,
|
|
) -> mx.array:
|
|
"""Create position grid for RoPE in pixel space.
|
|
|
|
Args:
|
|
batch_size: Batch size
|
|
num_frames: Number of frames (latent)
|
|
height: Height (latent)
|
|
width: Width (latent)
|
|
temporal_scale: VAE temporal scale factor (default 8)
|
|
spatial_scale: VAE spatial scale factor (default 32)
|
|
fps: Frames per second (default 24.0)
|
|
causal_fix: Apply causal fix for first frame (default True)
|
|
|
|
Returns:
|
|
Position grid of shape (B, 3, num_patches, 2) in pixel space
|
|
where dim 2 is [start, end) bounds for each patch
|
|
"""
|
|
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
|
|
|
|
t_coords = np.arange(0, num_frames, patch_size_t)
|
|
h_coords = np.arange(0, height, patch_size_h)
|
|
w_coords = np.arange(0, width, patch_size_w)
|
|
|
|
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
|
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
|
|
|
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
|
|
patch_ends = patch_starts + patch_size_delta
|
|
|
|
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
|
|
num_patches = num_frames * height * width
|
|
latent_coords = latent_coords.reshape(3, num_patches, 2)
|
|
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
|
|
|
|
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
|
|
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
|
|
|
|
if causal_fix:
|
|
pixel_coords[:, 0, :, :] = np.clip(
|
|
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
|
|
a_min=0,
|
|
a_max=None
|
|
)
|
|
|
|
# Divide temporal coords by fps
|
|
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
|
|
|
# Cast entire position grid through bfloat16 to match PyTorch's behavior.
|
|
# PyTorch does: positions = positions.to(bfloat16) on ALL coordinates before
|
|
# passing to the transformer/RoPE. This quantization is what the model was
|
|
# trained with, so we must replicate it for numerical fidelity.
|
|
positions_bf16 = mx.array(pixel_coords, dtype=mx.bfloat16)
|
|
mx.eval(positions_bf16)
|
|
return positions_bf16.astype(mx.float32)
|
|
|
|
|
|
def create_audio_position_grid(
|
|
batch_size: int,
|
|
audio_frames: int,
|
|
sample_rate: int = AUDIO_LATENT_SAMPLE_RATE,
|
|
hop_length: int = AUDIO_HOP_LENGTH,
|
|
downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR,
|
|
is_causal: bool = True,
|
|
) -> mx.array:
|
|
"""Create temporal position grid for audio RoPE."""
|
|
def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray:
|
|
latent_frame = np.arange(start_idx, end_idx, dtype=np.float32)
|
|
mel_frame = latent_frame * downsample_factor
|
|
if is_causal:
|
|
mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None)
|
|
return mel_frame * hop_length / sample_rate
|
|
|
|
start_times = get_audio_latent_time_in_sec(0, audio_frames)
|
|
end_times = get_audio_latent_time_in_sec(1, audio_frames + 1)
|
|
|
|
positions = np.stack([start_times, end_times], axis=-1)
|
|
positions = positions[np.newaxis, np.newaxis, :, :]
|
|
positions = np.tile(positions, (batch_size, 1, 1, 1))
|
|
|
|
# Cast through bfloat16 to match PyTorch's precision behavior
|
|
positions_bf16 = mx.array(positions, dtype=mx.bfloat16)
|
|
mx.eval(positions_bf16)
|
|
return positions_bf16.astype(mx.float32)
|
|
|
|
|
|
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
|
|
"""Compute number of audio latent frames given video duration."""
|
|
duration = num_video_frames / fps
|
|
return round(duration * AUDIO_LATENTS_PER_SECOND)
|
|
|
|
|
|
# =============================================================================
|
|
# Distilled Pipeline Denoising (no CFG, fixed sigmas)
|
|
# =============================================================================
|
|
|
|
def denoise_distilled(
|
|
latents: mx.array,
|
|
positions: mx.array,
|
|
text_embeddings: mx.array,
|
|
transformer: LTXModel,
|
|
sigmas: list,
|
|
verbose: bool = True,
|
|
state: Optional[LatentState] = None,
|
|
audio_latents: Optional[mx.array] = None,
|
|
audio_positions: Optional[mx.array] = None,
|
|
audio_embeddings: Optional[mx.array] = None,
|
|
audio_frozen: bool = False,
|
|
) -> tuple[mx.array, Optional[mx.array]]:
|
|
"""Run denoising loop for distilled pipeline (no CFG)."""
|
|
dtype = latents.dtype
|
|
enable_audio = audio_latents is not None
|
|
|
|
if state is not None:
|
|
latents = state.latent
|
|
|
|
# Keep latents in float32 throughout to avoid quantization noise accumulation.
|
|
latents = latents.astype(mx.float32)
|
|
if enable_audio:
|
|
audio_latents = audio_latents.astype(mx.float32)
|
|
|
|
desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]"
|
|
num_steps = len(sigmas) - 1
|
|
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(),
|
|
TaskProgressColumn(),
|
|
TimeRemainingColumn(),
|
|
console=console,
|
|
disable=not verbose,
|
|
) as progress:
|
|
task = progress.add_task(desc, total=num_steps)
|
|
|
|
for i in range(num_steps):
|
|
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
|
|
|
b, c, f, h, w = latents.shape
|
|
num_tokens = f * h * w
|
|
# Cast to model dtype for transformer input
|
|
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
|
|
|
if state is not None:
|
|
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
|
|
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
|
else:
|
|
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
|
|
|
video_modality = Modality(
|
|
latent=latents_flat,
|
|
timesteps=timesteps,
|
|
positions=positions,
|
|
context=text_embeddings,
|
|
context_mask=None,
|
|
enabled=True,
|
|
sigma=mx.full((b,), sigma, dtype=dtype),
|
|
)
|
|
|
|
audio_modality = None
|
|
if enable_audio:
|
|
ab, ac, at, af = audio_latents.shape
|
|
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
|
|
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
|
|
|
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
|
|
a_ts = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
|
a_sig = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
|
audio_modality = Modality(
|
|
latent=audio_flat,
|
|
timesteps=a_ts,
|
|
positions=audio_positions,
|
|
context=audio_embeddings,
|
|
context_mask=None,
|
|
enabled=True,
|
|
sigma=a_sig,
|
|
)
|
|
|
|
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
|
|
mx.eval(velocity)
|
|
if audio_velocity is not None:
|
|
mx.eval(audio_velocity)
|
|
|
|
# Compute denoised (x0) using per-token timesteps in float32
|
|
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
|
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
|
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
|
|
x0_f32 = latents_flat_f32 - timesteps_f32 * velocity.astype(mx.float32)
|
|
denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w))
|
|
|
|
audio_denoised = None
|
|
if enable_audio and audio_velocity is not None and not audio_frozen:
|
|
ab, ac, at, af = audio_latents.shape
|
|
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
|
|
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
|
|
audio_denoised = audio_latents - sigma_f32 * audio_velocity.astype(mx.float32)
|
|
|
|
if state is not None:
|
|
denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask)
|
|
|
|
mx.eval(denoised)
|
|
if audio_denoised is not None:
|
|
mx.eval(audio_denoised)
|
|
|
|
# Euler step in float32
|
|
if sigma_next > 0:
|
|
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
|
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
|
|
if enable_audio and audio_denoised is not None and not audio_frozen:
|
|
audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32
|
|
else:
|
|
latents = denoised
|
|
if enable_audio and audio_denoised is not None and not audio_frozen:
|
|
audio_latents = audio_denoised
|
|
|
|
mx.eval(latents)
|
|
if enable_audio:
|
|
mx.eval(audio_latents)
|
|
|
|
progress.advance(task)
|
|
|
|
return latents.astype(dtype), audio_latents.astype(dtype) if enable_audio else None
|
|
|
|
|
|
# =============================================================================
|
|
# Dev Pipeline Denoising (with CFG, dynamic sigmas)
|
|
# =============================================================================
|
|
|
|
def denoise_dev(
|
|
latents: mx.array,
|
|
positions: mx.array,
|
|
text_embeddings_pos: mx.array,
|
|
text_embeddings_neg: mx.array,
|
|
transformer: LTXModel,
|
|
sigmas: mx.array,
|
|
cfg_scale: float = 4.0,
|
|
cfg_rescale: float = 0.0,
|
|
verbose: bool = True,
|
|
state: Optional[LatentState] = None,
|
|
use_apg: bool = False,
|
|
apg_eta: float = 1.0,
|
|
apg_norm_threshold: float = 0.0,
|
|
stg_scale: float = 0.0,
|
|
stg_blocks: Optional[list] = None,
|
|
) -> mx.array:
|
|
"""Run denoising loop for dev pipeline with CFG/APG and optional STG guidance.
|
|
|
|
Args:
|
|
cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction
|
|
variance relative to conditional prediction to reduce over-saturation.
|
|
PyTorch default is 0.7. Set to 0.0 to disable.
|
|
use_apg: Use Adaptive Projected Guidance instead of standard CFG.
|
|
APG decomposes guidance into parallel/orthogonal components
|
|
for more stable I2V generation.
|
|
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
|
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
|
stg_scale: STG (Spatiotemporal Guidance) scale. 0.0 = disabled.
|
|
stg_blocks: Transformer block indices for STG perturbation.
|
|
"""
|
|
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
|
|
|
dtype = latents.dtype
|
|
if state is not None:
|
|
latents = state.latent
|
|
|
|
# Keep latents in float32 throughout the denoising loop to avoid
|
|
# quantization noise accumulation over many steps.
|
|
# Model input is cast to model dtype; all denoising math stays in float32.
|
|
latents = latents.astype(mx.float32)
|
|
|
|
sigmas_list = sigmas.tolist()
|
|
use_cfg = cfg_scale != 1.0
|
|
use_stg = stg_scale != 0.0 and stg_blocks is not None
|
|
num_steps = len(sigmas_list) - 1
|
|
|
|
# Precompute RoPE once
|
|
precomputed_rope = precompute_freqs_cis(
|
|
positions,
|
|
dim=transformer.inner_dim,
|
|
theta=transformer.positional_embedding_theta,
|
|
max_pos=transformer.positional_embedding_max_pos,
|
|
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
|
num_attention_heads=transformer.num_attention_heads,
|
|
rope_type=transformer.rope_type,
|
|
double_precision=transformer.config.double_precision_rope,
|
|
)
|
|
mx.eval(precomputed_rope)
|
|
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(),
|
|
TaskProgressColumn(),
|
|
TimeRemainingColumn(),
|
|
console=console,
|
|
disable=not verbose,
|
|
) as progress:
|
|
passes = ["CFG"] if use_cfg else []
|
|
if use_stg: passes.append("STG")
|
|
label = "+".join(passes) if passes else "uncond"
|
|
task = progress.add_task(f"[cyan]Denoising ({label})[/]", total=num_steps)
|
|
|
|
for i in range(num_steps):
|
|
sigma = sigmas_list[i]
|
|
sigma_next = sigmas_list[i + 1]
|
|
|
|
b, c, f, h, w = latents.shape
|
|
num_tokens = f * h * w
|
|
# Cast to model dtype for transformer input
|
|
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
|
|
|
if state is not None:
|
|
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
|
|
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
|
else:
|
|
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
|
|
|
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
|
|
|
# Positive conditioning pass
|
|
video_modality_pos = Modality(
|
|
latent=latents_flat,
|
|
timesteps=timesteps,
|
|
positions=positions,
|
|
context=text_embeddings_pos,
|
|
context_mask=None,
|
|
enabled=True,
|
|
positional_embeddings=precomputed_rope,
|
|
sigma=sigma_array,
|
|
)
|
|
velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
|
|
|
|
# Convert velocity to x0 (denoised) using per-token timesteps
|
|
# Matches PyTorch's X0Model: x0 = latent - timestep * velocity
|
|
# For conditioned tokens (timestep=0): x0 = latent (correct regardless of velocity)
|
|
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
|
|
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
|
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
|
|
x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32)
|
|
|
|
# Start with positive prediction
|
|
x0_guided_f32 = x0_pos_f32
|
|
|
|
if use_cfg:
|
|
# Negative conditioning pass
|
|
video_modality_neg = Modality(
|
|
latent=latents_flat,
|
|
timesteps=timesteps,
|
|
positions=positions,
|
|
context=text_embeddings_neg,
|
|
context_mask=None,
|
|
enabled=True,
|
|
positional_embeddings=precomputed_rope,
|
|
sigma=sigma_array,
|
|
)
|
|
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
|
|
|
|
# Convert negative velocity to x0 using per-token timesteps
|
|
x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32)
|
|
|
|
# Apply guidance to x0 predictions
|
|
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0
|
|
if use_apg:
|
|
# APG: decompose into parallel/orthogonal components for stability
|
|
x0_guided_f32 = x0_pos_f32 + apg_delta(
|
|
x0_pos_f32, x0_neg_f32, cfg_scale,
|
|
eta=apg_eta, norm_threshold=apg_norm_threshold
|
|
)
|
|
else:
|
|
# Standard CFG
|
|
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
|
|
|
# STG pass: skip self-attention at specified blocks
|
|
if use_stg:
|
|
velocity_ptb, _ = transformer(
|
|
video=video_modality_pos, audio=None,
|
|
stg_video_blocks=stg_blocks,
|
|
)
|
|
mx.eval(velocity_ptb)
|
|
|
|
x0_ptb_f32 = latents_flat_f32 - timesteps_f32 * velocity_ptb.astype(mx.float32)
|
|
x0_guided_f32 = x0_guided_f32 + stg_scale * (x0_pos_f32 - x0_ptb_f32)
|
|
|
|
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation)
|
|
# factor = rescale * (cond_std / pred_std) + (1 - rescale)
|
|
# pred = pred * factor
|
|
if cfg_rescale > 0.0 and (use_cfg or use_stg):
|
|
v_factor = x0_pos_f32.std() / (x0_guided_f32.std() + 1e-8)
|
|
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
|
x0_guided_f32 = x0_guided_f32 * v_factor
|
|
|
|
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
|
|
denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
|
|
|
|
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
|
|
|
if state is not None:
|
|
denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask)
|
|
|
|
# Euler step in float32 (latents stay in float32)
|
|
if sigma_next > 0:
|
|
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
|
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
|
|
else:
|
|
latents = denoised
|
|
|
|
mx.eval(latents)
|
|
progress.advance(task)
|
|
|
|
return latents.astype(dtype)
|
|
|
|
|
|
def denoise_dev_av(
|
|
video_latents: mx.array,
|
|
audio_latents: mx.array,
|
|
video_positions: mx.array,
|
|
audio_positions: mx.array,
|
|
video_embeddings_pos: mx.array,
|
|
video_embeddings_neg: mx.array,
|
|
audio_embeddings_pos: mx.array,
|
|
audio_embeddings_neg: mx.array,
|
|
transformer: LTXModel,
|
|
sigmas: mx.array,
|
|
cfg_scale: float = 4.0,
|
|
audio_cfg_scale: float = 7.0,
|
|
cfg_rescale: float = 0.0,
|
|
verbose: bool = True,
|
|
video_state: Optional[LatentState] = None,
|
|
use_apg: bool = False,
|
|
apg_eta: float = 1.0,
|
|
apg_norm_threshold: float = 0.0,
|
|
stg_scale: float = 0.0,
|
|
stg_video_blocks: Optional[list] = None,
|
|
stg_audio_blocks: Optional[list] = None,
|
|
modality_scale: float = 1.0,
|
|
audio_frozen: bool = False,
|
|
) -> tuple[mx.array, mx.array]:
|
|
"""Run denoising loop for dev pipeline with CFG/APG, STG, modality guidance, and audio.
|
|
|
|
Args:
|
|
audio_cfg_scale: Separate CFG scale for audio (PyTorch default: 7.0).
|
|
cfg_rescale: Rescale factor for CFG (0.0-1.0). Normalizes guided prediction
|
|
variance to reduce artifacts. Default 0.0 means no rescaling.
|
|
use_apg: Use Adaptive Projected Guidance instead of standard CFG for video.
|
|
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
|
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
|
stg_scale: STG (Spatiotemporal Guidance) scale. 0.0 = disabled.
|
|
stg_video_blocks: Transformer block indices for video STG perturbation.
|
|
stg_audio_blocks: Transformer block indices for audio STG perturbation.
|
|
modality_scale: Cross-modal guidance scale. 1.0 = disabled.
|
|
"""
|
|
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
|
|
|
dtype = video_latents.dtype
|
|
if video_state is not None:
|
|
video_latents = video_state.latent
|
|
|
|
# Keep latents in float32 throughout the denoising loop for precision.
|
|
video_latents = video_latents.astype(mx.float32)
|
|
audio_latents = audio_latents.astype(mx.float32)
|
|
|
|
sigmas_list = sigmas.tolist()
|
|
use_cfg = cfg_scale != 1.0
|
|
use_stg = stg_scale != 0.0 and stg_video_blocks is not None
|
|
use_modality = modality_scale != 1.0
|
|
num_steps = len(sigmas_list) - 1
|
|
|
|
# Precompute video RoPE
|
|
precomputed_video_rope = precompute_freqs_cis(
|
|
video_positions,
|
|
dim=transformer.inner_dim,
|
|
theta=transformer.positional_embedding_theta,
|
|
max_pos=transformer.positional_embedding_max_pos,
|
|
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
|
num_attention_heads=transformer.num_attention_heads,
|
|
rope_type=transformer.rope_type,
|
|
double_precision=transformer.config.double_precision_rope,
|
|
)
|
|
|
|
# Precompute audio RoPE
|
|
precomputed_audio_rope = precompute_freqs_cis(
|
|
audio_positions,
|
|
dim=transformer.audio_inner_dim,
|
|
theta=transformer.positional_embedding_theta,
|
|
max_pos=transformer.audio_positional_embedding_max_pos,
|
|
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
|
num_attention_heads=transformer.audio_num_attention_heads,
|
|
rope_type=transformer.rope_type,
|
|
double_precision=transformer.config.double_precision_rope,
|
|
)
|
|
mx.eval(precomputed_video_rope, precomputed_audio_rope)
|
|
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(),
|
|
TaskProgressColumn(),
|
|
TimeRemainingColumn(),
|
|
console=console,
|
|
disable=not verbose,
|
|
) as progress:
|
|
passes = ["CFG"] if use_cfg else []
|
|
if use_stg: passes.append("STG")
|
|
if use_modality: passes.append("Mod")
|
|
label = "+".join(passes) if passes else "uncond"
|
|
task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=num_steps)
|
|
|
|
for i in range(num_steps):
|
|
sigma = sigmas_list[i]
|
|
sigma_next = sigmas_list[i + 1]
|
|
|
|
# Flatten video latents (cast to model dtype for transformer input)
|
|
b, c, f, h, w = video_latents.shape
|
|
num_video_tokens = f * h * w
|
|
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
|
|
|
# Flatten audio latents (cast to model dtype for transformer input)
|
|
ab, ac, at, af = audio_latents.shape
|
|
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
|
|
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
|
|
|
# Compute timesteps
|
|
if video_state is not None:
|
|
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
|
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
|
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
|
else:
|
|
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
|
|
|
# A2V: frozen audio uses timesteps=0 (tells model audio is clean)
|
|
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
|
|
|
# Positive conditioning pass
|
|
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
|
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
|
video_modality_pos = Modality(
|
|
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
|
context=video_embeddings_pos, context_mask=None, enabled=True,
|
|
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
|
)
|
|
audio_modality_pos = Modality(
|
|
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
|
context=audio_embeddings_pos, context_mask=None, enabled=True,
|
|
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
|
)
|
|
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
|
mx.eval(video_vel_pos, audio_vel_pos)
|
|
|
|
# Convert velocity to denoised (x0) using per-token timesteps
|
|
# This matches PyTorch's X0ModelWrapper: x0 = latent - timestep * velocity
|
|
# For conditioned tokens (timestep=0): x0 = latent (velocity is irrelevant)
|
|
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
|
|
video_flat_f32 = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
|
audio_flat_f32 = mx.reshape(mx.transpose(audio_latents, (0, 2, 1, 3)), (ab, at, ac * af))
|
|
video_timesteps_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
|
|
audio_timesteps_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1)
|
|
|
|
video_x0_pos_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_pos.astype(mx.float32)
|
|
audio_x0_pos_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_pos.astype(mx.float32)
|
|
|
|
# Start with positive prediction
|
|
video_x0_guided_f32 = video_x0_pos_f32
|
|
audio_x0_guided_f32 = audio_x0_pos_f32
|
|
|
|
# Pass 2: CFG (negative conditioning)
|
|
if use_cfg:
|
|
video_modality_neg = Modality(
|
|
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
|
context=video_embeddings_neg, context_mask=None, enabled=True,
|
|
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
|
)
|
|
audio_modality_neg = Modality(
|
|
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
|
context=audio_embeddings_neg, context_mask=None, enabled=True,
|
|
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
|
)
|
|
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
|
mx.eval(video_vel_neg, audio_vel_neg)
|
|
|
|
video_x0_neg_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_neg.astype(mx.float32)
|
|
audio_x0_neg_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_neg.astype(mx.float32)
|
|
|
|
if use_apg:
|
|
video_x0_guided_f32 = video_x0_pos_f32 + apg_delta(
|
|
video_x0_pos_f32, video_x0_neg_f32, cfg_scale,
|
|
eta=apg_eta, norm_threshold=apg_norm_threshold
|
|
)
|
|
else:
|
|
video_x0_guided_f32 = video_x0_pos_f32 + (cfg_scale - 1.0) * (video_x0_pos_f32 - video_x0_neg_f32)
|
|
audio_x0_guided_f32 = audio_x0_pos_f32 + (audio_cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
|
|
|
# Pass 3: STG (self-attention perturbation at specified blocks)
|
|
if use_stg:
|
|
video_vel_ptb, audio_vel_ptb = transformer(
|
|
video=video_modality_pos, audio=audio_modality_pos,
|
|
stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks,
|
|
)
|
|
mx.eval(video_vel_ptb, audio_vel_ptb)
|
|
|
|
video_x0_ptb_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_ptb.astype(mx.float32)
|
|
audio_x0_ptb_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_ptb.astype(mx.float32)
|
|
|
|
video_x0_guided_f32 = video_x0_guided_f32 + stg_scale * (video_x0_pos_f32 - video_x0_ptb_f32)
|
|
audio_x0_guided_f32 = audio_x0_guided_f32 + stg_scale * (audio_x0_pos_f32 - audio_x0_ptb_f32)
|
|
|
|
# Pass 4: Modality isolation (skip all cross-modal attention)
|
|
if use_modality:
|
|
video_vel_iso, audio_vel_iso = transformer(
|
|
video=video_modality_pos, audio=audio_modality_pos,
|
|
skip_cross_modal=True,
|
|
)
|
|
mx.eval(video_vel_iso, audio_vel_iso)
|
|
|
|
video_x0_iso_f32 = video_flat_f32 - video_timesteps_f32 * video_vel_iso.astype(mx.float32)
|
|
audio_x0_iso_f32 = audio_flat_f32 - audio_timesteps_f32 * audio_vel_iso.astype(mx.float32)
|
|
|
|
video_x0_guided_f32 = video_x0_guided_f32 + (modality_scale - 1.0) * (video_x0_pos_f32 - video_x0_iso_f32)
|
|
audio_x0_guided_f32 = audio_x0_guided_f32 + (modality_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_iso_f32)
|
|
|
|
# Apply CFG rescale (std-ratio rescaling to reduce over-saturation)
|
|
if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
|
|
v_factor = video_x0_pos_f32.std() / (video_x0_guided_f32.std() + 1e-8)
|
|
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
|
video_x0_guided_f32 = video_x0_guided_f32 * v_factor
|
|
a_factor = audio_x0_pos_f32.std() / (audio_x0_guided_f32.std() + 1e-8)
|
|
a_factor = cfg_rescale * a_factor + (1.0 - cfg_rescale)
|
|
audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor
|
|
|
|
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
|
|
video_denoised_f32 = mx.reshape(mx.transpose(video_x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
|
|
audio_denoised_f32 = mx.reshape(audio_x0_guided_f32, (ab, at, ac, af))
|
|
audio_denoised_f32 = mx.transpose(audio_denoised_f32, (0, 2, 1, 3))
|
|
|
|
# Post-process: blend denoised with clean latent using mask
|
|
# Matches PyTorch's post_process_latent: denoised * mask + clean * (1 - mask)
|
|
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
|
|
|
if video_state is not None:
|
|
clean_f32 = video_state.clean_latent.astype(mx.float32)
|
|
mask_f32 = video_state.denoise_mask.astype(mx.float32)
|
|
video_denoised_f32 = video_denoised_f32 * mask_f32 + clean_f32 * (1.0 - mask_f32)
|
|
|
|
mx.eval(video_denoised_f32, audio_denoised_f32)
|
|
|
|
# Euler step: sample + velocity * dt (float32)
|
|
if sigma_next > 0:
|
|
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
|
dt_f32 = sigma_next_f32 - sigma_f32
|
|
|
|
video_velocity_f32 = (video_latents - video_denoised_f32) / sigma_f32
|
|
video_latents = video_latents + video_velocity_f32 * dt_f32
|
|
|
|
if not audio_frozen:
|
|
audio_velocity_f32 = (audio_latents - audio_denoised_f32) / sigma_f32
|
|
audio_latents = audio_latents + audio_velocity_f32 * dt_f32
|
|
else:
|
|
video_latents = video_denoised_f32
|
|
if not audio_frozen:
|
|
audio_latents = audio_denoised_f32
|
|
|
|
mx.eval(video_latents, audio_latents)
|
|
progress.advance(task)
|
|
|
|
return video_latents, audio_latents
|
|
|
|
|
|
def denoise_res2s_av(
|
|
video_latents: mx.array,
|
|
audio_latents: mx.array,
|
|
video_positions: mx.array,
|
|
audio_positions: mx.array,
|
|
video_embeddings_pos: mx.array,
|
|
video_embeddings_neg: mx.array,
|
|
audio_embeddings_pos: mx.array,
|
|
audio_embeddings_neg: mx.array,
|
|
transformer: LTXModel,
|
|
sigmas: mx.array,
|
|
cfg_scale: float = 3.0,
|
|
audio_cfg_scale: float = 7.0,
|
|
cfg_rescale: float = 0.45,
|
|
audio_cfg_rescale: Optional[float] = None,
|
|
verbose: bool = True,
|
|
video_state: Optional[LatentState] = None,
|
|
stg_scale: float = 0.0,
|
|
stg_video_blocks: Optional[list] = None,
|
|
stg_audio_blocks: Optional[list] = None,
|
|
modality_scale: float = 1.0,
|
|
noise_seed: int = 42,
|
|
bongmath: bool = True,
|
|
bongmath_max_iter: int = 100,
|
|
audio_frozen: bool = False,
|
|
) -> tuple[mx.array, mx.array]:
|
|
"""Run res_2s second-order denoising loop with CFG/STG/modality guidance.
|
|
|
|
Two model evaluations per step (current point + midpoint), with SDE noise
|
|
injection and optional bong iteration for anchor refinement.
|
|
|
|
Args:
|
|
audio_cfg_rescale: Separate rescale for audio. If None, uses cfg_rescale.
|
|
noise_seed: Seed for SDE noise generators.
|
|
bongmath: Enable iterative anchor refinement.
|
|
bongmath_max_iter: Max bong iterations per step.
|
|
"""
|
|
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
|
from mlx_video.models.ltx_2.samplers import get_res2s_coefficients, sde_noise_step, get_new_noise
|
|
|
|
if audio_cfg_rescale is None:
|
|
audio_cfg_rescale = cfg_rescale
|
|
|
|
dtype = video_latents.dtype
|
|
if video_state is not None:
|
|
video_latents = video_state.latent
|
|
|
|
video_latents = video_latents.astype(mx.float32)
|
|
audio_latents = audio_latents.astype(mx.float32)
|
|
|
|
sigmas_list = sigmas.tolist()
|
|
use_cfg = cfg_scale != 1.0
|
|
use_stg = stg_scale != 0.0 and stg_video_blocks is not None
|
|
use_modality = modality_scale != 1.0
|
|
n_full_steps = len(sigmas_list) - 1
|
|
|
|
# Pad sigmas if last is 0 (avoid division by zero in RK steps)
|
|
if sigmas_list[-1] == 0:
|
|
sigmas_list = sigmas_list[:-1] + [0.0011, 0.0]
|
|
|
|
# Compute step sizes in log-space for the main loop steps only.
|
|
# After padding, sigmas_list may have an extra [0.0011, 0.0] tail;
|
|
# we only need hs for the n_full_steps pairs the loop actually uses.
|
|
hs = [-math.log(sigmas_list[i + 1] / sigmas_list[i]) for i in range(n_full_steps)]
|
|
|
|
# Precompute RoPE
|
|
precomputed_video_rope = precompute_freqs_cis(
|
|
video_positions,
|
|
dim=transformer.inner_dim,
|
|
theta=transformer.positional_embedding_theta,
|
|
max_pos=transformer.positional_embedding_max_pos,
|
|
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
|
num_attention_heads=transformer.num_attention_heads,
|
|
rope_type=transformer.rope_type,
|
|
double_precision=transformer.config.double_precision_rope,
|
|
)
|
|
precomputed_audio_rope = precompute_freqs_cis(
|
|
audio_positions,
|
|
dim=transformer.audio_inner_dim,
|
|
theta=transformer.positional_embedding_theta,
|
|
max_pos=transformer.audio_positional_embedding_max_pos,
|
|
use_middle_indices_grid=transformer.use_middle_indices_grid,
|
|
num_attention_heads=transformer.audio_num_attention_heads,
|
|
rope_type=transformer.rope_type,
|
|
double_precision=transformer.config.double_precision_rope,
|
|
)
|
|
mx.eval(precomputed_video_rope, precomputed_audio_rope)
|
|
|
|
phi_cache = {}
|
|
c2 = 0.5
|
|
|
|
# Noise key management: step noise and substep noise use different keys
|
|
step_noise_key = mx.random.key(noise_seed)
|
|
substep_noise_key = mx.random.key(noise_seed + 10000)
|
|
|
|
def _eval_guided_denoise(v_latents, a_latents, sigma):
|
|
"""Run all guidance passes and return (video_denoised, audio_denoised) in float32 spatial format."""
|
|
b, c, f, h, w = v_latents.shape
|
|
num_video_tokens = f * h * w
|
|
video_flat = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
|
|
|
ab, ac, at, af = a_latents.shape
|
|
audio_flat = mx.transpose(a_latents, (0, 2, 1, 3))
|
|
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
|
|
|
# Timesteps
|
|
if video_state is not None:
|
|
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
|
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
|
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
|
else:
|
|
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
|
audio_timesteps = mx.zeros((ab, at), dtype=dtype) if audio_frozen else mx.full((ab, at), sigma, dtype=dtype)
|
|
|
|
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
|
audio_sigma_array = mx.zeros((ab,), dtype=dtype) if audio_frozen else mx.full((ab,), sigma, dtype=dtype)
|
|
|
|
# Pass 1: Positive conditioning
|
|
video_modality_pos = Modality(
|
|
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
|
context=video_embeddings_pos, context_mask=None, enabled=True,
|
|
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
|
)
|
|
audio_modality_pos = Modality(
|
|
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
|
context=audio_embeddings_pos, context_mask=None, enabled=True,
|
|
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
|
)
|
|
video_vel_pos, audio_vel_pos = transformer(video=video_modality_pos, audio=audio_modality_pos)
|
|
mx.eval(video_vel_pos, audio_vel_pos)
|
|
|
|
# Convert velocity to x0
|
|
video_flat_f32 = mx.transpose(mx.reshape(v_latents, (b, c, -1)), (0, 2, 1))
|
|
audio_flat_f32 = mx.reshape(mx.transpose(a_latents, (0, 2, 1, 3)), (ab, at, ac * af))
|
|
video_ts_f32 = mx.expand_dims(video_timesteps.astype(mx.float32), axis=-1)
|
|
audio_ts_f32 = mx.expand_dims(audio_timesteps.astype(mx.float32), axis=-1)
|
|
|
|
video_x0_pos = video_flat_f32 - video_ts_f32 * video_vel_pos.astype(mx.float32)
|
|
audio_x0_pos = audio_flat_f32 - audio_ts_f32 * audio_vel_pos.astype(mx.float32)
|
|
|
|
video_x0_guided = video_x0_pos
|
|
audio_x0_guided = audio_x0_pos
|
|
|
|
# Pass 2: CFG
|
|
if use_cfg:
|
|
video_modality_neg = Modality(
|
|
latent=video_flat, timesteps=video_timesteps, positions=video_positions,
|
|
context=video_embeddings_neg, context_mask=None, enabled=True,
|
|
positional_embeddings=precomputed_video_rope, sigma=sigma_array,
|
|
)
|
|
audio_modality_neg = Modality(
|
|
latent=audio_flat, timesteps=audio_timesteps, positions=audio_positions,
|
|
context=audio_embeddings_neg, context_mask=None, enabled=True,
|
|
positional_embeddings=precomputed_audio_rope, sigma=audio_sigma_array,
|
|
)
|
|
video_vel_neg, audio_vel_neg = transformer(video=video_modality_neg, audio=audio_modality_neg)
|
|
mx.eval(video_vel_neg, audio_vel_neg)
|
|
|
|
video_x0_neg = video_flat_f32 - video_ts_f32 * video_vel_neg.astype(mx.float32)
|
|
audio_x0_neg = audio_flat_f32 - audio_ts_f32 * audio_vel_neg.astype(mx.float32)
|
|
|
|
video_x0_guided = video_x0_pos + (cfg_scale - 1.0) * (video_x0_pos - video_x0_neg)
|
|
audio_x0_guided = audio_x0_pos + (audio_cfg_scale - 1.0) * (audio_x0_pos - audio_x0_neg)
|
|
|
|
# Pass 3: STG
|
|
if use_stg:
|
|
video_vel_ptb, audio_vel_ptb = transformer(
|
|
video=video_modality_pos, audio=audio_modality_pos,
|
|
stg_video_blocks=stg_video_blocks, stg_audio_blocks=stg_audio_blocks,
|
|
)
|
|
mx.eval(video_vel_ptb, audio_vel_ptb)
|
|
|
|
video_x0_ptb = video_flat_f32 - video_ts_f32 * video_vel_ptb.astype(mx.float32)
|
|
audio_x0_ptb = audio_flat_f32 - audio_ts_f32 * audio_vel_ptb.astype(mx.float32)
|
|
|
|
video_x0_guided = video_x0_guided + stg_scale * (video_x0_pos - video_x0_ptb)
|
|
audio_x0_guided = audio_x0_guided + stg_scale * (audio_x0_pos - audio_x0_ptb)
|
|
|
|
# Pass 4: Modality isolation
|
|
if use_modality:
|
|
video_vel_iso, audio_vel_iso = transformer(
|
|
video=video_modality_pos, audio=audio_modality_pos,
|
|
skip_cross_modal=True,
|
|
)
|
|
mx.eval(video_vel_iso, audio_vel_iso)
|
|
|
|
video_x0_iso = video_flat_f32 - video_ts_f32 * video_vel_iso.astype(mx.float32)
|
|
audio_x0_iso = audio_flat_f32 - audio_ts_f32 * audio_vel_iso.astype(mx.float32)
|
|
|
|
video_x0_guided = video_x0_guided + (modality_scale - 1.0) * (video_x0_pos - video_x0_iso)
|
|
audio_x0_guided = audio_x0_guided + (modality_scale - 1.0) * (audio_x0_pos - audio_x0_iso)
|
|
|
|
# Rescale (separate factors for video and audio)
|
|
if cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
|
|
v_factor = video_x0_pos.std() / (video_x0_guided.std() + 1e-8)
|
|
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
|
video_x0_guided = video_x0_guided * v_factor
|
|
if audio_cfg_rescale > 0.0 and (use_cfg or use_stg or use_modality):
|
|
a_factor = audio_x0_pos.std() / (audio_x0_guided.std() + 1e-8)
|
|
a_factor = audio_cfg_rescale * a_factor + (1.0 - audio_cfg_rescale)
|
|
audio_x0_guided = audio_x0_guided * a_factor
|
|
|
|
# Reshape to spatial
|
|
video_denoised = mx.reshape(mx.transpose(video_x0_guided, (0, 2, 1)), (b, c, f, h, w))
|
|
audio_denoised = mx.reshape(audio_x0_guided, (ab, at, ac, af))
|
|
audio_denoised = mx.transpose(audio_denoised, (0, 2, 1, 3))
|
|
|
|
# Post-process with mask
|
|
if video_state is not None:
|
|
clean_f32 = video_state.clean_latent.astype(mx.float32)
|
|
mask_f32 = video_state.denoise_mask.astype(mx.float32)
|
|
video_denoised = video_denoised * mask_f32 + clean_f32 * (1.0 - mask_f32)
|
|
|
|
mx.eval(video_denoised, audio_denoised)
|
|
return video_denoised, audio_denoised
|
|
|
|
# Main res_2s loop
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(),
|
|
TaskProgressColumn(),
|
|
TimeRemainingColumn(),
|
|
console=console,
|
|
disable=not verbose,
|
|
) as progress:
|
|
passes = ["res2s"]
|
|
if use_cfg: passes.append("CFG")
|
|
if use_stg: passes.append("STG")
|
|
if use_modality: passes.append("Mod")
|
|
label = "+".join(passes)
|
|
task = progress.add_task(f"[cyan]Denoising A/V ({label})[/]", total=n_full_steps)
|
|
|
|
for step_idx in range(n_full_steps):
|
|
sigma = sigmas_list[step_idx]
|
|
sigma_next = sigmas_list[step_idx + 1]
|
|
h = hs[step_idx]
|
|
|
|
# Initialize anchor
|
|
x_anchor_video = video_latents
|
|
x_anchor_audio = audio_latents
|
|
|
|
# ============================================================
|
|
# Stage 1: Evaluate denoiser at current sigma
|
|
# ============================================================
|
|
denoised_video_1, denoised_audio_1 = _eval_guided_denoise(
|
|
video_latents, audio_latents, sigma
|
|
)
|
|
|
|
# RK coefficients
|
|
a21, b1, b2 = get_res2s_coefficients(h, phi_cache, c2)
|
|
|
|
# Substep sigma (geometric midpoint for c2=0.5)
|
|
sub_sigma = math.sqrt(sigma * sigma_next)
|
|
|
|
# Compute midpoint
|
|
eps_1_video = denoised_video_1 - x_anchor_video
|
|
x_mid_video = x_anchor_video + h * a21 * eps_1_video
|
|
|
|
if not audio_frozen:
|
|
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
|
x_mid_audio = x_anchor_audio + h * a21 * eps_1_audio
|
|
else:
|
|
eps_1_audio = None
|
|
x_mid_audio = audio_latents # frozen: pass through unchanged
|
|
|
|
# SDE noise injection at substep
|
|
substep_noise_key, key1, key2 = mx.random.split(substep_noise_key, 3)
|
|
substep_noise_v = get_new_noise(video_latents.shape, key1)
|
|
|
|
x_mid_video = sde_noise_step(x_anchor_video, x_mid_video, sigma, sub_sigma, substep_noise_v)
|
|
if not audio_frozen:
|
|
substep_noise_a = get_new_noise(audio_latents.shape, key2)
|
|
x_mid_audio = sde_noise_step(x_anchor_audio, x_mid_audio, sigma, sub_sigma, substep_noise_a)
|
|
mx.eval(x_mid_video, x_mid_audio)
|
|
|
|
# ============================================================
|
|
# Bong iteration: refine anchor (pure arithmetic, no model calls)
|
|
# ============================================================
|
|
if bongmath and h < 0.5 and sigma > 0.03:
|
|
for _ in range(bongmath_max_iter):
|
|
x_anchor_video = x_mid_video - h * a21 * eps_1_video
|
|
eps_1_video = denoised_video_1 - x_anchor_video
|
|
if not audio_frozen:
|
|
x_anchor_audio = x_mid_audio - h * a21 * eps_1_audio
|
|
eps_1_audio = denoised_audio_1 - x_anchor_audio
|
|
if audio_frozen:
|
|
mx.eval(x_anchor_video, eps_1_video)
|
|
else:
|
|
mx.eval(x_anchor_video, x_anchor_audio, eps_1_video, eps_1_audio)
|
|
|
|
# ============================================================
|
|
# Stage 2: Evaluate denoiser at midpoint sigma
|
|
# ============================================================
|
|
denoised_video_2, denoised_audio_2 = _eval_guided_denoise(
|
|
x_mid_video.astype(mx.float32), x_mid_audio.astype(mx.float32), sub_sigma
|
|
)
|
|
|
|
# ============================================================
|
|
# Final combination with RK coefficients
|
|
# ============================================================
|
|
eps_2_video = denoised_video_2 - x_anchor_video
|
|
x_next_video = x_anchor_video + h * (b1 * eps_1_video + b2 * eps_2_video)
|
|
|
|
# SDE noise injection at step level
|
|
step_noise_key, key1, key2 = mx.random.split(step_noise_key, 3)
|
|
step_noise_v = get_new_noise(video_latents.shape, key1)
|
|
x_next_video = sde_noise_step(x_anchor_video, x_next_video, sigma, sigma_next, step_noise_v)
|
|
|
|
video_latents = x_next_video.astype(mx.float32)
|
|
if not audio_frozen:
|
|
eps_2_audio = denoised_audio_2 - x_anchor_audio
|
|
x_next_audio = x_anchor_audio + h * (b1 * eps_1_audio + b2 * eps_2_audio)
|
|
step_noise_a = get_new_noise(audio_latents.shape, key2)
|
|
x_next_audio = sde_noise_step(x_anchor_audio, x_next_audio, sigma, sigma_next, step_noise_a)
|
|
audio_latents = x_next_audio.astype(mx.float32)
|
|
|
|
mx.eval(video_latents, audio_latents)
|
|
progress.advance(task)
|
|
|
|
# Final clean step if original schedule ended at 0
|
|
if sigmas.tolist()[-1] == 0:
|
|
denoised_video, denoised_audio = _eval_guided_denoise(
|
|
video_latents, audio_latents, sigmas_list[n_full_steps]
|
|
)
|
|
video_latents = denoised_video
|
|
if not audio_frozen:
|
|
audio_latents = denoised_audio
|
|
mx.eval(video_latents, audio_latents)
|
|
|
|
return video_latents, audio_latents
|
|
|
|
|
|
# =============================================================================
|
|
# Audio Loading and Processing
|
|
# =============================================================================
|
|
|
|
def load_audio_decoder(model_path: Path, pipeline: PipelineType):
|
|
"""Load audio VAE decoder."""
|
|
from mlx_video.models.ltx_2.audio_vae import AudioDecoder
|
|
|
|
decoder = AudioDecoder.from_pretrained(model_path / "audio_vae" / "decoder")
|
|
|
|
return decoder
|
|
|
|
|
|
def load_vocoder_model(model_path: Path, pipeline: PipelineType):
|
|
"""Load vocoder for mel to waveform conversion.
|
|
|
|
Automatically detects HiFi-GAN (LTX-2) or BigVGAN+BWE (LTX-2.3).
|
|
"""
|
|
from mlx_video.models.ltx_2.audio_vae.vocoder import load_vocoder as _load_vocoder
|
|
|
|
return _load_vocoder(model_path / "vocoder")
|
|
|
|
|
|
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
|
|
"""Save audio to WAV file."""
|
|
import wave
|
|
|
|
if audio.ndim == 2:
|
|
audio = audio.T
|
|
|
|
audio = np.clip(audio, -1.0, 1.0)
|
|
audio_int16 = (audio * 32767).astype(np.int16)
|
|
|
|
with wave.open(str(path), 'wb') as wf:
|
|
wf.setnchannels(2 if audio_int16.ndim == 2 else 1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(sample_rate)
|
|
wf.writeframes(audio_int16.tobytes())
|
|
|
|
|
|
def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path):
|
|
"""Combine video and audio into final output using ffmpeg."""
|
|
import subprocess
|
|
|
|
cmd = [
|
|
"ffmpeg", "-y",
|
|
"-i", str(video_path),
|
|
"-i", str(audio_path),
|
|
"-c:v", "copy",
|
|
"-c:a", "aac",
|
|
"-shortest",
|
|
str(output_path)
|
|
]
|
|
|
|
try:
|
|
subprocess.run(cmd, check=True, capture_output=True)
|
|
return True
|
|
except subprocess.CalledProcessError as e:
|
|
console.print(f"[red]FFmpeg error: {e.stderr.decode()}[/]")
|
|
return False
|
|
except FileNotFoundError:
|
|
console.print("[red]FFmpeg not found. Please install ffmpeg.[/]")
|
|
return False
|
|
|
|
|
|
# =============================================================================
|
|
# Unified Generate Function
|
|
# =============================================================================
|
|
|
|
def generate_video(
|
|
model_repo: str,
|
|
text_encoder_repo: str,
|
|
prompt: str,
|
|
pipeline: PipelineType = PipelineType.DISTILLED,
|
|
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
|
|
height: int = 512,
|
|
width: int = 512,
|
|
num_frames: int = 33,
|
|
num_inference_steps: int = 40,
|
|
cfg_scale: float = 4.0,
|
|
audio_cfg_scale: float = 7.0,
|
|
cfg_rescale: float = 0.0,
|
|
seed: int = 42,
|
|
fps: int = 24,
|
|
output_path: str = "output.mp4",
|
|
save_frames: bool = False,
|
|
verbose: bool = True,
|
|
enhance_prompt: bool = False,
|
|
max_tokens: int = 512,
|
|
temperature: float = 0.7,
|
|
image: Optional[str] = None,
|
|
image_strength: float = 1.0,
|
|
image_frame_idx: int = 0,
|
|
tiling: str = "auto",
|
|
stream: bool = False,
|
|
audio: bool = False,
|
|
output_audio_path: Optional[str] = None,
|
|
use_apg: bool = False,
|
|
apg_eta: float = 1.0,
|
|
apg_norm_threshold: float = 0.0,
|
|
stg_scale: float = 0.0,
|
|
stg_blocks: Optional[list] = None,
|
|
modality_scale: float = 1.0,
|
|
lora_path: Optional[str] = None,
|
|
lora_strength: float = 1.0,
|
|
lora_strength_stage_1: Optional[float] = None,
|
|
lora_strength_stage_2: Optional[float] = None,
|
|
audio_file: Optional[str] = None,
|
|
audio_start_time: float = 0.0,
|
|
):
|
|
"""Generate video using LTX-2 models.
|
|
|
|
Supports four pipelines:
|
|
- DISTILLED: Two-stage generation with upsampling, fixed sigma schedules, no CFG
|
|
- DEV: Single-stage generation with dynamic sigmas and CFG
|
|
- DEV_TWO_STAGE: Stage 1 dev (half res, CFG) + upsample + stage 2 distilled with LoRA (full res, no CFG)
|
|
- DEV_TWO_STAGE_HQ: res_2s sampler, LoRA both stages (0.25/0.5), lower rescale
|
|
|
|
Args:
|
|
model_repo: Model repository ID
|
|
text_encoder_repo: Text encoder repository ID
|
|
prompt: Text description of the video to generate
|
|
pipeline: Pipeline type (DISTILLED or DEV)
|
|
negative_prompt: Negative prompt for CFG (dev pipeline only)
|
|
height: Output video height (must be divisible by 32/64)
|
|
width: Output video width (must be divisible by 32/64)
|
|
num_frames: Number of frames (must be 1 + 8*k)
|
|
num_inference_steps: Number of denoising steps (dev pipeline only)
|
|
cfg_scale: Guidance scale for CFG (dev pipeline only)
|
|
seed: Random seed for reproducibility
|
|
fps: Frames per second for output video
|
|
output_path: Path to save the output video
|
|
save_frames: Whether to save individual frames as images
|
|
verbose: Whether to print progress
|
|
enhance_prompt: Whether to enhance prompt using Gemma
|
|
max_tokens: Max tokens for prompt enhancement
|
|
temperature: Temperature for prompt enhancement
|
|
image: Path to conditioning image for I2V
|
|
image_strength: Conditioning strength for I2V
|
|
image_frame_idx: Frame index to condition for I2V
|
|
tiling: Tiling mode for VAE decoding
|
|
stream: Stream frames to output as they're decoded
|
|
audio: Enable synchronized audio generation
|
|
output_audio_path: Path to save audio file
|
|
use_apg: Use Adaptive Projected Guidance instead of CFG (more stable for I2V)
|
|
apg_eta: APG parallel component weight (1.0 = keep full parallel)
|
|
apg_norm_threshold: APG guidance norm clamp (0 = no clamping)
|
|
"""
|
|
start_time = time.time()
|
|
|
|
# Validate dimensions
|
|
is_two_stage = pipeline in (PipelineType.DISTILLED, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ)
|
|
divisor = 64 if is_two_stage else 32
|
|
assert height % divisor == 0, f"Height must be divisible by {divisor}, got {height}"
|
|
assert width % divisor == 0, f"Width must be divisible by {divisor}, got {width}"
|
|
|
|
if num_frames % 8 != 1:
|
|
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
|
|
console.print(f"[yellow]⚠️ Number of frames must be 1 + 8*k. Using: {adjusted_num_frames}[/]")
|
|
num_frames = adjusted_num_frames
|
|
|
|
is_i2v = image is not None
|
|
is_a2v = audio_file is not None
|
|
if is_a2v and audio:
|
|
raise ValueError("Cannot use both --audio-file (A2V) and --audio (generate audio). Choose one.")
|
|
# A2V implicitly enables audio path through the transformer
|
|
if is_a2v:
|
|
audio = True
|
|
mode_str = "I2V" if is_i2v else "T2V"
|
|
if is_a2v:
|
|
mode_str = "A2V" + ("+I2V" if is_i2v else "")
|
|
elif audio:
|
|
mode_str += "+Audio"
|
|
|
|
pipeline_names = {
|
|
PipelineType.DISTILLED: "DISTILLED",
|
|
PipelineType.DEV: "DEV",
|
|
PipelineType.DEV_TWO_STAGE: "DEV-TWO-STAGE",
|
|
PipelineType.DEV_TWO_STAGE_HQ: "DEV-TWO-STAGE-HQ",
|
|
}
|
|
pipeline_name = pipeline_names[pipeline]
|
|
header = f"[bold cyan]🎬 [{pipeline_name}] [{mode_str}] {width}x{height} • {num_frames} frames[/]"
|
|
console.print(Panel(header, expand=False))
|
|
console.print(f"[dim]Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}[/]")
|
|
|
|
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
|
audio_cfg_info = f", Audio CFG: {audio_cfg_scale}" if audio else ""
|
|
stg_info = f", STG: {stg_scale} blocks={stg_blocks}" if stg_scale != 0.0 else ""
|
|
mod_info = f", Modality: {modality_scale}" if modality_scale != 1.0 else ""
|
|
console.print(f"[dim]Steps: {num_inference_steps}, CFG: {cfg_scale}{audio_cfg_info}, Rescale: {cfg_rescale}{stg_info}{mod_info}[/]")
|
|
|
|
if is_i2v:
|
|
console.print(f"[dim]Image: {image} (strength={image_strength}, frame={image_frame_idx})[/]")
|
|
|
|
# Always compute audio frames - PyTorch distilled pipeline unconditionally
|
|
# generates audio alongside video (model was trained with joint audio-video).
|
|
# The --audio flag only controls whether audio is decoded and saved to output.
|
|
audio_frames = compute_audio_frames(num_frames, fps)
|
|
if audio:
|
|
console.print(f"[dim]Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz[/]")
|
|
|
|
# Get model path
|
|
model_path = get_model_path(model_repo)
|
|
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
|
|
|
# Calculate latent dimensions
|
|
if is_two_stage:
|
|
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
|
stage2_h, stage2_w = height // 32, width // 32
|
|
else:
|
|
latent_h, latent_w = height // 32, width // 32
|
|
latent_frames = 1 + (num_frames - 1) // 8
|
|
|
|
mx.random.seed(seed)
|
|
|
|
# Read transformer config to detect model version
|
|
import json
|
|
transformer_config_path = model_path / "transformer" / "config.json"
|
|
has_prompt_adaln = False
|
|
if transformer_config_path.exists():
|
|
with open(transformer_config_path) as f:
|
|
has_prompt_adaln = json.load(f).get("has_prompt_adaln", False)
|
|
|
|
# Load text encoder
|
|
with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"):
|
|
from mlx_video.models.ltx_2.text_encoder import LTX2TextEncoder
|
|
text_encoder = LTX2TextEncoder(has_prompt_adaln=has_prompt_adaln)
|
|
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
|
mx.eval(text_encoder.parameters())
|
|
console.print("[green]✓[/] Text encoder loaded")
|
|
|
|
# Optionally enhance the prompt
|
|
if enhance_prompt:
|
|
console.print("[bold magenta]✨ Enhancing prompt[/]")
|
|
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
|
console.print(f"[dim]Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}[/]")
|
|
|
|
# Encode prompts - always get audio embeddings since the model was trained
|
|
# with joint audio-video processing (PyTorch unconditionally generates audio)
|
|
if pipeline in (PipelineType.DEV, PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
|
# Dev/dev-two-stage pipelines need positive and negative embeddings for CFG
|
|
video_embeddings_pos, audio_embeddings_pos = text_encoder(prompt, return_audio_embeddings=True)
|
|
video_embeddings_neg, audio_embeddings_neg = text_encoder(negative_prompt, return_audio_embeddings=True)
|
|
model_dtype = video_embeddings_pos.dtype
|
|
mx.eval(video_embeddings_pos, video_embeddings_neg, audio_embeddings_pos, audio_embeddings_neg)
|
|
# For dev-two-stage, stage 2 uses single positive embedding (no CFG)
|
|
if pipeline in (PipelineType.DEV_TWO_STAGE, PipelineType.DEV_TWO_STAGE_HQ):
|
|
text_embeddings = video_embeddings_pos
|
|
else:
|
|
# Distilled pipeline - single embedding
|
|
text_embeddings, audio_embeddings = text_encoder(prompt, return_audio_embeddings=True)
|
|
mx.eval(text_embeddings, audio_embeddings)
|
|
model_dtype = text_embeddings.dtype
|
|
|
|
del text_encoder
|
|
mx.clear_cache()
|
|
|
|
# Load transformer
|
|
transformer_desc = f"🤖 Loading {pipeline_name.lower()} transformer{' (A/V mode)' if audio else ''}..."
|
|
with console.status(f"[blue]{transformer_desc}[/]", spinner="dots"):
|
|
transformer = LTXModel.from_pretrained(model_path=model_path / "transformer", strict=True)
|
|
|
|
console.print("[green]✓[/] Transformer loaded")
|
|
|
|
# Auto-detect stg_blocks from transformer config if not explicitly provided.
|
|
# LTX-2.3 (has_prompt_adaln=True) uses block 28; LTX-2 uses block 29.
|
|
if stg_blocks is None and stg_scale != 0.0:
|
|
if transformer.config.has_prompt_adaln:
|
|
stg_blocks = [28]
|
|
else:
|
|
stg_blocks = [29]
|
|
console.print(f"[dim]Auto-detected STG blocks: {stg_blocks} (model={'2.3' if transformer.config.has_prompt_adaln else '2'})[/]")
|
|
|
|
# ==========================================================================
|
|
# A2V: Encode input audio to frozen latents
|
|
# ==========================================================================
|
|
a2v_audio_latents = None
|
|
a2v_waveform = None
|
|
a2v_sr = None
|
|
if is_a2v:
|
|
from mlx_video.models.ltx_2.audio_vae.audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
|
from mlx_video.convert import convert_audio_encoder
|
|
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
|
|
|
|
with console.status("[blue]Loading and encoding input audio (A2V)...[/]", spinner="dots"):
|
|
video_duration = num_frames / fps
|
|
|
|
# Load audio
|
|
waveform, sr = load_audio(
|
|
audio_file,
|
|
target_sr=AUDIO_LATENT_SAMPLE_RATE,
|
|
start_time=audio_start_time,
|
|
max_duration=video_duration,
|
|
)
|
|
waveform = ensure_stereo(waveform)
|
|
a2v_waveform = waveform.copy()
|
|
a2v_sr = sr
|
|
|
|
# Compute mel-spectrogram
|
|
mel = waveform_to_mel(waveform, sample_rate=sr, n_fft=1024, hop_length=AUDIO_HOP_LENGTH, n_mels=64)
|
|
|
|
# Convert audio encoder weights if needed, then load
|
|
encoder_dir = convert_audio_encoder(model_path, source_repo="Lightricks/LTX-2")
|
|
audio_encoder = AudioEncoder.from_pretrained(encoder_dir)
|
|
mx.eval(audio_encoder.parameters())
|
|
|
|
# Encode: (1, 2, time, 64) -> normalized latents
|
|
encoded = audio_encoder(mel)
|
|
mx.eval(encoded)
|
|
|
|
# encoded is in MLX format (B, T', mel_bins', z_channels) = (1, T', 16, 8)
|
|
# Convert to PyTorch-style format for consistency: (B, C, T, mel_bins)
|
|
a2v_audio_latents = mx.transpose(encoded, (0, 3, 1, 2)).astype(model_dtype)
|
|
|
|
# Trim/pad to match expected audio_frames
|
|
t_encoded = a2v_audio_latents.shape[2]
|
|
if t_encoded > audio_frames:
|
|
a2v_audio_latents = a2v_audio_latents[:, :, :audio_frames, :]
|
|
elif t_encoded < audio_frames:
|
|
pad_size = audio_frames - t_encoded
|
|
padding = mx.zeros((1, AUDIO_LATENT_CHANNELS, pad_size, AUDIO_MEL_BINS), dtype=model_dtype)
|
|
a2v_audio_latents = mx.concatenate([a2v_audio_latents, padding], axis=2)
|
|
mx.eval(a2v_audio_latents)
|
|
|
|
del audio_encoder
|
|
mx.clear_cache()
|
|
|
|
console.print(f"[green]✓[/] Audio encoded ({a2v_audio_latents.shape[2]} frames from {audio_file})")
|
|
|
|
# ==========================================================================
|
|
# Pipeline-specific generation logic
|
|
# ==========================================================================
|
|
|
|
if pipeline == PipelineType.DISTILLED:
|
|
# ======================================================================
|
|
# DISTILLED PIPELINE: Two-stage with upsampling
|
|
# ======================================================================
|
|
|
|
# Load VAE encoder for I2V
|
|
stage1_image_latent = None
|
|
stage2_image_latent = None
|
|
if is_i2v:
|
|
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
|
|
|
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
|
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
|
mx.eval(stage1_image_latent)
|
|
|
|
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
|
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
|
mx.eval(stage2_image_latent)
|
|
|
|
del vae_encoder
|
|
mx.clear_cache()
|
|
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
|
|
|
# Stage 1
|
|
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)")
|
|
mx.random.seed(seed)
|
|
|
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
|
mx.eval(positions)
|
|
|
|
# Init audio latents/positions: use encoded A2V latents or random
|
|
audio_positions = create_audio_position_grid(1, audio_frames)
|
|
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
|
mx.eval(audio_positions, audio_latents)
|
|
|
|
# Apply I2V conditioning
|
|
state1 = None
|
|
if is_i2v and stage1_image_latent is not None:
|
|
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
|
state1 = LatentState(
|
|
latent=mx.zeros(latent_shape, dtype=model_dtype),
|
|
clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
|
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
|
)
|
|
conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
|
state1 = apply_conditioning(state1, [conditioning])
|
|
|
|
noise = mx.random.normal(latent_shape, dtype=model_dtype)
|
|
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype)
|
|
scaled_mask = state1.denoise_mask * noise_scale
|
|
state1 = LatentState(
|
|
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
|
clean_latent=state1.clean_latent,
|
|
denoise_mask=state1.denoise_mask,
|
|
)
|
|
latents = state1.latent
|
|
mx.eval(latents)
|
|
else:
|
|
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype)
|
|
mx.eval(latents)
|
|
|
|
latents, audio_latents = denoise_distilled(
|
|
latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS,
|
|
verbose=verbose, state=state1,
|
|
audio_latents=audio_latents, audio_positions=audio_positions, audio_embeddings=audio_embeddings,
|
|
audio_frozen=is_a2v,
|
|
)
|
|
|
|
# Upsample latents
|
|
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
|
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
|
if not upscaler_files:
|
|
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
|
upsampler = load_upsampler(str(upscaler_files[0]))
|
|
mx.eval(upsampler.parameters())
|
|
|
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
|
|
|
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
|
mx.eval(latents)
|
|
|
|
del upsampler
|
|
mx.clear_cache()
|
|
console.print("[green]✓[/] Latents upsampled")
|
|
|
|
# Stage 2
|
|
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)")
|
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
|
mx.eval(positions)
|
|
|
|
state2 = None
|
|
if is_i2v and stage2_image_latent is not None:
|
|
state2 = LatentState(
|
|
latent=latents,
|
|
clean_latent=mx.zeros_like(latents),
|
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
|
)
|
|
conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
|
state2 = apply_conditioning(state2, [conditioning])
|
|
|
|
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
scaled_mask = state2.denoise_mask * noise_scale
|
|
state2 = LatentState(
|
|
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
|
clean_latent=state2.clean_latent,
|
|
denoise_mask=state2.denoise_mask,
|
|
)
|
|
latents = state2.latent
|
|
mx.eval(latents)
|
|
else:
|
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
|
latents = noise * noise_scale + latents * one_minus_scale
|
|
mx.eval(latents)
|
|
|
|
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
|
|
if audio_latents is not None and not is_a2v:
|
|
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
|
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
|
mx.eval(audio_latents)
|
|
|
|
# Joint video + audio refinement (no CFG, positive embeddings only)
|
|
latents, audio_latents = denoise_distilled(
|
|
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
|
verbose=verbose, state=state2,
|
|
audio_latents=audio_latents, audio_positions=audio_positions,
|
|
audio_embeddings=audio_embeddings,
|
|
audio_frozen=is_a2v,
|
|
)
|
|
|
|
elif pipeline == PipelineType.DEV:
|
|
# ======================================================================
|
|
# DEV PIPELINE: Single-stage with CFG
|
|
# ======================================================================
|
|
|
|
# Load VAE encoder for I2V
|
|
image_latent = None
|
|
if is_i2v:
|
|
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
|
|
|
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
|
image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
|
image_latent = vae_encoder(image_tensor)
|
|
mx.eval(image_latent)
|
|
|
|
del vae_encoder
|
|
mx.clear_cache()
|
|
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
|
|
|
# Generate sigma schedule with token-count-dependent shifting
|
|
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
|
mx.eval(sigmas)
|
|
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
|
|
|
console.print(f"\n[bold yellow]⚡ Generating:[/] {width}x{height} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
|
|
mx.random.seed(seed)
|
|
|
|
video_positions = create_position_grid(1, latent_frames, latent_h, latent_w)
|
|
mx.eval(video_positions)
|
|
|
|
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
|
audio_positions = create_audio_position_grid(1, audio_frames)
|
|
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
|
mx.eval(audio_positions, audio_latents)
|
|
|
|
# Initialize latents with optional I2V conditioning
|
|
video_state = None
|
|
video_latent_shape = (1, 128, latent_frames, latent_h, latent_w)
|
|
if is_i2v and image_latent is not None:
|
|
video_state = LatentState(
|
|
latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
|
clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
|
)
|
|
conditioning = VideoConditionByLatentIndex(latent=image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
|
video_state = apply_conditioning(video_state, [conditioning])
|
|
|
|
noise = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
|
noise_scale = sigmas[0]
|
|
scaled_mask = video_state.denoise_mask * noise_scale
|
|
video_state = LatentState(
|
|
latent=noise * scaled_mask + video_state.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
|
clean_latent=video_state.clean_latent,
|
|
denoise_mask=video_state.denoise_mask,
|
|
)
|
|
latents = video_state.latent
|
|
mx.eval(latents)
|
|
else:
|
|
latents = mx.random.normal(video_latent_shape, dtype=model_dtype)
|
|
mx.eval(latents)
|
|
|
|
# Always use A/V denoising - PyTorch always processes audio+video jointly
|
|
latents, audio_latents = denoise_dev_av(
|
|
latents, audio_latents,
|
|
video_positions, audio_positions,
|
|
video_embeddings_pos, video_embeddings_neg,
|
|
audio_embeddings_pos, audio_embeddings_neg,
|
|
transformer, sigmas, cfg_scale=cfg_scale,
|
|
audio_cfg_scale=audio_cfg_scale,
|
|
cfg_rescale=cfg_rescale, verbose=verbose, video_state=video_state,
|
|
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
|
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
|
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
|
audio_frozen=is_a2v,
|
|
)
|
|
|
|
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
|
|
|
elif pipeline == PipelineType.DEV_TWO_STAGE:
|
|
# ======================================================================
|
|
# DEV TWO-STAGE PIPELINE:
|
|
# Stage 1: Dev denoising at half resolution with CFG
|
|
# Upsample: 2x spatial via LatentUpsampler
|
|
# Stage 2: Distilled denoising at full resolution with LoRA, no CFG
|
|
# ======================================================================
|
|
|
|
# Load VAE encoder for I2V
|
|
stage1_image_latent = None
|
|
stage2_image_latent = None
|
|
if is_i2v:
|
|
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
|
|
|
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
|
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
|
mx.eval(stage1_image_latent)
|
|
|
|
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
|
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
|
mx.eval(stage2_image_latent)
|
|
|
|
del vae_encoder
|
|
mx.clear_cache()
|
|
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
|
|
|
# Stage 1: Dev denoising at half resolution with CFG
|
|
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
|
mx.eval(sigmas)
|
|
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
|
|
|
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {width//2}x{height//2} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
|
|
mx.random.seed(seed)
|
|
|
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
|
mx.eval(positions)
|
|
|
|
# Always init audio latents/positions - PyTorch unconditionally generates audio
|
|
audio_positions = create_audio_position_grid(1, audio_frames)
|
|
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
|
mx.eval(audio_positions, audio_latents)
|
|
|
|
# Apply I2V conditioning for stage 1
|
|
state1 = None
|
|
stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
|
if is_i2v and stage1_image_latent is not None:
|
|
state1 = LatentState(
|
|
latent=mx.zeros(stage1_shape, dtype=model_dtype),
|
|
clean_latent=mx.zeros(stage1_shape, dtype=model_dtype),
|
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
|
)
|
|
conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
|
state1 = apply_conditioning(state1, [conditioning])
|
|
|
|
noise = mx.random.normal(stage1_shape, dtype=model_dtype)
|
|
noise_scale = sigmas[0]
|
|
scaled_mask = state1.denoise_mask * noise_scale
|
|
state1 = LatentState(
|
|
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
|
clean_latent=state1.clean_latent,
|
|
denoise_mask=state1.denoise_mask,
|
|
)
|
|
latents = state1.latent
|
|
mx.eval(latents)
|
|
else:
|
|
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
|
|
mx.eval(latents)
|
|
|
|
# Stage 1: Always use joint AV denoising (matches PyTorch)
|
|
latents, audio_latents = denoise_dev_av(
|
|
latents, audio_latents,
|
|
positions, audio_positions,
|
|
video_embeddings_pos, video_embeddings_neg,
|
|
audio_embeddings_pos, audio_embeddings_neg,
|
|
transformer, sigmas, cfg_scale=cfg_scale,
|
|
audio_cfg_scale=audio_cfg_scale,
|
|
cfg_rescale=cfg_rescale, verbose=verbose, video_state=state1,
|
|
use_apg=use_apg, apg_eta=apg_eta, apg_norm_threshold=apg_norm_threshold,
|
|
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
|
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
|
audio_frozen=is_a2v,
|
|
)
|
|
|
|
mx.eval(audio_latents)
|
|
|
|
# Upsample latents 2x
|
|
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
|
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
|
if not upscaler_files:
|
|
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
|
upsampler = load_upsampler(str(upscaler_files[0]))
|
|
mx.eval(upsampler.parameters())
|
|
|
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
|
|
|
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
|
mx.eval(latents)
|
|
|
|
del upsampler
|
|
mx.clear_cache()
|
|
console.print("[green]✓[/] Latents upsampled")
|
|
|
|
# Merge LoRA weights for stage 2 (distilled refinement)
|
|
if lora_path is None:
|
|
# Auto-detect LoRA file in model directory
|
|
lora_files = sorted(model_path.glob("*distilled-lora*.safetensors"))
|
|
if lora_files:
|
|
lora_path = str(lora_files[0])
|
|
console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]")
|
|
else:
|
|
console.print("[yellow]⚠️ No LoRA file found. Stage 2 will use base weights.[/]")
|
|
|
|
if lora_path is not None:
|
|
with console.status("[blue]🔧 Merging distilled LoRA weights...[/]", spinner="dots"):
|
|
load_and_merge_lora(transformer, lora_path, strength=lora_strength)
|
|
|
|
# Stage 2: Distilled refinement at full resolution (no CFG)
|
|
# Matches PyTorch: re-noise audio at sigma=0.909375, then jointly refine
|
|
# both video and audio through the distilled schedule using the LoRA-merged model.
|
|
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Distilled refining at {width}x{height} (3 steps, no CFG)")
|
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
|
mx.eval(positions)
|
|
|
|
state2 = None
|
|
if is_i2v and stage2_image_latent is not None:
|
|
state2 = LatentState(
|
|
latent=latents,
|
|
clean_latent=mx.zeros_like(latents),
|
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
|
)
|
|
conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
|
state2 = apply_conditioning(state2, [conditioning])
|
|
|
|
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
scaled_mask = state2.denoise_mask * noise_scale
|
|
state2 = LatentState(
|
|
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
|
clean_latent=state2.clean_latent,
|
|
denoise_mask=state2.denoise_mask,
|
|
)
|
|
latents = state2.latent
|
|
mx.eval(latents)
|
|
else:
|
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
|
latents = noise * noise_scale + latents * one_minus_scale
|
|
mx.eval(latents)
|
|
|
|
# Re-noise audio at sigma=0.909375 for joint refinement (matches PyTorch)
|
|
if audio_latents is not None and not is_a2v:
|
|
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
|
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
|
mx.eval(audio_latents)
|
|
|
|
# Joint video + audio refinement (no CFG, positive embeddings only)
|
|
latents, audio_latents = denoise_distilled(
|
|
latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS,
|
|
verbose=verbose, state=state2,
|
|
audio_latents=audio_latents, audio_positions=audio_positions,
|
|
audio_embeddings=audio_embeddings_pos,
|
|
audio_frozen=is_a2v,
|
|
)
|
|
|
|
elif pipeline == PipelineType.DEV_TWO_STAGE_HQ:
|
|
# ======================================================================
|
|
# DEV TWO-STAGE HQ PIPELINE:
|
|
# Stage 1: res_2s denoising at half resolution with CFG + LoRA@0.25
|
|
# Upsample: 2x spatial via LatentUpsampler
|
|
# Stage 2: res_2s refinement at full resolution with LoRA@0.5, no CFG
|
|
# ======================================================================
|
|
|
|
# HQ defaults
|
|
hq_lora_strength_s1 = lora_strength_stage_1 if lora_strength_stage_1 is not None else 0.25
|
|
hq_lora_strength_s2 = lora_strength_stage_2 if lora_strength_stage_2 is not None else 0.5
|
|
hq_cfg_rescale = cfg_rescale if cfg_rescale != 0.7 else 0.45 # Override default 0.7 → 0.45
|
|
hq_steps = num_inference_steps if num_inference_steps != 30 else 15 # Override default 30 → 15
|
|
|
|
# Load VAE encoder for I2V
|
|
stage1_image_latent = None
|
|
stage2_image_latent = None
|
|
if is_i2v:
|
|
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
|
|
|
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
|
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
|
mx.eval(stage1_image_latent)
|
|
|
|
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
|
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
|
mx.eval(stage2_image_latent)
|
|
|
|
del vae_encoder
|
|
mx.clear_cache()
|
|
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
|
|
|
# Auto-detect and merge LoRA for stage 1 (strength 0.25)
|
|
if lora_path is None:
|
|
lora_files = sorted(model_path.glob("*distilled-lora*.safetensors"))
|
|
if lora_files:
|
|
lora_path = str(lora_files[0])
|
|
console.print(f"[dim]Auto-detected LoRA: {Path(lora_path).name}[/]")
|
|
else:
|
|
console.print("[yellow]Warning: No LoRA file found. HQ pipeline works best with distilled LoRA.[/]")
|
|
|
|
if lora_path is not None:
|
|
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
|
|
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
|
|
|
|
# Stage 1: res_2s denoising at half resolution with CFG
|
|
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
|
|
num_tokens = latent_frames * stage1_h * stage1_w
|
|
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
|
|
mx.eval(sigmas)
|
|
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
|
|
|
|
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
|
|
mx.random.seed(seed)
|
|
|
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
|
mx.eval(positions)
|
|
|
|
audio_positions = create_audio_position_grid(1, audio_frames)
|
|
audio_latents = a2v_audio_latents if is_a2v else mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS), dtype=model_dtype)
|
|
mx.eval(audio_positions, audio_latents)
|
|
|
|
# Apply I2V conditioning for stage 1
|
|
state1 = None
|
|
stage1_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
|
if is_i2v and stage1_image_latent is not None:
|
|
state1 = LatentState(
|
|
latent=mx.zeros(stage1_shape, dtype=model_dtype),
|
|
clean_latent=mx.zeros(stage1_shape, dtype=model_dtype),
|
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
|
)
|
|
conditioning = VideoConditionByLatentIndex(latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
|
state1 = apply_conditioning(state1, [conditioning])
|
|
|
|
noise = mx.random.normal(stage1_shape, dtype=model_dtype)
|
|
noise_scale = sigmas[0]
|
|
scaled_mask = state1.denoise_mask * noise_scale
|
|
state1 = LatentState(
|
|
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
|
clean_latent=state1.clean_latent,
|
|
denoise_mask=state1.denoise_mask,
|
|
)
|
|
latents = state1.latent
|
|
mx.eval(latents)
|
|
else:
|
|
latents = mx.random.normal(stage1_shape, dtype=model_dtype)
|
|
mx.eval(latents)
|
|
|
|
# Stage 1: res_2s with CFG (STG disabled for HQ by default)
|
|
latents, audio_latents = denoise_res2s_av(
|
|
latents, audio_latents,
|
|
positions, audio_positions,
|
|
video_embeddings_pos, video_embeddings_neg,
|
|
audio_embeddings_pos, audio_embeddings_neg,
|
|
transformer, sigmas, cfg_scale=cfg_scale,
|
|
audio_cfg_scale=audio_cfg_scale,
|
|
cfg_rescale=hq_cfg_rescale, audio_cfg_rescale=1.0,
|
|
verbose=verbose, video_state=state1,
|
|
stg_scale=stg_scale, stg_video_blocks=stg_blocks,
|
|
stg_audio_blocks=stg_blocks, modality_scale=modality_scale,
|
|
noise_seed=seed,
|
|
audio_frozen=is_a2v,
|
|
)
|
|
|
|
mx.eval(audio_latents)
|
|
|
|
# Upsample latents 2x
|
|
with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"):
|
|
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
|
if not upscaler_files:
|
|
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
|
upsampler = load_upsampler(str(upscaler_files[0]))
|
|
mx.eval(upsampler.parameters())
|
|
|
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
|
|
|
latents = upsample_latents(latents, upsampler, vae_decoder.per_channel_statistics.mean, vae_decoder.per_channel_statistics.std)
|
|
mx.eval(latents)
|
|
|
|
del upsampler
|
|
mx.clear_cache()
|
|
console.print("[green]✓[/] Latents upsampled")
|
|
|
|
# Merge additional LoRA for stage 2 (additive: 0.25 + 0.25 = 0.5 total)
|
|
if lora_path is not None:
|
|
additional_strength = hq_lora_strength_s2 - hq_lora_strength_s1
|
|
if additional_strength > 0:
|
|
with console.status(f"[blue]Adjusting LoRA (stage 2, total={hq_lora_strength_s2})...[/]", spinner="dots"):
|
|
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
|
|
|
|
# Stage 2: res_2s refinement at full resolution (no CFG)
|
|
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)")
|
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
|
mx.eval(positions)
|
|
|
|
state2 = None
|
|
if is_i2v and stage2_image_latent is not None:
|
|
state2 = LatentState(
|
|
latent=latents,
|
|
clean_latent=mx.zeros_like(latents),
|
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
|
)
|
|
conditioning = VideoConditionByLatentIndex(latent=stage2_image_latent, frame_idx=image_frame_idx, strength=image_strength)
|
|
state2 = apply_conditioning(state2, [conditioning])
|
|
|
|
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
scaled_mask = state2.denoise_mask * noise_scale
|
|
state2 = LatentState(
|
|
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
|
clean_latent=state2.clean_latent,
|
|
denoise_mask=state2.denoise_mask,
|
|
)
|
|
latents = state2.latent
|
|
mx.eval(latents)
|
|
else:
|
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
|
latents = noise * noise_scale + latents * one_minus_scale
|
|
mx.eval(latents)
|
|
|
|
# Re-noise audio at sigma=0.909375 for joint refinement
|
|
if audio_latents is not None and not is_a2v:
|
|
audio_noise = mx.random.normal(audio_latents.shape, dtype=model_dtype)
|
|
audio_noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
|
audio_latents = audio_noise * audio_noise_scale + audio_latents * (mx.array(1.0, dtype=model_dtype) - audio_noise_scale)
|
|
mx.eval(audio_latents)
|
|
|
|
# Stage 2: res_2s with no CFG (positive embeddings only)
|
|
stage2_sigmas = mx.array(STAGE_2_SIGMAS, dtype=mx.float32)
|
|
latents, audio_latents = denoise_res2s_av(
|
|
latents, audio_latents,
|
|
positions, audio_positions,
|
|
video_embeddings_pos, video_embeddings_pos, # both pos (no neg for stage 2)
|
|
audio_embeddings_pos, audio_embeddings_pos,
|
|
transformer, stage2_sigmas, cfg_scale=1.0, # no CFG
|
|
audio_cfg_scale=1.0,
|
|
cfg_rescale=0.0, verbose=verbose, video_state=state2,
|
|
noise_seed=seed + 1,
|
|
audio_frozen=is_a2v,
|
|
)
|
|
|
|
del transformer
|
|
mx.clear_cache()
|
|
|
|
# ==========================================================================
|
|
# Decode and save outputs (common to both pipelines)
|
|
# ==========================================================================
|
|
|
|
console.print("\n[blue]🎞️ Decoding video...[/]")
|
|
|
|
# Select tiling configuration
|
|
if tiling == "none":
|
|
tiling_config = None
|
|
elif tiling == "auto":
|
|
tiling_config = TilingConfig.auto(height, width, num_frames)
|
|
elif tiling == "default":
|
|
tiling_config = TilingConfig.default()
|
|
elif tiling == "aggressive":
|
|
tiling_config = TilingConfig.aggressive()
|
|
elif tiling == "conservative":
|
|
tiling_config = TilingConfig.conservative()
|
|
elif tiling == "spatial":
|
|
tiling_config = TilingConfig.spatial_only()
|
|
elif tiling == "temporal":
|
|
tiling_config = TilingConfig.temporal_only()
|
|
else:
|
|
console.print(f"[yellow] Unknown tiling mode '{tiling}', using auto[/]")
|
|
tiling_config = TilingConfig.auto(height, width, num_frames)
|
|
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Stream mode
|
|
video_writer = None
|
|
stream_progress = None
|
|
|
|
if stream and tiling_config is not None:
|
|
import cv2
|
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
|
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
|
|
stream_progress = Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(),
|
|
TaskProgressColumn(),
|
|
console=console,
|
|
)
|
|
stream_progress.start()
|
|
stream_task = stream_progress.add_task("[cyan]Streaming frames[/]", total=num_frames)
|
|
|
|
def on_frames_ready(frames: mx.array, _start_idx: int):
|
|
frames = mx.squeeze(frames, axis=0)
|
|
frames = mx.transpose(frames, (1, 2, 3, 0))
|
|
frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0)
|
|
frames = (frames * 255).astype(mx.uint8)
|
|
frames_np = np.array(frames)
|
|
|
|
for frame in frames_np:
|
|
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
|
stream_progress.advance(stream_task)
|
|
else:
|
|
on_frames_ready = None
|
|
|
|
if tiling_config is not None:
|
|
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
|
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
|
console.print(f"[dim] Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}[/]")
|
|
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
|
|
else:
|
|
console.print("[dim] Tiling: disabled[/]")
|
|
video = vae_decoder(latents)
|
|
mx.eval(video)
|
|
mx.clear_cache()
|
|
|
|
# Close stream writer
|
|
if video_writer is not None:
|
|
video_writer.release()
|
|
if stream_progress is not None:
|
|
stream_progress.stop()
|
|
console.print(f"[green]✅ Streamed video to[/] {output_path}")
|
|
video = mx.squeeze(video, axis=0)
|
|
video = mx.transpose(video, (1, 2, 3, 0))
|
|
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
|
video = (video * 255).astype(mx.uint8)
|
|
video_np = np.array(video)
|
|
else:
|
|
video = mx.squeeze(video, axis=0)
|
|
video = mx.transpose(video, (1, 2, 3, 0))
|
|
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
|
video = (video * 255).astype(mx.uint8)
|
|
video_np = np.array(video)
|
|
|
|
if audio:
|
|
temp_video_path = output_path.with_suffix('.temp.mp4')
|
|
save_path = temp_video_path
|
|
else:
|
|
save_path = output_path
|
|
|
|
try:
|
|
import cv2
|
|
h, w = video_np.shape[1], video_np.shape[2]
|
|
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
|
out = cv2.VideoWriter(str(save_path), fourcc, fps, (w, h))
|
|
for frame in video_np:
|
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
|
out.release()
|
|
if not audio:
|
|
console.print(f"[green]✅ Saved video to[/] {output_path}")
|
|
except Exception as e:
|
|
console.print(f"[red]❌ Could not save video: {e}[/]")
|
|
|
|
# Decode and save audio if enabled
|
|
audio_np = None
|
|
vocoder_sample_rate = AUDIO_SAMPLE_RATE
|
|
if audio and audio_latents is not None:
|
|
if is_a2v and a2v_waveform is not None:
|
|
# A2V: use original input audio waveform (no VAE decoding needed)
|
|
audio_np = a2v_waveform
|
|
if audio_np.ndim == 1:
|
|
audio_np = audio_np[np.newaxis, :]
|
|
vocoder_sample_rate = a2v_sr or AUDIO_LATENT_SAMPLE_RATE
|
|
console.print("[green]✓[/] Using original input audio (A2V)")
|
|
else:
|
|
with console.status("[blue]Decoding audio...[/]", spinner="dots"):
|
|
audio_decoder = load_audio_decoder(model_path, pipeline)
|
|
vocoder = load_vocoder_model(model_path, pipeline)
|
|
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
|
|
|
mel_spectrogram = audio_decoder(audio_latents)
|
|
mx.eval(mel_spectrogram)
|
|
console.print(f"[dim] Mel spectrogram: shape={mel_spectrogram.shape}, std={mel_spectrogram.std().item():.4f}, mean={mel_spectrogram.mean().item():.4f}[/]")
|
|
|
|
audio_waveform = vocoder(mel_spectrogram)
|
|
mx.eval(audio_waveform)
|
|
|
|
audio_np = np.array(audio_waveform.astype(mx.float32))
|
|
if audio_np.ndim == 3:
|
|
audio_np = audio_np[0]
|
|
|
|
# Get sample rate from vocoder (dynamic: 24kHz for LTX-2, 48kHz for LTX-2.3 BWE)
|
|
vocoder_sample_rate = getattr(vocoder, 'output_sampling_rate', AUDIO_SAMPLE_RATE)
|
|
|
|
del audio_decoder, vocoder
|
|
mx.clear_cache()
|
|
console.print("[green]✓[/] Audio decoded")
|
|
|
|
audio_path = Path(output_audio_path) if output_audio_path else output_path.with_suffix('.wav')
|
|
save_audio(audio_np, audio_path, vocoder_sample_rate)
|
|
console.print(f"[green]✅ Saved audio to[/] {audio_path}")
|
|
|
|
with console.status("[blue]🎬 Combining video and audio...[/]", spinner="dots"):
|
|
temp_video_path = output_path.with_suffix('.temp.mp4')
|
|
success = mux_video_audio(temp_video_path, audio_path, output_path)
|
|
if success:
|
|
console.print(f"[green]✅ Saved video with audio to[/] {output_path}")
|
|
temp_video_path.unlink()
|
|
else:
|
|
temp_video_path.rename(output_path)
|
|
console.print(f"[yellow]⚠️ Saved video without audio to[/] {output_path}")
|
|
|
|
del vae_decoder
|
|
mx.clear_cache()
|
|
|
|
if save_frames:
|
|
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
|
frames_dir.mkdir(exist_ok=True)
|
|
for i, frame in enumerate(video_np):
|
|
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
|
|
console.print(f"[green]✅ Saved {len(video_np)} frames to {frames_dir}[/]")
|
|
|
|
elapsed = time.time() - start_time
|
|
minutes, seconds = divmod(elapsed, 60)
|
|
time_str = f"{int(minutes)}m {seconds:.1f}s" if minutes >= 1 else f"{seconds:.1f}s"
|
|
console.print(Panel(
|
|
f"[bold green]🎉 Done![/] Generated in {time_str} ({elapsed/num_frames:.2f}s/frame)\n"
|
|
f"[bold green]✨ Peak memory:[/] {mx.get_peak_memory() / (1024 ** 3):.2f}GB",
|
|
expand=False
|
|
))
|
|
|
|
if audio:
|
|
return video_np, audio_np
|
|
return video_np
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate videos with MLX LTX-2 (Distilled or Dev pipeline)",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Distilled pipeline (two-stage, fast, no CFG)
|
|
python -m mlx_video.generate --prompt "A cat walking on grass"
|
|
python -m mlx_video.generate --prompt "Ocean waves" --pipeline distilled
|
|
|
|
# Dev pipeline (single-stage, CFG, higher quality)
|
|
python -m mlx_video.generate --prompt "A cat walking" --pipeline dev --cfg-scale 3.0
|
|
python -m mlx_video.generate --prompt "Ocean waves" --pipeline dev --steps 40
|
|
|
|
# Dev two-stage pipeline (dev + LoRA refinement)
|
|
python -m mlx_video.generate --prompt "A cat walking" --pipeline dev-two-stage --cfg-scale 3.0
|
|
|
|
# Image-to-Video (works with both pipelines)
|
|
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
|
|
python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --pipeline dev
|
|
|
|
# With Audio (works with both pipelines)
|
|
python -m mlx_video.generate --prompt "Ocean waves crashing" --audio
|
|
python -m mlx_video.generate --prompt "A jazz band playing" --audio --pipeline dev
|
|
"""
|
|
)
|
|
|
|
parser.add_argument("--prompt", "-p", type=str, required=True, help="Text description of the video to generate")
|
|
parser.add_argument("--pipeline", type=str, default="distilled", choices=["distilled", "dev", "dev-two-stage", "dev-two-stage-hq"],
|
|
help="Pipeline type: distilled (fast), dev (CFG), dev-two-stage (dev + LoRA), dev-two-stage-hq (res_2s + LoRA both stages)")
|
|
parser.add_argument("--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT,
|
|
help="Negative prompt for CFG (dev pipeline only)")
|
|
parser.add_argument("--height", "-H", type=int, default=512, help="Output video height")
|
|
parser.add_argument("--width", "-W", type=int, default=512, help="Output video width")
|
|
parser.add_argument("--num-frames", "-n", type=int, default=33, help="Number of frames")
|
|
parser.add_argument("--steps", type=int, default=30, help="Number of inference steps (dev pipeline only, default 30)")
|
|
parser.add_argument("--cfg-scale", type=float, default=3.0, help="CFG guidance scale for video (dev pipeline only, default 3.0)")
|
|
parser.add_argument("--audio-cfg-scale", type=float, default=7.0, help="CFG guidance scale for audio (default 7.0, PyTorch default)")
|
|
parser.add_argument("--cfg-rescale", type=float, default=0.7, help="CFG rescale factor (0.0-1.0). Normalizes guided prediction variance to reduce artifacts (dev pipeline only, default 0.7)")
|
|
parser.add_argument("--seed", "-s", type=int, default=42, help="Random seed")
|
|
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
|
|
parser.add_argument("--output-path", "-o", type=str, default="output.mp4", help="Output video path")
|
|
parser.add_argument("--save-frames", action="store_true", help="Save individual frames as images")
|
|
parser.add_argument("--model-repo", type=str, default="Lightricks/LTX-2", help="Model repository")
|
|
parser.add_argument("--text-encoder-repo", type=str, default=None, help="Text encoder repository")
|
|
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
|
parser.add_argument("--enhance-prompt", action="store_true", help="Enhance the prompt using Gemma")
|
|
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens for prompt enhancement")
|
|
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for prompt enhancement")
|
|
parser.add_argument("--image", "-i", type=str, default=None, help="Path to conditioning image for I2V")
|
|
parser.add_argument("--image-strength", type=float, default=1.0, help="Conditioning strength for I2V")
|
|
parser.add_argument("--image-frame-idx", type=int, default=0, help="Frame index to condition for I2V")
|
|
parser.add_argument("--tiling", type=str, default="auto",
|
|
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
|
help="Tiling mode for VAE decoding")
|
|
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
|
|
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
|
|
parser.add_argument("--audio-file", type=str, default=None, help="Path to audio file for A2V (audio-to-video) conditioning")
|
|
parser.add_argument("--audio-start-time", type=float, default=0.0, help="Start time in seconds for audio file (default: 0.0)")
|
|
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
|
|
parser.add_argument("--apg", action="store_true", help="Use Adaptive Projected Guidance instead of CFG (more stable for I2V)")
|
|
parser.add_argument("--apg-eta", type=float, default=1.0, help="APG parallel component weight (1.0 = keep full parallel)")
|
|
parser.add_argument("--apg-norm-threshold", type=float, default=0.0, help="APG guidance norm clamp (0 = no clamping)")
|
|
parser.add_argument("--stg-scale", type=float, default=0.0, help="STG (Spatiotemporal Guidance) scale (default 0.0 = disabled, PyTorch default: 1.0)")
|
|
parser.add_argument("--stg-blocks", type=int, nargs="+", default=None, help="Transformer block indices for STG perturbation (default: [29] for LTX-2, [28] for LTX-2.3)")
|
|
parser.add_argument("--modality-scale", type=float, default=1.0, help="Cross-modal guidance scale (default 1.0 = disabled, PyTorch default: 3.0)")
|
|
parser.add_argument("--lora-path", type=str, default=None, help="Path to LoRA safetensors file (dev-two-stage pipeline)")
|
|
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
|
|
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
|
|
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
|
|
args = parser.parse_args()
|
|
|
|
pipeline_map = {
|
|
"distilled": PipelineType.DISTILLED,
|
|
"dev": PipelineType.DEV,
|
|
"dev-two-stage": PipelineType.DEV_TWO_STAGE,
|
|
"dev-two-stage-hq": PipelineType.DEV_TWO_STAGE_HQ,
|
|
}
|
|
pipeline = pipeline_map[args.pipeline]
|
|
|
|
generate_video(
|
|
model_repo=args.model_repo,
|
|
text_encoder_repo=args.text_encoder_repo,
|
|
prompt=args.prompt,
|
|
pipeline=pipeline,
|
|
negative_prompt=args.negative_prompt,
|
|
height=args.height,
|
|
width=args.width,
|
|
num_frames=args.num_frames,
|
|
num_inference_steps=args.steps,
|
|
cfg_scale=args.cfg_scale,
|
|
audio_cfg_scale=args.audio_cfg_scale,
|
|
cfg_rescale=args.cfg_rescale,
|
|
seed=args.seed,
|
|
fps=args.fps,
|
|
output_path=args.output_path,
|
|
save_frames=args.save_frames,
|
|
verbose=args.verbose,
|
|
enhance_prompt=args.enhance_prompt,
|
|
max_tokens=args.max_tokens,
|
|
temperature=args.temperature,
|
|
image=args.image,
|
|
image_strength=args.image_strength,
|
|
image_frame_idx=args.image_frame_idx,
|
|
tiling=args.tiling,
|
|
stream=args.stream,
|
|
audio=args.audio,
|
|
output_audio_path=args.output_audio,
|
|
use_apg=args.apg,
|
|
apg_eta=args.apg_eta,
|
|
apg_norm_threshold=args.apg_norm_threshold,
|
|
stg_scale=args.stg_scale,
|
|
stg_blocks=args.stg_blocks,
|
|
modality_scale=args.modality_scale,
|
|
lora_path=args.lora_path,
|
|
lora_strength=args.lora_strength,
|
|
lora_strength_stage_1=args.lora_strength_stage_1,
|
|
lora_strength_stage_2=args.lora_strength_stage_2,
|
|
audio_file=args.audio_file,
|
|
audio_start_time=args.audio_start_time,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|