Fix LTX-2.3 decoder grainy bug

This commit is contained in:
Prince Canuma
2026-03-14 21:56:03 +01:00
parent 5644492f7d
commit eb0d1355e4

View File

@@ -316,11 +316,12 @@ class LTX2VideoDecoder(nn.Module):
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
residual = block_def[4] if len(block_def) > 4 else True
self.up_blocks[idx] = DepthToSpaceUpsample(
dims=3,
in_channels=ch,
stride=stride,
residual=True,
residual=residual,
out_channels_reduction_factor=reduction,
spatial_padding_mode=spatial_padding_mode,
)
@@ -406,7 +407,7 @@ class LTX2VideoDecoder(nn.Module):
model_path = Path(model_path)
config_dict = {}
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
@@ -425,9 +426,14 @@ class LTX2VideoDecoder(nn.Module):
# Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
# Determine spatial padding mode from config
spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect")
spatial_padding_mode = PaddingModeType(spatial_padding_mode_str)
model = cls(
timestep_conditioning=config_dict.get("timestep_conditioning", False),
decoder_blocks=decoder_blocks,
spatial_padding_mode=spatial_padding_mode,
)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
@@ -477,6 +483,7 @@ class LTX2VideoDecoder(nn.Module):
# Second pass: determine d2s strides using the channel progression
# For each d2s block, the next res block tells us the expected output channels
blocks = []
d2s_strides = []
for i, block in enumerate(raw_blocks):
if block[0] == "res":
blocks.append(block)
@@ -508,9 +515,27 @@ class LTX2VideoDecoder(nn.Module):
else:
stride = (2, 2, 2)
d2s_strides.append(stride)
blocks.append(("d2s", in_ch, reduction, stride))
return blocks if blocks else None
if not blocks:
return None
# Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True
# LTX-2.3 has mixed strides or reduction=1 → residual=False
has_mixed_strides = len(set(d2s_strides)) > 1
has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s")
use_residual = not has_mixed_strides and not has_non_standard_reduction
# Apply residual flag to all d2s blocks
final_blocks = []
for block in blocks:
if block[0] == "d2s":
final_blocks.append(("d2s", block[1], block[2], block[3], use_residual))
else:
final_blocks.append(block)
return final_blocks