Enhance upsampler weight detection logic in LTX-2 model; improve clarity in comments and streamline spatial scale determination for x1.5 and x2 formats

This commit is contained in:
Prince Canuma
2026-03-17 15:14:57 +01:00
parent 57f66bcae2
commit f8e371e9ce

View File

@@ -401,13 +401,17 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
mid_channels = 1024 mid_channels = 1024
# Detect upsampler type from conv output channels # Detect upsampler type from conv output channels
# x2 uses sequential: upsampler.0.weight (4*mid out channels) # x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
# x1.5 uses named: upsampler.conv.weight (9*mid out channels) + upsampler.blur_down.kernel # x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
rational_resampler = "upsampler.blur_down.kernel" in raw_weights # Both formats may have upsampler.blur_down.kernel, so use channel count
if rational_resampler: conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
# x1.5: conv out = 9 * mid_channels (3^2 * mid for PixelShuffle(3)) if conv_key in raw_weights:
spatial_scale = 1.5 out_channels = raw_weights[conv_key].shape[0]
ratio = out_channels // mid_channels
rational_resampler = ratio == 9 # 3^2 for PixelShuffle(3) + blur downsample
spatial_scale = 1.5 if rational_resampler else 2.0
else: else:
rational_resampler = False
spatial_scale = 2.0 spatial_scale = 2.0
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}") print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")