Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.

This commit is contained in:
Prince Canuma
2026-03-10 16:47:36 +01:00
parent d028b239fb
commit 207c223354
8 changed files with 545 additions and 239 deletions

View File

@@ -26,7 +26,7 @@ class TransformerArgsPreprocessor:
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
caption_projection: Optional[PixArtAlphaTextProjection],
inner_dim: int,
max_pos: List[int],
num_attention_heads: int,
@@ -35,10 +35,12 @@ class TransformerArgsPreprocessor:
positional_embedding_theta: float,
rope_type: LTXRopeType,
double_precision_rope: bool = False,
prompt_adaln: Optional[AdaLayerNormSingle] = None,
):
self.patchify_proj = patchify_proj
self.adaln = adaln
self.caption_projection = caption_projection
self.prompt_adaln = prompt_adaln
self.inner_dim = inner_dim
self.max_pos = max_pos
self.num_attention_heads = num_attention_heads
@@ -64,6 +66,19 @@ class TransformerArgsPreprocessor:
return timestep_emb, embedded_timestep
def _prepare_timestep_with_adaln(
self,
adaln: AdaLayerNormSingle,
timestep: mx.array,
batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
return timestep_emb, embedded_timestep
def _prepare_context(
self,
context: mx.array,
@@ -72,9 +87,8 @@ class TransformerArgsPreprocessor:
) -> Tuple[mx.array, Optional[mx.array]]:
batch_size = x.shape[0]
# Context is already processed through embeddings connector in text encoder
# Here we just apply the caption projection
context = self.caption_projection(context)
if self.caption_projection is not None:
context = self.caption_projection(context)
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
return context, attention_mask
@@ -134,6 +148,14 @@ class TransformerArgsPreprocessor:
num_attention_heads=self.num_attention_heads,
)
# Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps
prompt_timestep = None
prompt_embedded_timestep = None
if self.prompt_adaln is not None and modality.sigma is not None:
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
)
return TransformerArgs(
x=x,
context=context,
@@ -145,6 +167,8 @@ class TransformerArgsPreprocessor:
cross_scale_shift_timestep=None,
cross_gate_timestep=None,
enabled=modality.enabled,
prompt_timesteps=prompt_timestep,
prompt_embedded_timestep=prompt_embedded_timestep,
)
@@ -154,7 +178,7 @@ class MultiModalTransformerArgsPreprocessor:
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
caption_projection: Optional[PixArtAlphaTextProjection],
cross_scale_shift_adaln: AdaLayerNormSingle,
cross_gate_adaln: AdaLayerNormSingle,
inner_dim: int,
@@ -168,6 +192,7 @@ class MultiModalTransformerArgsPreprocessor:
rope_type: LTXRopeType,
av_ca_timestep_scale_multiplier: int,
double_precision_rope: bool = False,
prompt_adaln: Optional[AdaLayerNormSingle] = None,
):
self.simple_preprocessor = TransformerArgsPreprocessor(
patchify_proj=patchify_proj,
@@ -181,6 +206,7 @@ class MultiModalTransformerArgsPreprocessor:
positional_embedding_theta=positional_embedding_theta,
rope_type=rope_type,
double_precision_rope=double_precision_rope,
prompt_adaln=prompt_adaln,
)
self.cross_scale_shift_adaln = cross_scale_shift_adaln
self.cross_gate_adaln = cross_gate_adaln
@@ -280,11 +306,17 @@ class LTXModel(nn.Module):
def _init_video(self, config: LTXModelConfig) -> None:
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.inner_dim,
)
adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
if config.has_prompt_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.inner_dim,
)
self.scale_shift_table = mx.zeros((2, self.inner_dim))
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
@@ -292,13 +324,17 @@ class LTXModel(nn.Module):
def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels,
hidden_size=self.audio_inner_dim,
)
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
if config.has_prompt_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels,
hidden_size=self.audio_inner_dim,
)
# Output components
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
@@ -331,7 +367,7 @@ class LTXModel(nn.Module):
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
caption_projection=getattr(self, "caption_projection", None),
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
inner_dim=self.inner_dim,
@@ -345,11 +381,12 @@ class LTXModel(nn.Module):
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
caption_projection=getattr(self, "audio_caption_projection", None),
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
inner_dim=self.audio_inner_dim,
@@ -363,12 +400,13 @@ class LTXModel(nn.Module):
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
elif config.model_type.is_video_enabled():
self.video_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
caption_projection=getattr(self, "caption_projection", None),
inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
@@ -377,12 +415,13 @@ class LTXModel(nn.Module):
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
elif config.model_type.is_audio_enabled():
self.audio_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
caption_projection=getattr(self, "audio_caption_projection", None),
inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
@@ -391,13 +430,13 @@ class LTXModel(nn.Module):
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
video_config = config.get_video_config()
audio_config = config.get_audio_config()
self.transformer_blocks = {
idx: BasicAVTransformerBlock(
idx=idx,
@@ -405,6 +444,7 @@ class LTXModel(nn.Module):
audio=audio_config,
rope_type=config.rope_type,
norm_eps=config.norm_eps,
has_prompt_adaln=config.has_prompt_adaln,
)
for idx in range(config.num_layers)
}