Enhance upsampler weight detection logic in LTX-2 model; improve clarity in comments and streamline spatial scale determination for x1.5 and x2 formats
This commit is contained in:
@@ -401,13 +401,17 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
||||
mid_channels = 1024
|
||||
|
||||
# Detect upsampler type from conv output channels
|
||||
# x2 uses sequential: upsampler.0.weight (4*mid out channels)
|
||||
# x1.5 uses named: upsampler.conv.weight (9*mid out channels) + upsampler.blur_down.kernel
|
||||
rational_resampler = "upsampler.blur_down.kernel" in raw_weights
|
||||
if rational_resampler:
|
||||
# x1.5: conv out = 9 * mid_channels (3^2 * mid for PixelShuffle(3))
|
||||
spatial_scale = 1.5
|
||||
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
|
||||
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
|
||||
# Both formats may have upsampler.blur_down.kernel, so use channel count
|
||||
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
|
||||
if conv_key in raw_weights:
|
||||
out_channels = raw_weights[conv_key].shape[0]
|
||||
ratio = out_channels // mid_channels
|
||||
rational_resampler = ratio == 9 # 3^2 for PixelShuffle(3) + blur downsample
|
||||
spatial_scale = 1.5 if rational_resampler else 2.0
|
||||
else:
|
||||
rational_resampler = False
|
||||
spatial_scale = 2.0
|
||||
|
||||
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
|
||||
|
||||
Reference in New Issue
Block a user