Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.
This commit is contained in:
@@ -26,7 +26,7 @@ class TransformerArgsPreprocessor:
|
||||
self,
|
||||
patchify_proj: nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: PixArtAlphaTextProjection,
|
||||
caption_projection: Optional[PixArtAlphaTextProjection],
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
num_attention_heads: int,
|
||||
@@ -35,10 +35,12 @@ class TransformerArgsPreprocessor:
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
double_precision_rope: bool = False,
|
||||
prompt_adaln: Optional[AdaLayerNormSingle] = None,
|
||||
):
|
||||
self.patchify_proj = patchify_proj
|
||||
self.adaln = adaln
|
||||
self.caption_projection = caption_projection
|
||||
self.prompt_adaln = prompt_adaln
|
||||
self.inner_dim = inner_dim
|
||||
self.max_pos = max_pos
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -64,6 +66,19 @@ class TransformerArgsPreprocessor:
|
||||
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_timestep_with_adaln(
|
||||
self,
|
||||
adaln: AdaLayerNormSingle,
|
||||
timestep: mx.array,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_context(
|
||||
self,
|
||||
context: mx.array,
|
||||
@@ -72,9 +87,8 @@ class TransformerArgsPreprocessor:
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# Context is already processed through embeddings connector in text encoder
|
||||
# Here we just apply the caption projection
|
||||
context = self.caption_projection(context)
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||
return context, attention_mask
|
||||
|
||||
@@ -134,6 +148,14 @@ class TransformerArgsPreprocessor:
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
|
||||
# Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps
|
||||
prompt_timestep = None
|
||||
prompt_embedded_timestep = None
|
||||
if self.prompt_adaln is not None and modality.sigma is not None:
|
||||
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
|
||||
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
|
||||
)
|
||||
|
||||
return TransformerArgs(
|
||||
x=x,
|
||||
context=context,
|
||||
@@ -145,6 +167,8 @@ class TransformerArgsPreprocessor:
|
||||
cross_scale_shift_timestep=None,
|
||||
cross_gate_timestep=None,
|
||||
enabled=modality.enabled,
|
||||
prompt_timesteps=prompt_timestep,
|
||||
prompt_embedded_timestep=prompt_embedded_timestep,
|
||||
)
|
||||
|
||||
|
||||
@@ -154,7 +178,7 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
self,
|
||||
patchify_proj: nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: PixArtAlphaTextProjection,
|
||||
caption_projection: Optional[PixArtAlphaTextProjection],
|
||||
cross_scale_shift_adaln: AdaLayerNormSingle,
|
||||
cross_gate_adaln: AdaLayerNormSingle,
|
||||
inner_dim: int,
|
||||
@@ -168,6 +192,7 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
rope_type: LTXRopeType,
|
||||
av_ca_timestep_scale_multiplier: int,
|
||||
double_precision_rope: bool = False,
|
||||
prompt_adaln: Optional[AdaLayerNormSingle] = None,
|
||||
):
|
||||
self.simple_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=patchify_proj,
|
||||
@@ -181,6 +206,7 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
rope_type=rope_type,
|
||||
double_precision_rope=double_precision_rope,
|
||||
prompt_adaln=prompt_adaln,
|
||||
)
|
||||
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
||||
self.cross_gate_adaln = cross_gate_adaln
|
||||
@@ -280,11 +306,17 @@ class LTXModel(nn.Module):
|
||||
|
||||
def _init_video(self, config: LTXModelConfig) -> None:
|
||||
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
|
||||
adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
|
||||
self.scale_shift_table = mx.zeros((2, self.inner_dim))
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
|
||||
@@ -292,13 +324,17 @@ class LTXModel(nn.Module):
|
||||
|
||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
||||
|
||||
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.audio_caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.audio_caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
|
||||
# Output components
|
||||
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
||||
@@ -331,7 +367,7 @@ class LTXModel(nn.Module):
|
||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=self.caption_projection,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
||||
inner_dim=self.inner_dim,
|
||||
@@ -345,11 +381,12 @@ class LTXModel(nn.Module):
|
||||
rope_type=config.rope_type,
|
||||
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=self.audio_caption_projection,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
@@ -363,12 +400,13 @@ class LTXModel(nn.Module):
|
||||
rope_type=config.rope_type,
|
||||
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
elif config.model_type.is_video_enabled():
|
||||
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=self.caption_projection,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=config.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
@@ -377,12 +415,13 @@ class LTXModel(nn.Module):
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
elif config.model_type.is_audio_enabled():
|
||||
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=self.audio_caption_projection,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=config.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
@@ -391,13 +430,13 @@ class LTXModel(nn.Module):
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
|
||||
video_config = config.get_video_config()
|
||||
audio_config = config.get_audio_config()
|
||||
|
||||
|
||||
self.transformer_blocks = {
|
||||
idx: BasicAVTransformerBlock(
|
||||
idx=idx,
|
||||
@@ -405,6 +444,7 @@ class LTXModel(nn.Module):
|
||||
audio=audio_config,
|
||||
rope_type=config.rope_type,
|
||||
norm_eps=config.norm_eps,
|
||||
has_prompt_adaln=config.has_prompt_adaln,
|
||||
)
|
||||
for idx in range(config.num_layers)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user