diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 105082c..be4e794 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -316,11 +316,12 @@ class LTX2VideoDecoder(nn.Module): elif block_type == "d2s": reduction = block_def[2] if len(block_def) > 2 else 2 stride = block_def[3] if len(block_def) > 3 else (2, 2, 2) + residual = block_def[4] if len(block_def) > 4 else True self.up_blocks[idx] = DepthToSpaceUpsample( dims=3, in_channels=ch, stride=stride, - residual=True, + residual=residual, out_channels_reduction_factor=reduction, spatial_padding_mode=spatial_padding_mode, ) @@ -406,7 +407,7 @@ class LTX2VideoDecoder(nn.Module): model_path = Path(model_path) config_dict = {} - + # Load config from directory config_path = model_path / "config.json" if config_path.exists(): @@ -425,9 +426,14 @@ class LTX2VideoDecoder(nn.Module): # Infer block structure from weights decoder_blocks = cls._infer_blocks(weights) + # Determine spatial padding mode from config + spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect") + spatial_padding_mode = PaddingModeType(spatial_padding_mode_str) + model = cls( timestep_conditioning=config_dict.get("timestep_conditioning", False), decoder_blocks=decoder_blocks, + spatial_padding_mode=spatial_padding_mode, ) weights = model.sanitize(weights) model.load_weights(list(weights.items()), strict=strict) @@ -477,6 +483,7 @@ class LTX2VideoDecoder(nn.Module): # Second pass: determine d2s strides using the channel progression # For each d2s block, the next res block tells us the expected output channels blocks = [] + d2s_strides = [] for i, block in enumerate(raw_blocks): if block[0] == "res": blocks.append(block) @@ -508,9 +515,27 @@ class LTX2VideoDecoder(nn.Module): else: stride = (2, 2, 2) + d2s_strides.append(stride) blocks.append(("d2s", in_ch, reduction, stride)) - return blocks if blocks else None + if not blocks: + return None + + # Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True + # LTX-2.3 has mixed strides or reduction=1 → residual=False + has_mixed_strides = len(set(d2s_strides)) > 1 + has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s") + use_residual = not has_mixed_strides and not has_non_standard_reduction + + # Apply residual flag to all d2s blocks + final_blocks = [] + for block in blocks: + if block[0] == "d2s": + final_blocks.append(("d2s", block[1], block[2], block[3], use_residual)) + else: + final_blocks.append(block) + + return final_blocks