diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index a187542..ab22374 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -201,14 +201,15 @@ class ResBlockGroup(nn.Module): embedding_dim=channels * 4 ) - self.res_blocks = [ - ResnetBlock3DSimple( + # Use dict with int keys for MLX to track parameters properly + self.res_blocks = { + i: ResnetBlock3DSimple( channels, spatial_padding_mode, timestep_conditioning=timestep_conditioning ) - for _ in range(num_layers) - ] + for i in range(num_layers) + } def __call__( self, @@ -227,7 +228,7 @@ class ResBlockGroup(nn.Module): # Reshape to (B, 4*C, 1, 1, 1) for broadcasting timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1) - for res_block in self.res_blocks: + for res_block in self.res_blocks.values(): x = res_block(x, causal=causal, timestep_embed=timestep_embed) return x @@ -287,10 +288,10 @@ class LTX2VideoDecoder(nn.Module): self.conv_in = ConvInWrapper() # Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample - - self.up_blocks = [ - ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - DepthToSpaceUpsample( + # Use dict with int keys for MLX to track parameters properly + self.up_blocks = { + 0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning), + 1: DepthToSpaceUpsample( dims=3, in_channels=1024, stride=(2, 2, 2), @@ -298,8 +299,8 @@ class LTX2VideoDecoder(nn.Module): out_channels_reduction_factor=2, spatial_padding_mode=spatial_padding_mode, ), - ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - DepthToSpaceUpsample( + 2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning), + 3: DepthToSpaceUpsample( dims=3, in_channels=512, stride=(2, 2, 2), @@ -307,8 +308,8 @@ class LTX2VideoDecoder(nn.Module): out_channels_reduction_factor=2, spatial_padding_mode=spatial_padding_mode, ), - ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - DepthToSpaceUpsample( + 4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning), + 5: DepthToSpaceUpsample( dims=3, in_channels=256, stride=(2, 2, 2), @@ -316,8 +317,8 @@ class LTX2VideoDecoder(nn.Module): out_channels_reduction_factor=2, spatial_padding_mode=spatial_padding_mode, ), - ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning), - ] + 6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning), + } final_out_channels = out_channels * patch_size * patch_size class ConvOutWrapper(nn.Module): @@ -396,10 +397,10 @@ class LTX2VideoDecoder(nn.Module): if debug: debug_stats("After conv_in", x) - for i, block in enumerate(self.up_blocks): + for i, block in self.up_blocks.items(): if isinstance(block, ResBlockGroup): x = block(x, causal=causal, timestep=scaled_timestep) - else: + else: x = block(x, causal=causal) if debug: block_type = type(block).__name__ @@ -443,10 +444,10 @@ class LTX2VideoDecoder(nn.Module): return x -def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder: +def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder: from pathlib import Path - - decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning) + import json + from safetensors import safe_open model_path = Path(model_path) @@ -461,6 +462,25 @@ def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX raise FileNotFoundError(f"VAE weights not found at {model_path}") print(f"Loading VAE decoder from {weights_path}...") + + # Read config from safetensors metadata to auto-detect timestep_conditioning + if timestep_conditioning is None: + try: + with safe_open(str(weights_path), framework="numpy") as f: + metadata = f.metadata() + if metadata and "config" in metadata: + configs = json.loads(metadata["config"]) + vae_config = configs.get("vae", {}) + timestep_conditioning = vae_config.get("timestep_conditioning", False) + print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights") + else: + timestep_conditioning = False + except Exception as e: + print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False") + timestep_conditioning = False + + decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning) + weights = mx.load(str(weights_path)) # Determine prefix based on weight keys