diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 4d6dc48..a187542 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -294,7 +294,7 @@ class LTX2VideoDecoder(nn.Module): dims=3, in_channels=1024, stride=(2, 2, 2), - residual=True, # CRITICAL: Must match PyTorch config! + residual=True, out_channels_reduction_factor=2, spatial_padding_mode=spatial_padding_mode, ), @@ -303,7 +303,7 @@ class LTX2VideoDecoder(nn.Module): dims=3, in_channels=512, stride=(2, 2, 2), - residual=True, # CRITICAL: Must match PyTorch config! + residual=True, out_channels_reduction_factor=2, spatial_padding_mode=spatial_padding_mode, ), @@ -312,7 +312,7 @@ class LTX2VideoDecoder(nn.Module): dims=3, in_channels=256, stride=(2, 2, 2), - residual=True, # CRITICAL: Must match PyTorch config! + residual=True, out_channels_reduction_factor=2, spatial_padding_mode=spatial_padding_mode, ),