Fix LTX-2.3 decoder grainy bug

This commit is contained in:
Prince Canuma
2026-03-14 21:56:03 +01:00
parent 5644492f7d
commit eb0d1355e4

View File

@@ -316,11 +316,12 @@ class LTX2VideoDecoder(nn.Module):
elif block_type == "d2s": elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2 reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2) stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
residual = block_def[4] if len(block_def) > 4 else True
self.up_blocks[idx] = DepthToSpaceUpsample( self.up_blocks[idx] = DepthToSpaceUpsample(
dims=3, dims=3,
in_channels=ch, in_channels=ch,
stride=stride, stride=stride,
residual=True, residual=residual,
out_channels_reduction_factor=reduction, out_channels_reduction_factor=reduction,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
) )
@@ -425,9 +426,14 @@ class LTX2VideoDecoder(nn.Module):
# Infer block structure from weights # Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights) decoder_blocks = cls._infer_blocks(weights)
# Determine spatial padding mode from config
spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect")
spatial_padding_mode = PaddingModeType(spatial_padding_mode_str)
model = cls( model = cls(
timestep_conditioning=config_dict.get("timestep_conditioning", False), timestep_conditioning=config_dict.get("timestep_conditioning", False),
decoder_blocks=decoder_blocks, decoder_blocks=decoder_blocks,
spatial_padding_mode=spatial_padding_mode,
) )
weights = model.sanitize(weights) weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict) model.load_weights(list(weights.items()), strict=strict)
@@ -477,6 +483,7 @@ class LTX2VideoDecoder(nn.Module):
# Second pass: determine d2s strides using the channel progression # Second pass: determine d2s strides using the channel progression
# For each d2s block, the next res block tells us the expected output channels # For each d2s block, the next res block tells us the expected output channels
blocks = [] blocks = []
d2s_strides = []
for i, block in enumerate(raw_blocks): for i, block in enumerate(raw_blocks):
if block[0] == "res": if block[0] == "res":
blocks.append(block) blocks.append(block)
@@ -508,9 +515,27 @@ class LTX2VideoDecoder(nn.Module):
else: else:
stride = (2, 2, 2) stride = (2, 2, 2)
d2s_strides.append(stride)
blocks.append(("d2s", in_ch, reduction, stride)) blocks.append(("d2s", in_ch, reduction, stride))
return blocks if blocks else None if not blocks:
return None
# Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True
# LTX-2.3 has mixed strides or reduction=1 → residual=False
has_mixed_strides = len(set(d2s_strides)) > 1
has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s")
use_residual = not has_mixed_strides and not has_non_standard_reduction
# Apply residual flag to all d2s blocks
final_blocks = []
for block in blocks:
if block[0] == "d2s":
final_blocks.append(("d2s", block[1], block[2], block[3], use_residual))
else:
final_blocks.append(block)
return final_blocks