Add LTX-2.3 model architecture with prompt-conditioned adaptive layer normalization (adaln) support. Introduce gating mechanisms in attention modules and update transformer configurations to accommodate new parameters. Refactor video and audio processing to utilize adaptive normalization, improving model flexibility and performance. Update weight loading and initialization logic to support dynamic block structures in the decoder.
This commit is contained in:
@@ -301,15 +301,14 @@ def upsample_latents(
|
||||
latent_std: mx.array,
|
||||
debug: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
# Un-normalize: latent * std + mean
|
||||
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
|
||||
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
|
||||
latent = latent * latent_std + latent_mean
|
||||
|
||||
|
||||
# Upsample
|
||||
latent = upsampler(latent, debug=debug)
|
||||
|
||||
|
||||
# Re-normalize: (latent - mean) / std
|
||||
latent = (latent - latent_mean) / latent_std
|
||||
|
||||
@@ -350,19 +349,18 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
|
||||
for key, value in raw_weights.items():
|
||||
new_key = key
|
||||
|
||||
# LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
|
||||
if key.startswith("upsampler.0."):
|
||||
new_key = key.replace("upsampler.0.", "upsampler.conv.")
|
||||
|
||||
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||
if "conv" in key and "weight" in key and value.ndim == 5:
|
||||
if "weight" in new_key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "conv" in key and "weight" in key and value.ndim == 4:
|
||||
if "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
# Map upsampler.conv to upsampler.conv (SpatialRationalResampler)
|
||||
# Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel
|
||||
if key.startswith("upsampler."):
|
||||
new_key = key # Keep as is for SpatialRationalResampler
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
# Load weights
|
||||
|
||||
Reference in New Issue
Block a user