Enhance upsampler weight detection logic in LTX-2 model; improve clarity in comments and streamline spatial scale determination for x1.5 and x2 formats
This commit is contained in:
@@ -401,13 +401,17 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
|||||||
mid_channels = 1024
|
mid_channels = 1024
|
||||||
|
|
||||||
# Detect upsampler type from conv output channels
|
# Detect upsampler type from conv output channels
|
||||||
# x2 uses sequential: upsampler.0.weight (4*mid out channels)
|
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
|
||||||
# x1.5 uses named: upsampler.conv.weight (9*mid out channels) + upsampler.blur_down.kernel
|
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
|
||||||
rational_resampler = "upsampler.blur_down.kernel" in raw_weights
|
# Both formats may have upsampler.blur_down.kernel, so use channel count
|
||||||
if rational_resampler:
|
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
|
||||||
# x1.5: conv out = 9 * mid_channels (3^2 * mid for PixelShuffle(3))
|
if conv_key in raw_weights:
|
||||||
spatial_scale = 1.5
|
out_channels = raw_weights[conv_key].shape[0]
|
||||||
|
ratio = out_channels // mid_channels
|
||||||
|
rational_resampler = ratio == 9 # 3^2 for PixelShuffle(3) + blur downsample
|
||||||
|
spatial_scale = 1.5 if rational_resampler else 2.0
|
||||||
else:
|
else:
|
||||||
|
rational_resampler = False
|
||||||
spatial_scale = 2.0
|
spatial_scale = 2.0
|
||||||
|
|
||||||
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
|
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
|
||||||
|
|||||||
Reference in New Issue
Block a user