This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,15 +1,14 @@
from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
LTXModelType,
LTXRopeType,
TransformerConfig,
)
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx_2.transformer import (
@@ -58,11 +57,17 @@ class TransformerArgsPreprocessor:
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb, embedded_timestep = self.adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
# Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
timestep_emb = mx.reshape(
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
)
embedded_timestep = mx.reshape(
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
)
return timestep_emb, embedded_timestep
@@ -74,9 +79,15 @@ class TransformerArgsPreprocessor:
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
timestep_emb, embedded_timestep = adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
timestep_emb = mx.reshape(
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
)
embedded_timestep = mx.reshape(
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
)
return timestep_emb, embedded_timestep
def _prepare_context(
@@ -107,7 +118,9 @@ class TransformerArgsPreprocessor:
# Convert boolean/int mask to float mask
# 0 -> -inf (masked), 1 -> 0 (not masked)
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
mask = mx.reshape(
mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
)
return mask
def _prepare_positional_embeddings(
@@ -132,9 +145,15 @@ class TransformerArgsPreprocessor:
def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
timestep, embedded_timestep = self._prepare_timestep(
modality.timesteps, x.shape[0], hidden_dtype=x.dtype
)
context, attention_mask = self._prepare_context(
modality.context, x, modality.context_mask
)
attention_mask = self._prepare_attention_mask(
attention_mask, modality.latent.dtype
)
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
if modality.positional_embeddings is not None:
@@ -152,8 +171,13 @@ class TransformerArgsPreprocessor:
prompt_timestep = None
prompt_embedded_timestep = None
if self.prompt_adaln is not None and modality.sigma is not None:
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
prompt_timestep, prompt_embedded_timestep = (
self._prepare_timestep_with_adaln(
self.prompt_adaln,
modality.sigma,
x.shape[0],
hidden_dtype=x.dtype,
)
)
return TransformerArgs(
@@ -229,11 +253,13 @@ class MultiModalTransformerArgsPreprocessor:
)
# Prepare cross-attention timestep embeddings
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
cross_scale_shift_timestep, cross_gate_timestep = (
self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
)
)
return replace(
@@ -254,17 +280,25 @@ class MultiModalTransformerArgsPreprocessor:
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
scale_shift_timestep = mx.reshape(
scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])
)
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(
timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype
)
gate_timestep = mx.reshape(
gate_timestep, (batch_size, -1, gate_timestep.shape[-1])
)
return scale_shift_timestep, gate_timestep
class LTXModel(nn.Module):
def __init__(self, config: LTXModelConfig):
super().__init__()
@@ -285,18 +319,25 @@ class LTXModel(nn.Module):
self._init_video(config)
if config.model_type.is_audio_enabled():
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
self.audio_positional_embedding_max_pos = (
config.audio_positional_embedding_max_pos
)
self.audio_num_attention_heads = config.audio_num_attention_heads
self.audio_inner_dim = config.audio_inner_dim
self._init_audio(config)
# Initialize cross-modal components
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
if (
config.model_type.is_video_enabled()
and config.model_type.is_audio_enabled()
):
cross_pe_max_pos = max(
config.positional_embedding_max_pos[0],
config.audio_positional_embedding_max_pos[0],
)
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
self.av_ca_timestep_scale_multiplier = (
config.av_ca_timestep_scale_multiplier
)
self.audio_cross_attention_dim = config.audio_cross_attention_dim
self._init_audio_video(config)
@@ -308,10 +349,14 @@ class LTXModel(nn.Module):
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=adaln_coefficient
)
if config.has_prompt_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
self.prompt_adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=2
)
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
@@ -323,13 +368,19 @@ class LTXModel(nn.Module):
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_patchify_proj = nn.Linear(
config.audio_in_channels, self.audio_inner_dim, bias=True
)
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient
)
if config.has_prompt_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim, embedding_coefficient=2
)
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels,
@@ -338,7 +389,9 @@ class LTXModel(nn.Module):
# Output components
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
self.audio_norm_out = nn.LayerNorm(
self.audio_inner_dim, eps=config.norm_eps, affine=False
)
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
def _init_audio_video(self, config: LTXModelConfig) -> None:
@@ -361,8 +414,13 @@ class LTXModel(nn.Module):
embedding_coefficient=1,
)
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
def _init_preprocessors(
self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]
) -> None:
if (
config.model_type.is_video_enabled()
and config.model_type.is_audio_enabled()
):
# Multi-modal preprocessors
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
@@ -468,7 +526,8 @@ class LTXModel(nn.Module):
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
for idx, block in self.transformer_blocks.items():
video, audio = block(
video=video, audio=audio,
video=video,
audio=audio,
skip_video_self_attn=(idx in stg_v_set),
skip_audio_self_attn=(idx in stg_a_set),
skip_cross_modal=skip_cross_modal,
@@ -483,7 +542,7 @@ class LTXModel(nn.Module):
x: mx.array,
embedded_timestep: mx.array,
) -> mx.array:
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
@@ -526,8 +585,12 @@ class LTXModel(nn.Module):
raise ValueError("Audio is not enabled for this model")
# Preprocess arguments
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
video_args = (
self.video_args_preprocessor.prepare(video) if video is not None else None
)
audio_args = (
self.audio_args_preprocessor.prepare(audio) if audio is not None else None
)
# Process transformer blocks
video_out, audio_out = self._process_transformer_blocks(
@@ -567,7 +630,7 @@ class LTXModel(nn.Module):
def sanitize(self, weights: dict) -> dict:
sanitized = {}
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
if not has_raw_prefix:
return weights
@@ -577,7 +640,10 @@ class LTXModel(nn.Module):
if not key.startswith("model.diffusion_model."):
continue
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
if (
"audio_embeddings_connector" in key
or "video_embeddings_connector" in key
):
continue
# Remove 'model.diffusion_model.' prefix
@@ -612,9 +678,11 @@ class LTXModel(nn.Module):
for weight_file in model_path.glob("*.safetensors"):
weights.update(mx.load(str(weight_file)))
sanitized = model.sanitize(weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
sanitized = {
k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v
for k, v in sanitized.items()
}
model.load_weights(list(sanitized.items()), strict=strict)
mx.eval(model.parameters())
@@ -625,7 +693,7 @@ class LTXModel(nn.Module):
class X0Model(nn.Module):
def __init__(self, velocity_model: LTXModel):
super().__init__()
self.velocity_model = velocity_model
@@ -639,13 +707,18 @@ class X0Model(nn.Module):
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model(
video, audio,
video,
audio,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
denoised_video = (
to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
)
denoised_audio = (
to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
)
return denoised_video, denoised_audio