Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.

This commit is contained in:
Prince Canuma
2026-03-10 16:47:36 +01:00
parent d028b239fb
commit 207c223354
8 changed files with 545 additions and 239 deletions

View File

@@ -344,6 +344,7 @@ def denoise_distilled(
context=text_embeddings,
context_mask=None,
enabled=True,
sigma=mx.full((b,), sigma, dtype=dtype),
)
audio_modality = None
@@ -359,6 +360,7 @@ def denoise_distilled(
context=audio_embeddings,
context_mask=None,
enabled=True,
sigma=mx.full((ab,), sigma, dtype=dtype),
)
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
@@ -493,6 +495,8 @@ def denoise_dev(
else:
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
sigma_array = mx.full((b,), sigma, dtype=dtype)
# Positive conditioning pass
video_modality_pos = Modality(
latent=latents_flat,
@@ -502,6 +506,7 @@ def denoise_dev(
context_mask=None,
enabled=True,
positional_embeddings=precomputed_rope,
sigma=sigma_array,
)
velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
@@ -523,6 +528,7 @@ def denoise_dev(
context_mask=None,
enabled=True,
positional_embeddings=precomputed_rope,
sigma=sigma_array,
)
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
@@ -957,10 +963,18 @@ def generate_video(
mx.random.seed(seed)
# Read transformer config to detect model version
import json
transformer_config_path = model_path / "transformer" / "config.json"
has_prompt_adaln = False
if transformer_config_path.exists():
with open(transformer_config_path) as f:
has_prompt_adaln = json.load(f).get("has_prompt_adaln", False)
# Load text encoder
with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"):
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder()
text_encoder = LTX2TextEncoder(has_prompt_adaln=has_prompt_adaln)
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters())
console.print("[green]✓[/] Text encoder loaded")
@@ -1084,7 +1098,10 @@ def generate_video(
# Upsample latents
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
if not upscaler_files:
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
upsampler = load_upsampler(str(upscaler_files[0]))
mx.eval(upsampler.parameters())
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))