Add custom spatial upscaling support to LTX-2 video generation; introduce spatial_upscaler parameter and enhance resolution handling for two-stage pipelines

This commit is contained in:
Prince Canuma
2026-03-17 02:23:47 +01:00
parent cc302d79b0
commit 57f66bcae2
3 changed files with 234 additions and 98 deletions

View File

@@ -1461,6 +1461,7 @@ def generate_video(
lora_strength_stage_2: Optional[float] = None,
audio_file: Optional[str] = None,
audio_start_time: float = 0.0,
spatial_upscaler: Optional[str] = None,
):
"""Generate video using LTX-2 models.
@@ -1557,10 +1558,35 @@ def generate_video(
model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Resolve spatial upscaler path for two-stage pipelines
upscaler_path = None
upscaler_scale = 2.0
if is_two_stage:
if spatial_upscaler is not None:
# User-specified upscaler file
upscaler_path = model_path / spatial_upscaler if not Path(spatial_upscaler).is_absolute() else Path(spatial_upscaler)
if not upscaler_path.exists():
# Try as a filename within model_path
upscaler_path = model_path / spatial_upscaler
# Detect scale from filename
if "x1.5" in str(upscaler_path):
upscaler_scale = 1.5
elif "x2" in str(upscaler_path):
upscaler_scale = 2.0
else:
# Auto-detect: prefer x2 upscaler
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if upscaler_files:
upscaler_path = upscaler_files[0]
upscaler_scale = 2.0
# Calculate latent dimensions
if is_two_stage:
# Stage 1 always at half resolution (matches PyTorch)
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
# Stage 2 resolution = stage 1 * upscaler scale
stage2_h = int(stage1_h * upscaler_scale)
stage2_w = int(stage1_w * upscaler_scale)
else:
latent_h, latent_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
@@ -1697,13 +1723,15 @@ def generate_video(
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
s1_h, s1_w = stage1_h * 32, stage1_w * 32
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
s2_h, s2_w = stage2_h * 32, stage2_w * 32
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
@@ -1712,7 +1740,7 @@ def generate_video(
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Stage 1
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)")
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {stage1_w*32}x{stage1_h*32} (8 steps)")
mx.random.seed(seed)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
@@ -1757,11 +1785,10 @@ def generate_video(
)
# Upsample latents
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if not upscaler_files:
with console.status(f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
if upscaler_path is None or not upscaler_path.exists():
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
upsampler = load_upsampler(str(upscaler_files[0]))
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
mx.eval(upsampler.parameters())
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
@@ -1774,7 +1801,7 @@ def generate_video(
console.print("[green]✓[/] Latents upsampled")
# Stage 2
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)")
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {stage2_w*32}x{stage2_h*32} (3 steps)")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
@@ -1916,13 +1943,15 @@ def generate_video(
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
s1_h, s1_w = stage1_h * 32, stage1_w * 32
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
s2_h, s2_w = stage2_h * 32, stage2_w * 32
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
@@ -1930,12 +1959,12 @@ def generate_video(
mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Stage 1: Dev denoising at half resolution with CFG
# Stage 1: Dev denoising at reduced resolution with CFG
sigmas = ltx2_scheduler(steps=num_inference_steps)
mx.eval(sigmas)
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]")
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {width//2}x{height//2} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {stage1_w*32}x{stage1_h*32} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
mx.random.seed(seed)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
@@ -1989,12 +2018,11 @@ def generate_video(
mx.eval(audio_latents)
# Upsample latents 2x
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if not upscaler_files:
# Upsample latents
with console.status(f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
if upscaler_path is None or not upscaler_path.exists():
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
upsampler = load_upsampler(str(upscaler_files[0]))
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
mx.eval(upsampler.parameters())
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
@@ -2091,13 +2119,15 @@ def generate_video(
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
s1_h, s1_w = stage1_h * 32, stage1_w * 32
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
s2_h, s2_w = stage2_h * 32, stage2_w * 32
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
@@ -2118,14 +2148,14 @@ def generate_video(
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
# Stage 1: res_2s denoising at half resolution with CFG
# Stage 1: res_2s denoising at reduced resolution with CFG
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
num_tokens = latent_frames * stage1_h * stage1_w
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
mx.eval(sigmas)
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {stage1_w*32}x{stage1_h*32} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
mx.random.seed(seed)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
@@ -2179,12 +2209,11 @@ def generate_video(
mx.eval(audio_latents)
# Upsample latents 2x
with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"):
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if not upscaler_files:
# Upsample latents
with console.status(f"[magenta]Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
if upscaler_path is None or not upscaler_path.exists():
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
upsampler = load_upsampler(str(upscaler_files[0]))
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
mx.eval(upsampler.parameters())
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
@@ -2204,7 +2233,7 @@ def generate_video(
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
# Stage 2: res_2s refinement at full resolution (no CFG)
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)")
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {stage2_w*32}x{stage2_h*32} (3 steps, no CFG)")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
@@ -2509,6 +2538,9 @@ Examples:
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
parser.add_argument("--spatial-upscaler", type=str, default=None,
help="Spatial upscaler filename (e.g. ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors). "
"Auto-detects x2 by default. Use this to select x1.5 or a specific version.")
args = parser.parse_args()
pipeline_map = {
@@ -2559,6 +2591,7 @@ Examples:
lora_strength_stage_2=args.lora_strength_stage_2,
audio_file=args.audio_file,
audio_start_time=args.audio_start_time,
spatial_upscaler=args.spatial_upscaler,
)