Add custom spatial upscaling support to LTX-2 video generation; introduce spatial_upscaler parameter and enhance resolution handling for two-stage pipelines
This commit is contained in:
@@ -1461,6 +1461,7 @@ def generate_video(
|
||||
lora_strength_stage_2: Optional[float] = None,
|
||||
audio_file: Optional[str] = None,
|
||||
audio_start_time: float = 0.0,
|
||||
spatial_upscaler: Optional[str] = None,
|
||||
):
|
||||
"""Generate video using LTX-2 models.
|
||||
|
||||
@@ -1557,10 +1558,35 @@ def generate_video(
|
||||
model_path = get_model_path(model_repo)
|
||||
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||
|
||||
# Resolve spatial upscaler path for two-stage pipelines
|
||||
upscaler_path = None
|
||||
upscaler_scale = 2.0
|
||||
if is_two_stage:
|
||||
if spatial_upscaler is not None:
|
||||
# User-specified upscaler file
|
||||
upscaler_path = model_path / spatial_upscaler if not Path(spatial_upscaler).is_absolute() else Path(spatial_upscaler)
|
||||
if not upscaler_path.exists():
|
||||
# Try as a filename within model_path
|
||||
upscaler_path = model_path / spatial_upscaler
|
||||
# Detect scale from filename
|
||||
if "x1.5" in str(upscaler_path):
|
||||
upscaler_scale = 1.5
|
||||
elif "x2" in str(upscaler_path):
|
||||
upscaler_scale = 2.0
|
||||
else:
|
||||
# Auto-detect: prefer x2 upscaler
|
||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
||||
if upscaler_files:
|
||||
upscaler_path = upscaler_files[0]
|
||||
upscaler_scale = 2.0
|
||||
|
||||
# Calculate latent dimensions
|
||||
if is_two_stage:
|
||||
# Stage 1 always at half resolution (matches PyTorch)
|
||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||
stage2_h, stage2_w = height // 32, width // 32
|
||||
# Stage 2 resolution = stage 1 * upscaler scale
|
||||
stage2_h = int(stage1_h * upscaler_scale)
|
||||
stage2_w = int(stage1_w * upscaler_scale)
|
||||
else:
|
||||
latent_h, latent_w = height // 32, width // 32
|
||||
latent_frames = 1 + (num_frames - 1) // 8
|
||||
@@ -1697,13 +1723,15 @@ def generate_video(
|
||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
s1_h, s1_w = stage1_h * 32, stage1_w * 32
|
||||
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
s2_h, s2_w = stage2_h * 32, stage2_w * 32
|
||||
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
|
||||
@@ -1712,7 +1740,7 @@ def generate_video(
|
||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||
|
||||
# Stage 1
|
||||
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)")
|
||||
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {stage1_w*32}x{stage1_h*32} (8 steps)")
|
||||
mx.random.seed(seed)
|
||||
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
@@ -1757,11 +1785,10 @@ def generate_video(
|
||||
)
|
||||
|
||||
# Upsample latents
|
||||
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
||||
if not upscaler_files:
|
||||
with console.status(f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
|
||||
if upscaler_path is None or not upscaler_path.exists():
|
||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
||||
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
@@ -1774,7 +1801,7 @@ def generate_video(
|
||||
console.print("[green]✓[/] Latents upsampled")
|
||||
|
||||
# Stage 2
|
||||
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)")
|
||||
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {stage2_w*32}x{stage2_h*32} (3 steps)")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
@@ -1916,13 +1943,15 @@ def generate_video(
|
||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
s1_h, s1_w = stage1_h * 32, stage1_w * 32
|
||||
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
s2_h, s2_w = stage2_h * 32, stage2_w * 32
|
||||
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
|
||||
@@ -1930,12 +1959,12 @@ def generate_video(
|
||||
mx.clear_cache()
|
||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||
|
||||
# Stage 1: Dev denoising at half resolution with CFG
|
||||
# Stage 1: Dev denoising at reduced resolution with CFG
|
||||
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
||||
mx.eval(sigmas)
|
||||
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||
|
||||
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {width//2}x{height//2} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
|
||||
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {stage1_w*32}x{stage1_h*32} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
|
||||
mx.random.seed(seed)
|
||||
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
@@ -1989,12 +2018,11 @@ def generate_video(
|
||||
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Upsample latents 2x
|
||||
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
||||
if not upscaler_files:
|
||||
# Upsample latents
|
||||
with console.status(f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
|
||||
if upscaler_path is None or not upscaler_path.exists():
|
||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
||||
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
@@ -2091,13 +2119,15 @@ def generate_video(
|
||||
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||
|
||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||
s1_h, s1_w = stage1_h * 32, stage1_w * 32
|
||||
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
|
||||
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
|
||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||
mx.eval(stage1_image_latent)
|
||||
|
||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||
s2_h, s2_w = stage2_h * 32, stage2_w * 32
|
||||
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
|
||||
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
|
||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||
mx.eval(stage2_image_latent)
|
||||
|
||||
@@ -2118,14 +2148,14 @@ def generate_video(
|
||||
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
|
||||
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
|
||||
|
||||
# Stage 1: res_2s denoising at half resolution with CFG
|
||||
# Stage 1: res_2s denoising at reduced resolution with CFG
|
||||
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
|
||||
num_tokens = latent_frames * stage1_h * stage1_w
|
||||
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
|
||||
mx.eval(sigmas)
|
||||
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
|
||||
|
||||
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
|
||||
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {stage1_w*32}x{stage1_h*32} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
|
||||
mx.random.seed(seed)
|
||||
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
@@ -2179,12 +2209,11 @@ def generate_video(
|
||||
|
||||
mx.eval(audio_latents)
|
||||
|
||||
# Upsample latents 2x
|
||||
with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"):
|
||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
||||
if not upscaler_files:
|
||||
# Upsample latents
|
||||
with console.status(f"[magenta]Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
|
||||
if upscaler_path is None or not upscaler_path.exists():
|
||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
||||
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
@@ -2204,7 +2233,7 @@ def generate_video(
|
||||
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
|
||||
|
||||
# Stage 2: res_2s refinement at full resolution (no CFG)
|
||||
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)")
|
||||
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {stage2_w*32}x{stage2_h*32} (3 steps, no CFG)")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
@@ -2509,6 +2538,9 @@ Examples:
|
||||
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
|
||||
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
|
||||
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
|
||||
parser.add_argument("--spatial-upscaler", type=str, default=None,
|
||||
help="Spatial upscaler filename (e.g. ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors). "
|
||||
"Auto-detects x2 by default. Use this to select x1.5 or a specific version.")
|
||||
args = parser.parse_args()
|
||||
|
||||
pipeline_map = {
|
||||
@@ -2559,6 +2591,7 @@ Examples:
|
||||
lora_strength_stage_2=args.lora_strength_stage_2,
|
||||
audio_file=args.audio_file,
|
||||
audio_start_time=args.audio_start_time,
|
||||
spatial_upscaler=args.spatial_upscaler,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user