diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index 2b4bec2..dea4089 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -384,8 +384,9 @@ class LTXModel(nn.Module): video_config = config.get_video_config() audio_config = config.get_audio_config() - self.transformer_blocks = [ - BasicAVTransformerBlock( + + self.transformer_blocks = { + idx: BasicAVTransformerBlock( idx=idx, video=video_config, audio=audio_config, @@ -393,7 +394,7 @@ class LTXModel(nn.Module): norm_eps=config.norm_eps, ) for idx in range(config.num_layers) - ] + } def _process_transformer_blocks( self, @@ -401,7 +402,7 @@ class LTXModel(nn.Module): audio: Optional[TransformerArgs], ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: """Process through all transformer blocks.""" - for block in self.transformer_blocks: + for block in self.transformer_blocks.values(): video, audio = block(video=video, audio=audio) return video, audio