Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.

This commit is contained in:
Prince Canuma
2026-03-10 16:47:36 +01:00
parent d028b239fb
commit 207c223354
8 changed files with 545 additions and 239 deletions

View File

@@ -301,15 +301,14 @@ def upsample_latents(
latent_std: mx.array,
debug: bool = False,
) -> mx.array:
# Un-normalize: latent * std + mean
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
latent = latent * latent_std + latent_mean
# Upsample
latent = upsampler(latent, debug=debug)
# Re-normalize: (latent - mean) / std
latent = (latent - latent_mean) / latent_std
@@ -350,19 +349,18 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
for key, value in raw_weights.items():
new_key = key
# LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
if key.startswith("upsampler.0."):
new_key = key.replace("upsampler.0.", "upsampler.conv.")
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 5:
if "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 4:
if "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map upsampler.conv to upsampler.conv (SpatialRationalResampler)
# Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel
if key.startswith("upsampler."):
new_key = key # Keep as is for SpatialRationalResampler
sanitized[new_key] = value
# Load weights