Add custom spatial upscaling support to LTX-2 video generation; introduce spatial_upscaler parameter and enhance resolution handling for two-stage pipelines

This commit is contained in:
Prince Canuma
2026-03-17 02:23:47 +01:00
parent cc302d79b0
commit 57f66bcae2
3 changed files with 234 additions and 98 deletions

View File

@@ -115,65 +115,135 @@ class GroupNorm3d(nn.Module):
class PixelShuffle2D(nn.Module):
"""Pixel shuffle for 2D spatial upsampling."""
"""Pixel shuffle for 2D spatial upsampling with per-axis factors."""
def __init__(self, upscale_factor: int = 2):
def __init__(self, upscale_factor_h: int = 2, upscale_factor_w: int = 2):
super().__init__()
self.upscale_factor = upscale_factor
self.rh = upscale_factor_h
self.rw = upscale_factor_w
def __call__(self, x: mx.array) -> mx.array:
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
# x: (N, H, W, C) where C = out_channels * rh * rw
n, h, w, c = x.shape
r = self.upscale_factor
out_c = c // (r * r)
rh, rw = self.rh, self.rw
out_c = c // (rh * rw)
# Reshape: (N, H, W, out_c, r, r)
x = mx.reshape(x, (n, h, w, out_c, r, r))
# Reshape: (N, H, W, out_c, rh, rw)
x = mx.reshape(x, (n, h, w, out_c, rh, rw))
# Permute: (N, H, r, W, r, out_c)
# Permute: (N, H, rh, W, rw, out_c)
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
# Reshape: (N, H*r, W*r, out_c)
x = mx.reshape(x, (n, h * r, w * r, out_c))
# Reshape: (N, H*rh, W*rw, out_c)
x = mx.reshape(x, (n, h * rh, w * rw, out_c))
return x
class BlurDownsample(nn.Module):
"""Anti-aliased downsampling with a fixed 5x5 binomial blur kernel.
PyTorch source uses a depthwise conv with the binomial kernel.
The kernel weight is stored as (1, 1, 5, 5) and loaded via safetensors.
"""
def __init__(self, stride: int = 2):
super().__init__()
self.stride = stride
# 5x5 binomial (1,4,6,4,1) kernel, normalized
# This will be overwritten by loaded weights if available
k = mx.array([1.0, 4.0, 6.0, 4.0, 1.0])
kernel_2d = mx.outer(k, k)
kernel_2d = kernel_2d / kernel_2d.sum()
# MLX conv2d weight: (O, H, W, I) — we use (1, 5, 5, 1) for per-channel
self.kernel = kernel_2d.reshape(1, 5, 5, 1)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, H, W, C) channels-last
n, h, w, c = x.shape
# Pad with edge replication (2 on each side for 5x5 kernel)
x = mx.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], mode="edge")
# Apply blur per-channel: reshape so each channel is a separate "batch"
# (N, H+4, W+4, C) -> (N*C, H+4, W+4, 1)
x = mx.transpose(x, (0, 3, 1, 2)) # (N, C, H+4, W+4)
x = mx.reshape(x, (n * c, h + 4, w + 4, 1))
# Depthwise conv: (N*C, H+4, W+4, 1) * (1, 5, 5, 1) -> (N*C, H_out, W_out, 1)
x = mx.conv2d(x, self.kernel, stride=(self.stride, self.stride))
_, h_out, w_out, _ = x.shape
# Reshape back: (N*C, H_out, W_out, 1) -> (N, C, H_out, W_out) -> (N, H_out, W_out, C)
x = mx.reshape(x, (n, c, h_out, w_out))
x = mx.transpose(x, (0, 2, 3, 1))
return x
class SpatialUpsampler2x(nn.Module):
"""Standard 2x spatial upsampler: Conv2d + PixelShuffle(2)."""
def __init__(self, mid_channels: int = 1024):
super().__init__()
self.scale = 2.0
# Sequential: conv (index 0) + pixel shuffle
# Weight key: upsampler.0.weight -> mapped to upsampler.conv.weight in sanitize
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffle2D(2, 2)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C)
n, d, h, w, c = x.shape
x = mx.reshape(x, (n * d, h, w, c))
x = self.conv(x)
x = self.pixel_shuffle(x)
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
return x
class SpatialRationalResampler(nn.Module):
"""Rational spatial resampler for non-integer scale factors (e.g., 1.5x).
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
For scale=1.5: upsample 3x via PixelShuffle, then downsample 2x via BlurDownsample.
Rational fraction: 1.5 = 3/2.
"""
def __init__(self, mid_channels: int = 1024, scale: float = 1.5):
super().__init__()
self.scale = scale
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
# Rational fraction for 1.5: numerator=3, denominator=2
num, den = _rational_for_scale(scale)
self.num = num
self.den = den
# Blur kernel for antialiasing
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
self.pixel_shuffle = PixelShuffle2D(2)
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffle2D(num, num)
self.blur_down = BlurDownsample(stride=den)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C) - channels last 3D format
# x: (N, D, H, W, C)
n, d, h, w, c = x.shape
# Process frame by frame
# Reshape to (N*D, H, W, C) for 2D operations
x = mx.reshape(x, (n * d, h, w, c))
# Apply 2D conv
x = self.conv(x)
x = self.pixel_shuffle(x) # H*num, W*num
x = self.blur_down(x) # H*num/den, W*num/den
# Pixel shuffle for 2x upscaling
x = self.pixel_shuffle(x)
# Reshape back to (N, D, H*2, W*2, C)
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
_, h_out, w_out, _ = x.shape
x = mx.reshape(x, (n, d, h_out, w_out, c))
return x
def _rational_for_scale(scale: float) -> Tuple[int, int]:
"""Convert a float scale to a rational fraction (numerator, denominator)."""
from fractions import Fraction
frac = Fraction(scale).limit_denominator(10)
return frac.numerator, frac.denominator
class ResBlock3D(nn.Module):
def __init__(self, channels: int):
@@ -201,17 +271,19 @@ class ResBlock3D(nn.Module):
class LatentUpsampler(nn.Module):
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
spatial_scale: float = 2.0,
rational_resampler: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.spatial_scale = spatial_scale
# Initial projection
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
@@ -221,7 +293,10 @@ class LatentUpsampler(nn.Module):
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
# Upsampler: 2D spatial upsampling (frame-by-frame)
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
if rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale)
else:
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
@@ -230,14 +305,14 @@ class LatentUpsampler(nn.Module):
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
"""Upsample latents by 2x spatially.
"""Upsample latents spatially.
Args:
latent: Input tensor of shape (B, C, F, H, W) - channels first
debug: If True, print intermediate values for debugging
Returns:
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
"""
def debug_stats(name, t):
if debug:
@@ -250,41 +325,27 @@ class LatentUpsampler(nn.Module):
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
x = mx.transpose(latent, (0, 2, 3, 4, 1))
if debug:
debug_stats("After transpose to channels-last", x)
# Initial conv
x = self.initial_conv(x)
if debug:
debug_stats("After initial_conv", x)
x = self.initial_norm(x)
if debug:
debug_stats("After initial_norm", x)
x = nn.silu(x)
if debug:
debug_stats("After silu", x)
# Pre-upsample blocks
for i in sorted(self.res_blocks.keys()):
x = self.res_blocks[i](x)
if debug:
debug_stats(f"After res_blocks[{i}]", x)
# Upsample (2D spatial, frame-by-frame)
x = self.upsampler(x)
if debug:
debug_stats("After upsampler (spatial 2x)", x)
debug_stats(f"After upsampler (spatial {self.spatial_scale}x)", x)
# Post-upsample blocks
for i in sorted(self.post_upsample_res_blocks.keys()):
x = self.post_upsample_res_blocks[i](x)
if debug:
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
# Final conv
x = self.final_conv(x)
if debug:
debug_stats("After final_conv", x)
# Convert back to channels first (B, C, F, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
@@ -315,33 +376,49 @@ def upsample_latents(
return latent
def load_upsampler(weights_path: str) -> LatentUpsampler:
def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
"""Load upsampler from safetensors weights.
Auto-detects whether the weights are for x2 or x1.5 upscaling based on
the upsampler conv output channels:
- x2: upsampler.0.weight shape [4*mid, mid, 3, 3] (4096 out channels)
- x1.5: upsampler.conv.weight shape [9*mid, mid, 3, 3] (9216 out channels)
Args:
weights_path: Path to upsampler weights file
Returns:
Loaded LatentUpsampler model
Tuple of (LatentUpsampler model, spatial_scale)
"""
print(f"Loading spatial upsampler from {weights_path}...")
raw_weights = mx.load(weights_path)
# Check weight shapes to determine mid_channels
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
# Detect mid_channels from res_blocks
sample_key = "res_blocks.0.conv1.weight"
if sample_key in raw_weights:
mid_channels = raw_weights[sample_key].shape[0]
else:
mid_channels = 1024 # default
mid_channels = 1024
print(f" Detected mid_channels: {mid_channels}")
# Detect upsampler type from conv output channels
# x2 uses sequential: upsampler.0.weight (4*mid out channels)
# x1.5 uses named: upsampler.conv.weight (9*mid out channels) + upsampler.blur_down.kernel
rational_resampler = "upsampler.blur_down.kernel" in raw_weights
if rational_resampler:
# x1.5: conv out = 9 * mid_channels (3^2 * mid for PixelShuffle(3))
spatial_scale = 1.5
else:
spatial_scale = 2.0
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
# Create model
upsampler = LatentUpsampler(
in_channels=128,
mid_channels=mid_channels,
num_blocks_per_stage=4,
spatial_scale=spatial_scale,
rational_resampler=rational_resampler,
)
# Sanitize weights - convert from PyTorch to MLX format
@@ -349,7 +426,7 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
for key, value in raw_weights.items():
new_key = key
# LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
# x2 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
if key.startswith("upsampler.0."):
new_key = key.replace("upsampler.0.", "upsampler.conv.")
@@ -358,7 +435,7 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "weight" in new_key and value.ndim == 4:
if ("weight" in new_key or "kernel" in new_key) and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
@@ -368,4 +445,4 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
print(f" Loaded {len(sanitized)} weights")
return upsampler
return upsampler, spatial_scale