Add custom spatial upscaling support to LTX-2 video generation; introduce spatial_upscaler parameter and enhance resolution handling for two-stage pipelines
This commit is contained in:
@@ -155,6 +155,32 @@ uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prom
|
|||||||
| `--audio-start-time` | 0.0 | Start time in seconds for audio file |
|
| `--audio-start-time` | 0.0 | Start time in seconds for audio file |
|
||||||
| `--tiling` | `auto` | VAE tiling mode: `auto`, `none`, `aggressive`, `conservative` |
|
| `--tiling` | `auto` | VAE tiling mode: `auto`, `none`, `aggressive`, `conservative` |
|
||||||
| `--stream` | false | Stream frames as they decode |
|
| `--stream` | false | Stream frames as they decode |
|
||||||
|
| `--spatial-upscaler` | auto (x2) | Spatial upscaler file for two-stage pipelines (see below) |
|
||||||
|
|
||||||
|
### Spatial Upscalers (LTX-2.3)
|
||||||
|
|
||||||
|
LTX-2.3 ships with multiple spatial upscaler variants. Use `--spatial-upscaler` to select one:
|
||||||
|
|
||||||
|
| Variant | Scale | Output (from 256x256) | Architecture |
|
||||||
|
|---------|-------|-----------------------|--------------|
|
||||||
|
| `ltx-2.3-spatial-upscaler-x2-1.0.safetensors` (default) | 2.0x | 512x512 | Conv2d + PixelShuffle(2) |
|
||||||
|
| `ltx-2.3-spatial-upscaler-x2-1.1.safetensors` | 2.0x | 512x512 | Same arch, newer weights |
|
||||||
|
| `ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors` | 1.5x | 384x384 | Conv2d + PixelShuffle(3) + BlurDownsample |
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Default (x2-1.0, auto-detected)
|
||||||
|
uv run mlx_video.generate --prompt "A sunset" --model-repo ./LTX-2.3-distilled
|
||||||
|
|
||||||
|
# x2-1.1 (newer weights)
|
||||||
|
uv run mlx_video.generate --prompt "A sunset" --model-repo ./LTX-2.3-distilled \
|
||||||
|
--spatial-upscaler ltx-2.3-spatial-upscaler-x2-1.1.safetensors
|
||||||
|
|
||||||
|
# x1.5 (smaller output, faster)
|
||||||
|
uv run mlx_video.generate --prompt "A sunset" --model-repo ./LTX-2.3-distilled \
|
||||||
|
--spatial-upscaler ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note:** Stage 1 always runs at half the target resolution. With x1.5, the final output is 75% of `--width`/`--height` (e.g., 512 target -> 256 stage 1 -> 384 output). With x2, the output matches the target exactly.
|
||||||
|
|
||||||
### Dev / Dev-Two-Stage
|
### Dev / Dev-Two-Stage
|
||||||
|
|
||||||
@@ -189,8 +215,8 @@ HQ defaults: 15 steps (vs 30), `cfg-rescale` 0.45 (vs 0.7), STG disabled. Uses t
|
|||||||
|
|
||||||
### Distilled Pipeline (default)
|
### Distilled Pipeline (default)
|
||||||
1. **Stage 1**: Generate at half resolution with 8 denoising steps (fixed sigmas)
|
1. **Stage 1**: Generate at half resolution with 8 denoising steps (fixed sigmas)
|
||||||
2. **Upsample**: 2x spatial upsampling via LatentUpsampler
|
2. **Upsample**: Spatial upsampling via LatentUpsampler (x2 or x1.5, selectable via `--spatial-upscaler`)
|
||||||
3. **Stage 2**: Refine at full resolution with 3 denoising steps
|
3. **Stage 2**: Refine at upsampled resolution with 3 denoising steps
|
||||||
4. **Decode**: VAE decoder converts latents to RGB video
|
4. **Decode**: VAE decoder converts latents to RGB video
|
||||||
|
|
||||||
### Dev Pipeline
|
### Dev Pipeline
|
||||||
@@ -199,14 +225,14 @@ HQ defaults: 15 steps (vs 30), `cfg-rescale` 0.45 (vs 0.7), STG disabled. Uses t
|
|||||||
|
|
||||||
### Dev Two-Stage Pipeline
|
### Dev Two-Stage Pipeline
|
||||||
1. **Stage 1**: Dev denoising at half resolution with CFG
|
1. **Stage 1**: Dev denoising at half resolution with CFG
|
||||||
2. **Upsample**: 2x spatial upsampling via LatentUpsampler
|
2. **Upsample**: Spatial upsampling via LatentUpsampler (x2 or x1.5)
|
||||||
3. **Stage 2**: Distilled refinement at full resolution with LoRA weights (3 steps, no CFG)
|
3. **Stage 2**: Distilled refinement at upsampled resolution with LoRA weights (3 steps, no CFG)
|
||||||
4. **Decode**: VAE decoder converts latents to RGB video
|
4. **Decode**: VAE decoder converts latents to RGB video
|
||||||
|
|
||||||
### Dev Two-Stage HQ Pipeline
|
### Dev Two-Stage HQ Pipeline
|
||||||
1. **Stage 1**: res_2s denoising at half resolution with CFG + LoRA@0.25 (15 steps, 2 evals/step)
|
1. **Stage 1**: res_2s denoising at half resolution with CFG + LoRA@0.25 (15 steps, 2 evals/step)
|
||||||
2. **Upsample**: 2x spatial upsampling via LatentUpsampler
|
2. **Upsample**: Spatial upsampling via LatentUpsampler (x2 or x1.5)
|
||||||
3. **Stage 2**: res_2s refinement at full resolution with LoRA@0.5 (3 steps, no CFG)
|
3. **Stage 2**: res_2s refinement at upsampled resolution with LoRA@0.5 (3 steps, no CFG)
|
||||||
4. **Decode**: VAE decoder converts latents to RGB video
|
4. **Decode**: VAE decoder converts latents to RGB video
|
||||||
|
|
||||||
The res_2s sampler uses an exponential Rosenbrock-type Runge-Kutta integrator with SDE noise injection, producing higher quality results than Euler at the same compute budget (~30 total model evaluations).
|
The res_2s sampler uses an exponential Rosenbrock-type Runge-Kutta integrator with SDE noise injection, producing higher quality results than Euler at the same compute budget (~30 total model evaluations).
|
||||||
|
|||||||
@@ -1461,6 +1461,7 @@ def generate_video(
|
|||||||
lora_strength_stage_2: Optional[float] = None,
|
lora_strength_stage_2: Optional[float] = None,
|
||||||
audio_file: Optional[str] = None,
|
audio_file: Optional[str] = None,
|
||||||
audio_start_time: float = 0.0,
|
audio_start_time: float = 0.0,
|
||||||
|
spatial_upscaler: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Generate video using LTX-2 models.
|
"""Generate video using LTX-2 models.
|
||||||
|
|
||||||
@@ -1557,10 +1558,35 @@ def generate_video(
|
|||||||
model_path = get_model_path(model_repo)
|
model_path = get_model_path(model_repo)
|
||||||
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||||
|
|
||||||
|
# Resolve spatial upscaler path for two-stage pipelines
|
||||||
|
upscaler_path = None
|
||||||
|
upscaler_scale = 2.0
|
||||||
|
if is_two_stage:
|
||||||
|
if spatial_upscaler is not None:
|
||||||
|
# User-specified upscaler file
|
||||||
|
upscaler_path = model_path / spatial_upscaler if not Path(spatial_upscaler).is_absolute() else Path(spatial_upscaler)
|
||||||
|
if not upscaler_path.exists():
|
||||||
|
# Try as a filename within model_path
|
||||||
|
upscaler_path = model_path / spatial_upscaler
|
||||||
|
# Detect scale from filename
|
||||||
|
if "x1.5" in str(upscaler_path):
|
||||||
|
upscaler_scale = 1.5
|
||||||
|
elif "x2" in str(upscaler_path):
|
||||||
|
upscaler_scale = 2.0
|
||||||
|
else:
|
||||||
|
# Auto-detect: prefer x2 upscaler
|
||||||
|
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
||||||
|
if upscaler_files:
|
||||||
|
upscaler_path = upscaler_files[0]
|
||||||
|
upscaler_scale = 2.0
|
||||||
|
|
||||||
# Calculate latent dimensions
|
# Calculate latent dimensions
|
||||||
if is_two_stage:
|
if is_two_stage:
|
||||||
|
# Stage 1 always at half resolution (matches PyTorch)
|
||||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||||
stage2_h, stage2_w = height // 32, width // 32
|
# Stage 2 resolution = stage 1 * upscaler scale
|
||||||
|
stage2_h = int(stage1_h * upscaler_scale)
|
||||||
|
stage2_w = int(stage1_w * upscaler_scale)
|
||||||
else:
|
else:
|
||||||
latent_h, latent_w = height // 32, width // 32
|
latent_h, latent_w = height // 32, width // 32
|
||||||
latent_frames = 1 + (num_frames - 1) // 8
|
latent_frames = 1 + (num_frames - 1) // 8
|
||||||
@@ -1697,13 +1723,15 @@ def generate_video(
|
|||||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||||
|
|
||||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
s1_h, s1_w = stage1_h * 32, stage1_w * 32
|
||||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
|
||||||
|
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
|
||||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||||
mx.eval(stage1_image_latent)
|
mx.eval(stage1_image_latent)
|
||||||
|
|
||||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
s2_h, s2_w = stage2_h * 32, stage2_w * 32
|
||||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
|
||||||
|
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
|
||||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||||
mx.eval(stage2_image_latent)
|
mx.eval(stage2_image_latent)
|
||||||
|
|
||||||
@@ -1712,7 +1740,7 @@ def generate_video(
|
|||||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||||
|
|
||||||
# Stage 1
|
# Stage 1
|
||||||
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {width//2}x{height//2} (8 steps)")
|
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Generating at {stage1_w*32}x{stage1_h*32} (8 steps)")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
@@ -1757,11 +1785,10 @@ def generate_video(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Upsample latents
|
# Upsample latents
|
||||||
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
with console.status(f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
|
||||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
if upscaler_path is None or not upscaler_path.exists():
|
||||||
if not upscaler_files:
|
|
||||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
|
||||||
mx.eval(upsampler.parameters())
|
mx.eval(upsampler.parameters())
|
||||||
|
|
||||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||||
@@ -1774,7 +1801,7 @@ def generate_video(
|
|||||||
console.print("[green]✓[/] Latents upsampled")
|
console.print("[green]✓[/] Latents upsampled")
|
||||||
|
|
||||||
# Stage 2
|
# Stage 2
|
||||||
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {width}x{height} (3 steps)")
|
console.print(f"\n[bold yellow]⚡ Stage 2:[/] Refining at {stage2_w*32}x{stage2_h*32} (3 steps)")
|
||||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
@@ -1916,13 +1943,15 @@ def generate_video(
|
|||||||
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
with console.status("[blue]🖼️ Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||||
|
|
||||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
s1_h, s1_w = stage1_h * 32, stage1_w * 32
|
||||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
|
||||||
|
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
|
||||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||||
mx.eval(stage1_image_latent)
|
mx.eval(stage1_image_latent)
|
||||||
|
|
||||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
s2_h, s2_w = stage2_h * 32, stage2_w * 32
|
||||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
|
||||||
|
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
|
||||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||||
mx.eval(stage2_image_latent)
|
mx.eval(stage2_image_latent)
|
||||||
|
|
||||||
@@ -1930,12 +1959,12 @@ def generate_video(
|
|||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||||
|
|
||||||
# Stage 1: Dev denoising at half resolution with CFG
|
# Stage 1: Dev denoising at reduced resolution with CFG
|
||||||
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
||||||
mx.eval(sigmas)
|
mx.eval(sigmas)
|
||||||
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||||
|
|
||||||
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {width//2}x{height//2} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
|
console.print(f"\n[bold yellow]⚡ Stage 1:[/] Dev generating at {stage1_w*32}x{stage1_h*32} ({num_inference_steps} steps, CFG={cfg_scale}, rescale={cfg_rescale})")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
@@ -1989,12 +2018,11 @@ def generate_video(
|
|||||||
|
|
||||||
mx.eval(audio_latents)
|
mx.eval(audio_latents)
|
||||||
|
|
||||||
# Upsample latents 2x
|
# Upsample latents
|
||||||
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
with console.status(f"[magenta]🔍 Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
|
||||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
if upscaler_path is None or not upscaler_path.exists():
|
||||||
if not upscaler_files:
|
|
||||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
|
||||||
mx.eval(upsampler.parameters())
|
mx.eval(upsampler.parameters())
|
||||||
|
|
||||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||||
@@ -2091,13 +2119,15 @@ def generate_video(
|
|||||||
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
with console.status("[blue]Loading VAE encoder and encoding image...[/]", spinner="dots"):
|
||||||
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
vae_encoder = VideoEncoder.from_pretrained(model_path / "vae" / "encoder")
|
||||||
|
|
||||||
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
s1_h, s1_w = stage1_h * 32, stage1_w * 32
|
||||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
input_image = load_image(image, height=s1_h, width=s1_w, dtype=model_dtype)
|
||||||
|
stage1_image_tensor = prepare_image_for_encoding(input_image, s1_h, s1_w, dtype=model_dtype)
|
||||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||||
mx.eval(stage1_image_latent)
|
mx.eval(stage1_image_latent)
|
||||||
|
|
||||||
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
s2_h, s2_w = stage2_h * 32, stage2_w * 32
|
||||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
input_image = load_image(image, height=s2_h, width=s2_w, dtype=model_dtype)
|
||||||
|
stage2_image_tensor = prepare_image_for_encoding(input_image, s2_h, s2_w, dtype=model_dtype)
|
||||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||||
mx.eval(stage2_image_latent)
|
mx.eval(stage2_image_latent)
|
||||||
|
|
||||||
@@ -2118,14 +2148,14 @@ def generate_video(
|
|||||||
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
|
with console.status(f"[blue]Merging distilled LoRA (stage 1, strength={hq_lora_strength_s1})...[/]", spinner="dots"):
|
||||||
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
|
load_and_merge_lora(transformer, lora_path, strength=hq_lora_strength_s1)
|
||||||
|
|
||||||
# Stage 1: res_2s denoising at half resolution with CFG
|
# Stage 1: res_2s denoising at reduced resolution with CFG
|
||||||
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
|
# HQ passes actual token count to scheduler (unlike regular dev-two-stage)
|
||||||
num_tokens = latent_frames * stage1_h * stage1_w
|
num_tokens = latent_frames * stage1_h * stage1_w
|
||||||
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
|
sigmas = ltx2_scheduler(steps=hq_steps, num_tokens=num_tokens)
|
||||||
mx.eval(sigmas)
|
mx.eval(sigmas)
|
||||||
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
|
console.print(f"[dim]Stage 1 sigma schedule: {sigmas[0].item():.4f} -> {sigmas[-2].item():.4f} -> {sigmas[-1].item():.4f} (tokens={num_tokens})[/]")
|
||||||
|
|
||||||
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {width//2}x{height//2} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
|
console.print(f"\n[bold yellow]Stage 1:[/] res_2s at {stage1_w*32}x{stage1_h*32} ({hq_steps} steps, CFG={cfg_scale}, rescale={hq_cfg_rescale})")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
@@ -2179,12 +2209,11 @@ def generate_video(
|
|||||||
|
|
||||||
mx.eval(audio_latents)
|
mx.eval(audio_latents)
|
||||||
|
|
||||||
# Upsample latents 2x
|
# Upsample latents
|
||||||
with console.status("[magenta]Upsampling latents 2x...[/]", spinner="dots"):
|
with console.status(f"[magenta]Upsampling latents {upscaler_scale}x...[/]", spinner="dots"):
|
||||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
if upscaler_path is None or not upscaler_path.exists():
|
||||||
if not upscaler_files:
|
|
||||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
upsampler, upscaler_scale = load_upsampler(str(upscaler_path))
|
||||||
mx.eval(upsampler.parameters())
|
mx.eval(upsampler.parameters())
|
||||||
|
|
||||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||||
@@ -2204,7 +2233,7 @@ def generate_video(
|
|||||||
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
|
load_and_merge_lora(transformer, lora_path, strength=additional_strength)
|
||||||
|
|
||||||
# Stage 2: res_2s refinement at full resolution (no CFG)
|
# Stage 2: res_2s refinement at full resolution (no CFG)
|
||||||
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {width}x{height} (3 steps, no CFG)")
|
console.print(f"\n[bold yellow]Stage 2:[/] res_2s refining at {stage2_w*32}x{stage2_h*32} (3 steps, no CFG)")
|
||||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
@@ -2509,6 +2538,9 @@ Examples:
|
|||||||
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
|
parser.add_argument("--lora-strength", type=float, default=1.0, help="LoRA merge strength (dev-two-stage pipeline, default 1.0)")
|
||||||
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
|
parser.add_argument("--lora-strength-stage-1", type=float, default=0.25, help="LoRA strength for HQ stage 1 (default 0.25)")
|
||||||
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
|
parser.add_argument("--lora-strength-stage-2", type=float, default=0.5, help="LoRA strength for HQ stage 2 (default 0.5)")
|
||||||
|
parser.add_argument("--spatial-upscaler", type=str, default=None,
|
||||||
|
help="Spatial upscaler filename (e.g. ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors). "
|
||||||
|
"Auto-detects x2 by default. Use this to select x1.5 or a specific version.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
pipeline_map = {
|
pipeline_map = {
|
||||||
@@ -2559,6 +2591,7 @@ Examples:
|
|||||||
lora_strength_stage_2=args.lora_strength_stage_2,
|
lora_strength_stage_2=args.lora_strength_stage_2,
|
||||||
audio_file=args.audio_file,
|
audio_file=args.audio_file,
|
||||||
audio_start_time=args.audio_start_time,
|
audio_start_time=args.audio_start_time,
|
||||||
|
spatial_upscaler=args.spatial_upscaler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -115,65 +115,135 @@ class GroupNorm3d(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PixelShuffle2D(nn.Module):
|
class PixelShuffle2D(nn.Module):
|
||||||
"""Pixel shuffle for 2D spatial upsampling."""
|
"""Pixel shuffle for 2D spatial upsampling with per-axis factors."""
|
||||||
|
|
||||||
def __init__(self, upscale_factor: int = 2):
|
def __init__(self, upscale_factor_h: int = 2, upscale_factor_w: int = 2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.upscale_factor = upscale_factor
|
self.rh = upscale_factor_h
|
||||||
|
self.rw = upscale_factor_w
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
|
# x: (N, H, W, C) where C = out_channels * rh * rw
|
||||||
n, h, w, c = x.shape
|
n, h, w, c = x.shape
|
||||||
r = self.upscale_factor
|
rh, rw = self.rh, self.rw
|
||||||
out_c = c // (r * r)
|
out_c = c // (rh * rw)
|
||||||
|
|
||||||
# Reshape: (N, H, W, out_c, r, r)
|
# Reshape: (N, H, W, out_c, rh, rw)
|
||||||
x = mx.reshape(x, (n, h, w, out_c, r, r))
|
x = mx.reshape(x, (n, h, w, out_c, rh, rw))
|
||||||
|
|
||||||
# Permute: (N, H, r, W, r, out_c)
|
# Permute: (N, H, rh, W, rw, out_c)
|
||||||
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
|
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
|
||||||
|
|
||||||
# Reshape: (N, H*r, W*r, out_c)
|
# Reshape: (N, H*rh, W*rw, out_c)
|
||||||
x = mx.reshape(x, (n, h * r, w * r, out_c))
|
x = mx.reshape(x, (n, h * rh, w * rw, out_c))
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BlurDownsample(nn.Module):
|
||||||
|
"""Anti-aliased downsampling with a fixed 5x5 binomial blur kernel.
|
||||||
|
|
||||||
|
PyTorch source uses a depthwise conv with the binomial kernel.
|
||||||
|
The kernel weight is stored as (1, 1, 5, 5) and loaded via safetensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stride: int = 2):
|
||||||
|
super().__init__()
|
||||||
|
self.stride = stride
|
||||||
|
# 5x5 binomial (1,4,6,4,1) kernel, normalized
|
||||||
|
# This will be overwritten by loaded weights if available
|
||||||
|
k = mx.array([1.0, 4.0, 6.0, 4.0, 1.0])
|
||||||
|
kernel_2d = mx.outer(k, k)
|
||||||
|
kernel_2d = kernel_2d / kernel_2d.sum()
|
||||||
|
# MLX conv2d weight: (O, H, W, I) — we use (1, 5, 5, 1) for per-channel
|
||||||
|
self.kernel = kernel_2d.reshape(1, 5, 5, 1)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# x: (N, H, W, C) channels-last
|
||||||
|
n, h, w, c = x.shape
|
||||||
|
|
||||||
|
# Pad with edge replication (2 on each side for 5x5 kernel)
|
||||||
|
x = mx.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], mode="edge")
|
||||||
|
|
||||||
|
# Apply blur per-channel: reshape so each channel is a separate "batch"
|
||||||
|
# (N, H+4, W+4, C) -> (N*C, H+4, W+4, 1)
|
||||||
|
x = mx.transpose(x, (0, 3, 1, 2)) # (N, C, H+4, W+4)
|
||||||
|
x = mx.reshape(x, (n * c, h + 4, w + 4, 1))
|
||||||
|
|
||||||
|
# Depthwise conv: (N*C, H+4, W+4, 1) * (1, 5, 5, 1) -> (N*C, H_out, W_out, 1)
|
||||||
|
x = mx.conv2d(x, self.kernel, stride=(self.stride, self.stride))
|
||||||
|
|
||||||
|
_, h_out, w_out, _ = x.shape
|
||||||
|
# Reshape back: (N*C, H_out, W_out, 1) -> (N, C, H_out, W_out) -> (N, H_out, W_out, C)
|
||||||
|
x = mx.reshape(x, (n, c, h_out, w_out))
|
||||||
|
x = mx.transpose(x, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialUpsampler2x(nn.Module):
|
||||||
|
"""Standard 2x spatial upsampler: Conv2d + PixelShuffle(2)."""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int = 1024):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = 2.0
|
||||||
|
# Sequential: conv (index 0) + pixel shuffle
|
||||||
|
# Weight key: upsampler.0.weight -> mapped to upsampler.conv.weight in sanitize
|
||||||
|
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.pixel_shuffle = PixelShuffle2D(2, 2)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# x: (N, D, H, W, C)
|
||||||
|
n, d, h, w, c = x.shape
|
||||||
|
x = mx.reshape(x, (n * d, h, w, c))
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.pixel_shuffle(x)
|
||||||
|
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SpatialRationalResampler(nn.Module):
|
class SpatialRationalResampler(nn.Module):
|
||||||
|
"""Rational spatial resampler for non-integer scale factors (e.g., 1.5x).
|
||||||
|
|
||||||
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
|
For scale=1.5: upsample 3x via PixelShuffle, then downsample 2x via BlurDownsample.
|
||||||
|
Rational fraction: 1.5 = 3/2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int = 1024, scale: float = 1.5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
|
# Rational fraction for 1.5: numerator=3, denominator=2
|
||||||
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
|
num, den = _rational_for_scale(scale)
|
||||||
|
self.num = num
|
||||||
|
self.den = den
|
||||||
|
|
||||||
# Blur kernel for antialiasing
|
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
|
||||||
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
|
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.pixel_shuffle = PixelShuffle2D(num, num)
|
||||||
self.pixel_shuffle = PixelShuffle2D(2)
|
self.blur_down = BlurDownsample(stride=den)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
# x: (N, D, H, W, C) - channels last 3D format
|
# x: (N, D, H, W, C)
|
||||||
|
|
||||||
n, d, h, w, c = x.shape
|
n, d, h, w, c = x.shape
|
||||||
|
|
||||||
# Process frame by frame
|
|
||||||
# Reshape to (N*D, H, W, C) for 2D operations
|
|
||||||
x = mx.reshape(x, (n * d, h, w, c))
|
x = mx.reshape(x, (n * d, h, w, c))
|
||||||
|
|
||||||
# Apply 2D conv
|
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
|
x = self.pixel_shuffle(x) # H*num, W*num
|
||||||
|
x = self.blur_down(x) # H*num/den, W*num/den
|
||||||
|
|
||||||
# Pixel shuffle for 2x upscaling
|
_, h_out, w_out, _ = x.shape
|
||||||
x = self.pixel_shuffle(x)
|
x = mx.reshape(x, (n, d, h_out, w_out, c))
|
||||||
|
|
||||||
# Reshape back to (N, D, H*2, W*2, C)
|
|
||||||
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||||
|
"""Convert a float scale to a rational fraction (numerator, denominator)."""
|
||||||
|
from fractions import Fraction
|
||||||
|
frac = Fraction(scale).limit_denominator(10)
|
||||||
|
return frac.numerator, frac.denominator
|
||||||
|
|
||||||
|
|
||||||
class ResBlock3D(nn.Module):
|
class ResBlock3D(nn.Module):
|
||||||
|
|
||||||
def __init__(self, channels: int):
|
def __init__(self, channels: int):
|
||||||
@@ -201,17 +271,19 @@ class ResBlock3D(nn.Module):
|
|||||||
|
|
||||||
class LatentUpsampler(nn.Module):
|
class LatentUpsampler(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int = 128,
|
in_channels: int = 128,
|
||||||
mid_channels: int = 1024,
|
mid_channels: int = 1024,
|
||||||
num_blocks_per_stage: int = 4,
|
num_blocks_per_stage: int = 4,
|
||||||
|
spatial_scale: float = 2.0,
|
||||||
|
rational_resampler: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.mid_channels = mid_channels
|
self.mid_channels = mid_channels
|
||||||
|
self.spatial_scale = spatial_scale
|
||||||
|
|
||||||
# Initial projection
|
# Initial projection
|
||||||
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
|
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||||
@@ -221,7 +293,10 @@ class LatentUpsampler(nn.Module):
|
|||||||
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||||
|
|
||||||
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
||||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
|
if rational_resampler:
|
||||||
|
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale)
|
||||||
|
else:
|
||||||
|
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
|
||||||
|
|
||||||
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||||
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||||
@@ -230,14 +305,14 @@ class LatentUpsampler(nn.Module):
|
|||||||
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
|
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
|
||||||
"""Upsample latents by 2x spatially.
|
"""Upsample latents spatially.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
latent: Input tensor of shape (B, C, F, H, W) - channels first
|
latent: Input tensor of shape (B, C, F, H, W) - channels first
|
||||||
debug: If True, print intermediate values for debugging
|
debug: If True, print intermediate values for debugging
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
|
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
|
||||||
"""
|
"""
|
||||||
def debug_stats(name, t):
|
def debug_stats(name, t):
|
||||||
if debug:
|
if debug:
|
||||||
@@ -250,41 +325,27 @@ class LatentUpsampler(nn.Module):
|
|||||||
|
|
||||||
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
|
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
|
||||||
x = mx.transpose(latent, (0, 2, 3, 4, 1))
|
x = mx.transpose(latent, (0, 2, 3, 4, 1))
|
||||||
if debug:
|
|
||||||
debug_stats("After transpose to channels-last", x)
|
|
||||||
|
|
||||||
# Initial conv
|
# Initial conv
|
||||||
x = self.initial_conv(x)
|
x = self.initial_conv(x)
|
||||||
if debug:
|
|
||||||
debug_stats("After initial_conv", x)
|
|
||||||
x = self.initial_norm(x)
|
x = self.initial_norm(x)
|
||||||
if debug:
|
|
||||||
debug_stats("After initial_norm", x)
|
|
||||||
x = nn.silu(x)
|
x = nn.silu(x)
|
||||||
if debug:
|
|
||||||
debug_stats("After silu", x)
|
|
||||||
|
|
||||||
# Pre-upsample blocks
|
# Pre-upsample blocks
|
||||||
for i in sorted(self.res_blocks.keys()):
|
for i in sorted(self.res_blocks.keys()):
|
||||||
x = self.res_blocks[i](x)
|
x = self.res_blocks[i](x)
|
||||||
if debug:
|
|
||||||
debug_stats(f"After res_blocks[{i}]", x)
|
|
||||||
|
|
||||||
# Upsample (2D spatial, frame-by-frame)
|
# Upsample (2D spatial, frame-by-frame)
|
||||||
x = self.upsampler(x)
|
x = self.upsampler(x)
|
||||||
if debug:
|
if debug:
|
||||||
debug_stats("After upsampler (spatial 2x)", x)
|
debug_stats(f"After upsampler (spatial {self.spatial_scale}x)", x)
|
||||||
|
|
||||||
# Post-upsample blocks
|
# Post-upsample blocks
|
||||||
for i in sorted(self.post_upsample_res_blocks.keys()):
|
for i in sorted(self.post_upsample_res_blocks.keys()):
|
||||||
x = self.post_upsample_res_blocks[i](x)
|
x = self.post_upsample_res_blocks[i](x)
|
||||||
if debug:
|
|
||||||
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
|
|
||||||
|
|
||||||
# Final conv
|
# Final conv
|
||||||
x = self.final_conv(x)
|
x = self.final_conv(x)
|
||||||
if debug:
|
|
||||||
debug_stats("After final_conv", x)
|
|
||||||
|
|
||||||
# Convert back to channels first (B, C, F, H, W)
|
# Convert back to channels first (B, C, F, H, W)
|
||||||
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||||
@@ -315,33 +376,49 @@ def upsample_latents(
|
|||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
|
||||||
def load_upsampler(weights_path: str) -> LatentUpsampler:
|
def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
||||||
"""Load upsampler from safetensors weights.
|
"""Load upsampler from safetensors weights.
|
||||||
|
|
||||||
|
Auto-detects whether the weights are for x2 or x1.5 upscaling based on
|
||||||
|
the upsampler conv output channels:
|
||||||
|
- x2: upsampler.0.weight shape [4*mid, mid, 3, 3] (4096 out channels)
|
||||||
|
- x1.5: upsampler.conv.weight shape [9*mid, mid, 3, 3] (9216 out channels)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
weights_path: Path to upsampler weights file
|
weights_path: Path to upsampler weights file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded LatentUpsampler model
|
Tuple of (LatentUpsampler model, spatial_scale)
|
||||||
"""
|
"""
|
||||||
print(f"Loading spatial upsampler from {weights_path}...")
|
print(f"Loading spatial upsampler from {weights_path}...")
|
||||||
raw_weights = mx.load(weights_path)
|
raw_weights = mx.load(weights_path)
|
||||||
|
|
||||||
# Check weight shapes to determine mid_channels
|
# Detect mid_channels from res_blocks
|
||||||
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
|
|
||||||
sample_key = "res_blocks.0.conv1.weight"
|
sample_key = "res_blocks.0.conv1.weight"
|
||||||
if sample_key in raw_weights:
|
if sample_key in raw_weights:
|
||||||
mid_channels = raw_weights[sample_key].shape[0]
|
mid_channels = raw_weights[sample_key].shape[0]
|
||||||
else:
|
else:
|
||||||
mid_channels = 1024 # default
|
mid_channels = 1024
|
||||||
|
|
||||||
print(f" Detected mid_channels: {mid_channels}")
|
# Detect upsampler type from conv output channels
|
||||||
|
# x2 uses sequential: upsampler.0.weight (4*mid out channels)
|
||||||
|
# x1.5 uses named: upsampler.conv.weight (9*mid out channels) + upsampler.blur_down.kernel
|
||||||
|
rational_resampler = "upsampler.blur_down.kernel" in raw_weights
|
||||||
|
if rational_resampler:
|
||||||
|
# x1.5: conv out = 9 * mid_channels (3^2 * mid for PixelShuffle(3))
|
||||||
|
spatial_scale = 1.5
|
||||||
|
else:
|
||||||
|
spatial_scale = 2.0
|
||||||
|
|
||||||
|
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
|
||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
upsampler = LatentUpsampler(
|
upsampler = LatentUpsampler(
|
||||||
in_channels=128,
|
in_channels=128,
|
||||||
mid_channels=mid_channels,
|
mid_channels=mid_channels,
|
||||||
num_blocks_per_stage=4,
|
num_blocks_per_stage=4,
|
||||||
|
spatial_scale=spatial_scale,
|
||||||
|
rational_resampler=rational_resampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sanitize weights - convert from PyTorch to MLX format
|
# Sanitize weights - convert from PyTorch to MLX format
|
||||||
@@ -349,7 +426,7 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
|
|||||||
for key, value in raw_weights.items():
|
for key, value in raw_weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
|
|
||||||
# LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
|
# x2 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
|
||||||
if key.startswith("upsampler.0."):
|
if key.startswith("upsampler.0."):
|
||||||
new_key = key.replace("upsampler.0.", "upsampler.conv.")
|
new_key = key.replace("upsampler.0.", "upsampler.conv.")
|
||||||
|
|
||||||
@@ -358,7 +435,7 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
|
|||||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||||
if "weight" in new_key and value.ndim == 4:
|
if ("weight" in new_key or "kernel" in new_key) and value.ndim == 4:
|
||||||
value = mx.transpose(value, (0, 2, 3, 1))
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
@@ -368,4 +445,4 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
|
|||||||
|
|
||||||
print(f" Loaded {len(sanitized)} weights")
|
print(f" Loaded {len(sanitized)} weights")
|
||||||
|
|
||||||
return upsampler
|
return upsampler, spatial_scale
|
||||||
|
|||||||
Reference in New Issue
Block a user