Refactor LTXModel: Change transformer_blocks from list to dictionary
This commit is contained in:
@@ -384,8 +384,9 @@ class LTXModel(nn.Module):
|
|||||||
video_config = config.get_video_config()
|
video_config = config.get_video_config()
|
||||||
audio_config = config.get_audio_config()
|
audio_config = config.get_audio_config()
|
||||||
|
|
||||||
self.transformer_blocks = [
|
|
||||||
BasicAVTransformerBlock(
|
self.transformer_blocks = {
|
||||||
|
idx: BasicAVTransformerBlock(
|
||||||
idx=idx,
|
idx=idx,
|
||||||
video=video_config,
|
video=video_config,
|
||||||
audio=audio_config,
|
audio=audio_config,
|
||||||
@@ -393,7 +394,7 @@ class LTXModel(nn.Module):
|
|||||||
norm_eps=config.norm_eps,
|
norm_eps=config.norm_eps,
|
||||||
)
|
)
|
||||||
for idx in range(config.num_layers)
|
for idx in range(config.num_layers)
|
||||||
]
|
}
|
||||||
|
|
||||||
def _process_transformer_blocks(
|
def _process_transformer_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -401,7 +402,7 @@ class LTXModel(nn.Module):
|
|||||||
audio: Optional[TransformerArgs],
|
audio: Optional[TransformerArgs],
|
||||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||||
"""Process through all transformer blocks."""
|
"""Process through all transformer blocks."""
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks.values():
|
||||||
video, audio = block(video=video, audio=audio)
|
video, audio = block(video=video, audio=audio)
|
||||||
return video, audio
|
return video, audio
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user