From f8e371e9ce757c1e451c278c4b6246a9a966b9a8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 17 Mar 2026 15:14:57 +0100 Subject: [PATCH] Enhance upsampler weight detection logic in LTX-2 model; improve clarity in comments and streamline spatial scale determination for x1.5 and x2 formats --- mlx_video/models/ltx_2/upsampler.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mlx_video/models/ltx_2/upsampler.py b/mlx_video/models/ltx_2/upsampler.py index 9ede781..1056687 100644 --- a/mlx_video/models/ltx_2/upsampler.py +++ b/mlx_video/models/ltx_2/upsampler.py @@ -401,13 +401,17 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]: mid_channels = 1024 # Detect upsampler type from conv output channels - # x2 uses sequential: upsampler.0.weight (4*mid out channels) - # x1.5 uses named: upsampler.conv.weight (9*mid out channels) + upsampler.blur_down.kernel - rational_resampler = "upsampler.blur_down.kernel" in raw_weights - if rational_resampler: - # x1.5: conv out = 9 * mid_channels (3^2 * mid for PixelShuffle(3)) - spatial_scale = 1.5 + # x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2)) + # x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample + # Both formats may have upsampler.blur_down.kernel, so use channel count + conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight" + if conv_key in raw_weights: + out_channels = raw_weights[conv_key].shape[0] + ratio = out_channels // mid_channels + rational_resampler = ratio == 9 # 3^2 for PixelShuffle(3) + blur downsample + spatial_scale = 1.5 if rational_resampler else 2.0 else: + rational_resampler = False spatial_scale = 2.0 print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")