Refactor LTXModel: Change transformer_blocks from list to dictionary
This commit is contained in:
@@ -384,8 +384,9 @@ class LTXModel(nn.Module):
|
||||
video_config = config.get_video_config()
|
||||
audio_config = config.get_audio_config()
|
||||
|
||||
self.transformer_blocks = [
|
||||
BasicAVTransformerBlock(
|
||||
|
||||
self.transformer_blocks = {
|
||||
idx: BasicAVTransformerBlock(
|
||||
idx=idx,
|
||||
video=video_config,
|
||||
audio=audio_config,
|
||||
@@ -393,7 +394,7 @@ class LTXModel(nn.Module):
|
||||
norm_eps=config.norm_eps,
|
||||
)
|
||||
for idx in range(config.num_layers)
|
||||
]
|
||||
}
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self,
|
||||
@@ -401,7 +402,7 @@ class LTXModel(nn.Module):
|
||||
audio: Optional[TransformerArgs],
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Process through all transformer blocks."""
|
||||
for block in self.transformer_blocks:
|
||||
for block in self.transformer_blocks.values():
|
||||
video, audio = block(video=video, audio=audio)
|
||||
return video, audio
|
||||
|
||||
|
||||
Reference in New Issue
Block a user