Refactor LTXModel: Change transformer_blocks from list to dictionary

This commit is contained in:
Prince Canuma
2026-01-15 03:47:52 +01:00
parent e7067fea11
commit 3fcd8f90be

View File

@@ -384,8 +384,9 @@ class LTXModel(nn.Module):
video_config = config.get_video_config() video_config = config.get_video_config()
audio_config = config.get_audio_config() audio_config = config.get_audio_config()
self.transformer_blocks = [
BasicAVTransformerBlock( self.transformer_blocks = {
idx: BasicAVTransformerBlock(
idx=idx, idx=idx,
video=video_config, video=video_config,
audio=audio_config, audio=audio_config,
@@ -393,7 +394,7 @@ class LTXModel(nn.Module):
norm_eps=config.norm_eps, norm_eps=config.norm_eps,
) )
for idx in range(config.num_layers) for idx in range(config.num_layers)
] }
def _process_transformer_blocks( def _process_transformer_blocks(
self, self,
@@ -401,7 +402,7 @@ class LTXModel(nn.Module):
audio: Optional[TransformerArgs], audio: Optional[TransformerArgs],
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks.""" """Process through all transformer blocks."""
for block in self.transformer_blocks: for block in self.transformer_blocks.values():
video, audio = block(video=video, audio=audio) video, audio = block(video=video, audio=audio)
return video, audio return video, audio