diff --git a/mlx_video/models/ltx_2/upsampler.py b/mlx_video/models/ltx_2/upsampler.py index 9ede781..1056687 100644 --- a/mlx_video/models/ltx_2/upsampler.py +++ b/mlx_video/models/ltx_2/upsampler.py @@ -401,13 +401,17 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]: mid_channels = 1024 # Detect upsampler type from conv output channels - # x2 uses sequential: upsampler.0.weight (4*mid out channels) - # x1.5 uses named: upsampler.conv.weight (9*mid out channels) + upsampler.blur_down.kernel - rational_resampler = "upsampler.blur_down.kernel" in raw_weights - if rational_resampler: - # x1.5: conv out = 9 * mid_channels (3^2 * mid for PixelShuffle(3)) - spatial_scale = 1.5 + # x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2)) + # x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample + # Both formats may have upsampler.blur_down.kernel, so use channel count + conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight" + if conv_key in raw_weights: + out_channels = raw_weights[conv_key].shape[0] + ratio = out_channels // mid_channels + rational_resampler = ratio == 9 # 3^2 for PixelShuffle(3) + blur downsample + spatial_scale = 1.5 if rational_resampler else 2.0 else: + rational_resampler = False spatial_scale = 2.0 print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")