Refactor LTXModel: Change transformer_blocks from list to dictionary

This commit is contained in:
Prince Canuma
2026-01-15 03:47:52 +01:00
parent e7067fea11
commit 3fcd8f90be

View File

@@ -384,8 +384,9 @@ class LTXModel(nn.Module):
video_config = config.get_video_config()
audio_config = config.get_audio_config()
self.transformer_blocks = [
BasicAVTransformerBlock(
self.transformer_blocks = {
idx: BasicAVTransformerBlock(
idx=idx,
video=video_config,
audio=audio_config,
@@ -393,7 +394,7 @@ class LTXModel(nn.Module):
norm_eps=config.norm_eps,
)
for idx in range(config.num_layers)
]
}
def _process_transformer_blocks(
self,
@@ -401,7 +402,7 @@ class LTXModel(nn.Module):
audio: Optional[TransformerArgs],
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks."""
for block in self.transformer_blocks:
for block in self.transformer_blocks.values():
video, audio = block(video=video, audio=audio)
return video, audio