Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.
This commit is contained in:
@@ -344,6 +344,7 @@ def denoise_distilled(
|
||||
context=text_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
sigma=mx.full((b,), sigma, dtype=dtype),
|
||||
)
|
||||
|
||||
audio_modality = None
|
||||
@@ -359,6 +360,7 @@ def denoise_distilled(
|
||||
context=audio_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
sigma=mx.full((ab,), sigma, dtype=dtype),
|
||||
)
|
||||
|
||||
velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
|
||||
@@ -493,6 +495,8 @@ def denoise_dev(
|
||||
else:
|
||||
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
||||
|
||||
sigma_array = mx.full((b,), sigma, dtype=dtype)
|
||||
|
||||
# Positive conditioning pass
|
||||
video_modality_pos = Modality(
|
||||
latent=latents_flat,
|
||||
@@ -502,6 +506,7 @@ def denoise_dev(
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_rope,
|
||||
sigma=sigma_array,
|
||||
)
|
||||
velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
|
||||
|
||||
@@ -523,6 +528,7 @@ def denoise_dev(
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
positional_embeddings=precomputed_rope,
|
||||
sigma=sigma_array,
|
||||
)
|
||||
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
|
||||
|
||||
@@ -957,10 +963,18 @@ def generate_video(
|
||||
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Read transformer config to detect model version
|
||||
import json
|
||||
transformer_config_path = model_path / "transformer" / "config.json"
|
||||
has_prompt_adaln = False
|
||||
if transformer_config_path.exists():
|
||||
with open(transformer_config_path) as f:
|
||||
has_prompt_adaln = json.load(f).get("has_prompt_adaln", False)
|
||||
|
||||
# Load text encoder
|
||||
with console.status("[blue]📝 Loading text encoder...[/]", spinner="dots"):
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
text_encoder = LTX2TextEncoder()
|
||||
text_encoder = LTX2TextEncoder(has_prompt_adaln=has_prompt_adaln)
|
||||
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||
mx.eval(text_encoder.parameters())
|
||||
console.print("[green]✓[/] Text encoder loaded")
|
||||
@@ -1084,7 +1098,10 @@ def generate_video(
|
||||
|
||||
# Upsample latents
|
||||
with console.status("[magenta]🔍 Upsampling latents 2x...[/]", spinner="dots"):
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
upscaler_files = sorted(model_path.glob("*spatial-upscaler-x2*.safetensors"))
|
||||
if not upscaler_files:
|
||||
raise FileNotFoundError(f"No spatial upscaler found in {model_path}")
|
||||
upsampler = load_upsampler(str(upscaler_files[0]))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = VideoDecoder.from_pretrained(str(model_path / "vae" / "decoder"))
|
||||
|
||||
Reference in New Issue
Block a user