Fix LTX-2.3 decoder grainy bug
This commit is contained in:
@@ -316,11 +316,12 @@ class LTX2VideoDecoder(nn.Module):
|
||||
elif block_type == "d2s":
|
||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||
residual = block_def[4] if len(block_def) > 4 else True
|
||||
self.up_blocks[idx] = DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=ch,
|
||||
stride=stride,
|
||||
residual=True,
|
||||
residual=residual,
|
||||
out_channels_reduction_factor=reduction,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
@@ -425,9 +426,14 @@ class LTX2VideoDecoder(nn.Module):
|
||||
# Infer block structure from weights
|
||||
decoder_blocks = cls._infer_blocks(weights)
|
||||
|
||||
# Determine spatial padding mode from config
|
||||
spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect")
|
||||
spatial_padding_mode = PaddingModeType(spatial_padding_mode_str)
|
||||
|
||||
model = cls(
|
||||
timestep_conditioning=config_dict.get("timestep_conditioning", False),
|
||||
decoder_blocks=decoder_blocks,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
@@ -477,6 +483,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
# Second pass: determine d2s strides using the channel progression
|
||||
# For each d2s block, the next res block tells us the expected output channels
|
||||
blocks = []
|
||||
d2s_strides = []
|
||||
for i, block in enumerate(raw_blocks):
|
||||
if block[0] == "res":
|
||||
blocks.append(block)
|
||||
@@ -508,9 +515,27 @@ class LTX2VideoDecoder(nn.Module):
|
||||
else:
|
||||
stride = (2, 2, 2)
|
||||
|
||||
d2s_strides.append(stride)
|
||||
blocks.append(("d2s", in_ch, reduction, stride))
|
||||
|
||||
return blocks if blocks else None
|
||||
if not blocks:
|
||||
return None
|
||||
|
||||
# Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True
|
||||
# LTX-2.3 has mixed strides or reduction=1 → residual=False
|
||||
has_mixed_strides = len(set(d2s_strides)) > 1
|
||||
has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s")
|
||||
use_residual = not has_mixed_strides and not has_non_standard_reduction
|
||||
|
||||
# Apply residual flag to all d2s blocks
|
||||
final_blocks = []
|
||||
for block in blocks:
|
||||
if block[0] == "d2s":
|
||||
final_blocks.append(("d2s", block[1], block[2], block[3], use_residual))
|
||||
else:
|
||||
final_blocks.append(block)
|
||||
|
||||
return final_blocks
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user