Merge pull request #22 from Blaizzy/pc/unify-apis

Unify APIs and add LTX-2.3
This commit is contained in:
Prince Canuma
2026-03-18 22:04:21 +01:00
committed by GitHub
105 changed files with 13091 additions and 6639 deletions

6
.gitignore vendored
View File

@@ -1,5 +1,9 @@
.env
claude.md
.claude/*
CLAUDE.md
config.json
*.safetensors
*.safetensors.index.json
.DS_Store
**.pyc
__pycache__/*

155
README.md
View File

@@ -4,8 +4,6 @@ MLX-Video is the best package for inference and finetuning of Image-Video-Audio
## Installation
Install from source:
### Option 1: Install with pip (requires git):
```bash
pip install git+https://github.com/Blaizzy/mlx-video.git
@@ -16,7 +14,7 @@ pip install git+https://github.com/Blaizzy/mlx-video.git
uv pip install git+https://github.com/Blaizzy/mlx-video.git
```
Supported models:
## Supported Models
- [**LTX-2**](https://huggingface.co/Lightricks/LTX-Video) — 19B parameter video generation model from Lightricks
- [**Wan2.1**](https://github.com/Wan-Video/Wan2.1) — 1.3B / 14B parameter T2V models (single-model pipeline)
@@ -24,36 +22,53 @@ Supported models:
## Features
- Text-to-video generation with multiple model families
- LTX-2: Two-stage pipeline with 2x spatial upscaling
- Wan2.1/2.2: Flow-matching diffusion with classifier-free guidance
**LTX-2 / LTX-2.3**
- Text-to-Video (T2V), Image-to-Video (I2V), Audio-to-Video (A2V)
- Audio-Video joint generation
- Multi-pipeline: distilled, dev, dev-two-stage, dev-two-stage-hq
- 2x spatial upscaling for images and videos
- Prompt enhancement via Gemma
**Wan2.1 / Wan2.2**
- Text-to-Video (T2V) — 1.3B and 14B models
- Image-to-Video (I2V) — 14B model
- Flow-matching diffusion with classifier-free guidance
- LoRA support (e.g. Wan2.2-Lightning for 4-step generation)
**General**
- Optimized for Apple Silicon using MLX
---
## LTX-2
> ** Info:** Currently, only the distilled variant is supported. Full LTX-2 feature support is coming soon.
### Text-to-Video Generation
```bash
uv run mlx_video.generate --prompt "Two dogs of the poodle breed wearing sunglasses, close up, cinematic, sunset" -n 100 --width 768
# Text-to-Video (distilled, fastest)
uv run mlx_video.ltx_2.generate --prompt "Two dogs wearing sunglasses, cinematic, sunset" -n 97 --width 768
# Image-to-Video
uv run mlx_video.ltx_2.generate --prompt "A person dancing" --image photo.jpg
# Audio-to-Video
uv run mlx_video.ltx_2.generate --audio-file music.wav --prompt "A band playing music"
# Dev pipeline with CFG (higher quality)
uv run mlx_video.ltx_2.generate --pipeline dev --prompt "A cinematic scene" --cfg-scale 3.0
# Dev two-stage HQ (highest quality)
uv run mlx_video.ltx_2.generate --pipeline dev-two-stage-hq \
--prompt "A cinematic scene of ocean waves at golden hour" \
--model-repo prince-canuma/LTX-2-dev
```
<img src="https://github.com/Blaizzy/mlx-video/raw/main/examples/poodles.gif" width="512" alt="Poodles demo">
With custom settings:
**Converting weights:**
Pre-converted weights are available on HuggingFace ([LTX-2-distilled](https://huggingface.co/prince-canuma/LTX-2-distilled), [LTX-2-dev](https://huggingface.co/prince-canuma/LTX-2-dev), [LTX-2.3-distilled](https://huggingface.co/prince-canuma/LTX-2.3-distilled), [LTX-2.3-dev](https://huggingface.co/prince-canuma/LTX-2.3-dev)), or convert from the original Lightricks checkpoint:
```bash
python -m mlx_video.generate \
--prompt "Ocean waves crashing on a beach at sunset" \
--height 768 \
--width 768 \
--num-frames 65 \
--seed 123 \
--output my_video.mp4
```
### LTX-2 CLI Options
@@ -69,33 +84,27 @@ python -m mlx_video.generate \
| `--save-frames` | false | Save individual frames as images |
| `--model-repo` | Lightricks/LTX-2 | HuggingFace model repository |
### How It Works (LTX-2)
1. **Stage 1**: Generate at half resolution (e.g., 384×384) with 8 denoising steps
2. **Upsample**: 2× spatial upsampling via LatentUpsampler
3. **Stage 2**: Refine at full resolution (e.g., 768×768) with 3 denoising steps
4. **Decode**: VAE decoder converts latents to RGB video
---
## Wan2.1 / Wan2.2
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
Both [Wan2.1](https://github.com/Wan-Video/Wan2.1) and [Wan2.2](https://github.com/Wan-Video/Wan2.2) are text-to-video diffusion models built on a DiT (Diffusion Transformer) backbone with a T5 text encoder and 3D VAE.
### Step 0: Download and Convert Weights
See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan/README.md) for details.
See the dedicated Wan2.1/Wan2.2 [README.md](mlx_video/models/wan_2/README.md) for details.
### Step 1: Generate Video
```bash
# Wan2.1 — uses defaults from config (50 steps, shift=5.0, guide=5.0)
python -m mlx_video.generate_wan \
python -m mlx_video.wan_2.generate \
--model-dir wan21_mlx \
--prompt "A cat playing piano in a cozy room"
# Wan2.2 — uses defaults from config (40 steps, shift=12.0, guide=3.0,4.0)
python -m mlx_video.generate_wan \
python -m mlx_video.wan_2.generate \
--model-dir wan22_mlx \
--prompt "A cat playing piano in a cozy room"
```
@@ -103,7 +112,7 @@ python -m mlx_video.generate_wan \
With custom settings:
```bash
python -m mlx_video.generate_wan \
python -m mlx_video.wan_2.generate \
--model-dir wan21_mlx \
--prompt "Ocean waves at sunset, cinematic, 4K" \
--negative-prompt "blurry, low quality" \
@@ -117,13 +126,12 @@ python -m mlx_video.generate_wan \
--output-path my_video.mp4
```
The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model). You can also override any parameter via CLI flags.
The pipeline auto-detects the model version from `config.json` and selects the right pipeline mode (single or dual model).
#### Image-to-Video (I2V-14B)
### Image-to-Video (I2V-14B)
```bash
# Generate video from an input image
python -m mlx_video.generate_wan \
python -m mlx_video.wan_2.generate \
--model-dir wan22_i2v_mlx \
--prompt "The camera slowly zooms in as the subject begins to move" \
--image start.png \
@@ -131,9 +139,30 @@ python -m mlx_video.generate_wan \
--output-path my_video.mp4
```
The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and uses channel concatenation (`y` tensor with 4 mask + 16 image latent channels) to condition generation on the first frame.
### LoRA Support
#### Generation CLI Options
LoRAs can be used with the `--lora-high` and `--lora-low` command line switches.
For example, using the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA for 4-step generation:
```bash
python -m mlx_video.wan_2.generate \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \
--height 704 \
--num-frames 41 \
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
--steps 4 \
--guide-scale 1 \
--trim-first-frames 1 \
--seed 2391784614 \
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
```
![Poodles](examples/poodles-wan.gif)
### Wan CLI Options
| Option | Default | Description |
|--------|---------|-------------|
@@ -150,29 +179,7 @@ The I2V-14B model encodes the input image through the Wan2.1 VAE encoder and use
| `--seed` | -1 (random) | Random seed for reproducibility |
| `--output-path` | `output.mp4` | Output video path |
## LoRA Support
LoRA's can be used with the `--lora-high` and `--lora-low` command line switches.
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
```bash
python -m mlx_video.generate_wan \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \
--height 704 \
--num-frames 41 \
--prompt "Two dogs of the poodle breed sitting on a beach wearing sunglasses, nodding with their heads, close up, cinematic, sunset" \
--steps 4 \
--guide-scale 1 \
--trim-first-frames 1 \
--seed 2391784614 \
--lora-high /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors 1 \
--lora-low /Volumes/SSD/Wan-AI/lightx2v/Wan2.2-Lightning/Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors 1
```
Which results in
![Poodles](examples/poodles-wan.gif)
---
## Requirements
@@ -181,36 +188,6 @@ Which results in
- MLX >= 0.22.0
- For weight conversion: PyTorch (`pip install torch`)
## Project Structure
```
mlx_video/
├── generate.py # LTX-2 generation pipeline
├── generate_wan.py # Wan2.1/2.2 generation pipeline
├── convert.py # LTX-2 weight conversion
├── convert_wan.py # Wan weight conversion (PyTorch → MLX)
├── postprocess.py # Video post-processing utilities
├── utils.py # Helper functions
└── models/
├── ltx/ # LTX-2 model
│ ├── ltx.py # DiT transformer
│ ├── config.py # Configuration
│ ├── transformer.py # Transformer blocks
│ ├── attention.py # Multi-head attention with RoPE
│ ├── text_encoder.py # Gemma 3 text encoder
│ ├── upsampler.py # 2x spatial upsampler
│ └── video_vae/ # VAE encoder/decoder
└── wan/ # Wan2.1/2.2 model
├── config.py # Configuration (2.1 & 2.2 presets)
├── model.py # WanModel (DiT transformer)
├── transformer.py # Attention blocks with 6-element modulation
├── attention.py # Self/cross attention with QK-norm
├── rope.py # 3-way factorized RoPE
├── text_encoder.py # T5 UMT5-XXL encoder
├── vae.py # 3D causal VAE decoder
└── scheduler.py # Flow-matching Euler scheduler
```
## License
MIT

View File

@@ -1,14 +1,50 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights
import os
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
# Audio VAE components
from mlx_video.models.ltx_2.audio_vae import (
AudioDecoder,
AudioEncoder,
AudioLatentShape,
AudioPatchifier,
PerChannelStatistics,
Vocoder,
decode_audio,
)
# Conditioning
from mlx_video.models.ltx_2.conditioning import VideoConditionByLatentIndex
# Utilities
from mlx_video.models.ltx_2.utils import (
convert_audio_encoder,
get_model_path,
load_config,
load_safetensors,
save_weights,
)
from mlx_video.models.wan_2 import WanModel, WanModelConfig
__all__ = [
# Models
"LTXModel",
"LTXModelConfig",
# Audio VAE
"AudioDecoder",
"AudioEncoder",
"Vocoder",
"decode_audio",
"AudioPatchifier",
"AudioLatentShape",
"PerChannelStatistics",
# Conditioning
"VideoConditionByLatentIndex",
# Utilities
"convert_audio_encoder",
"get_model_path",
"load_safetensors",
"load_config",
"save_weights",
# Wan Models
"WanModel",
"WanModelConfig",
"load_transformer_weights",
"load_vae_weights",
]
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"

View File

@@ -0,0 +1,3 @@
from .smart_turn import Model, ModelConfig
__all__ = ["Model", "ModelConfig"]

View File

@@ -1,3 +0,0 @@
"""Conditioning modules for LTX-2 video generation."""
from mlx_video.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning

View File

@@ -1,688 +0,0 @@
import json
import shutil
from pathlib import Path
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx.ltx import LTXModel
def get_model_path(
path_or_hf_repo: str,
revision: Optional[str] = None,
) -> Path:
"""Get local path to model, downloading if necessary.
Args:
path_or_hf_repo: Local path or HuggingFace repo ID
revision: Git revision for HF repo
Returns:
Path to model directory
"""
model_path = Path(path_or_hf_repo)
if model_path.exists():
return model_path
# Download from HuggingFace
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.safetensors",
"*.json",
"config.json",
],
)
)
return model_path
def load_safetensors(path: Path) -> Dict[str, mx.array]:
"""Load weights from safetensors file(s) using MLX.
Args:
path: Path to model directory or single safetensors file
Returns:
Dictionary of weights
"""
weights = {}
if path.is_file():
# Single file - use mx.load directly (handles bfloat16)
return mx.load(str(path))
else:
# Directory - load all safetensors files
safetensor_files = list(path.glob("*.safetensors"))
for sf_path in safetensor_files:
file_weights = mx.load(str(sf_path))
weights.update(file_weights)
return weights
def load_transformer_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load transformer weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of transformer weights
"""
# Try distilled model first, then dev
weight_files = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for weight_file in weight_files:
if weight_file.exists():
print(f"Loading transformer weights from {weight_file.name}...")
return mx.load(str(weight_file))
raise FileNotFoundError(f"No transformer weights found in {model_path}")
def load_vae_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load VAE weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of VAE weights
"""
vae_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
if vae_path.exists():
print(f"Loading VAE weights from {vae_path}...")
return mx.load(str(vae_path))
raise FileNotFoundError(f"VAE weights not found at {vae_path}")
def load_audio_vae_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load audio VAE weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of audio VAE weights
"""
# Try different possible paths for audio VAE weights
audio_vae_paths = [
model_path / "audio_vae" / "diffusion_pytorch_model.safetensors",
model_path / "audio_vae.safetensors",
]
# Also check in main model weights
main_paths = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for audio_path in audio_vae_paths:
if audio_path.exists():
print(f"Loading audio VAE weights from {audio_path}...")
return mx.load(str(audio_path))
# Check main model weights for audio_vae keys
for main_path in main_paths:
if main_path.exists():
print(f"Loading audio VAE weights from {main_path.name}...")
all_weights = mx.load(str(main_path))
# Filter to only audio_vae keys
audio_weights = {k: v for k, v in all_weights.items() if "audio_vae" in k}
if audio_weights:
return audio_weights
raise FileNotFoundError(f"Audio VAE weights not found in {model_path}")
def load_vocoder_weights(model_path: Path) -> Dict[str, mx.array]:
"""Load vocoder weights from LTX-2 model.
Args:
model_path: Path to LTX-2 model directory
Returns:
Dictionary of vocoder weights
"""
# Try different possible paths for vocoder weights
vocoder_paths = [
model_path / "vocoder" / "diffusion_pytorch_model.safetensors",
model_path / "vocoder.safetensors",
]
# Also check in main model weights
main_paths = [
model_path / "ltx-2-19b-distilled.safetensors",
model_path / "ltx-2-19b-dev.safetensors",
]
for vocoder_path in vocoder_paths:
if vocoder_path.exists():
print(f"Loading vocoder weights from {vocoder_path}...")
return mx.load(str(vocoder_path))
# Check main model weights for vocoder keys
for main_path in main_paths:
if main_path.exists():
print(f"Loading vocoder weights from {main_path.name}...")
all_weights = mx.load(str(main_path))
# Filter to only vocoder keys
vocoder_weights = {k: v for k, v in all_weights.items() if "vocoder" in k}
if vocoder_weights:
return vocoder_weights
raise FileNotFoundError(f"Vocoder weights not found in {model_path}")
def sanitize_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize transformer weight names from PyTorch LTX-2 format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for transformer
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model."):
continue
# Remove 'model.diffusion_model.' prefix
new_key = key.replace("model.diffusion_model.", "")
# Handle to_out.0 -> to_out (MLX doesn't use Sequential numbering)
new_key = new_key.replace(".to_out.0.", ".to_out.")
# Handle feed-forward net naming
# PyTorch: ff.net.0.proj -> ff.net_0_proj (or similar)
# MLX FeedForward: uses proj_in, proj_out
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
# Handle AdaLN naming - keep emb wrapper, just fix linear naming
# PyTorch: adaln_single.emb.timestep_embedder.linear_1 -> adaln_single.emb.timestep_embedder.linear1
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle caption projection (keep linear1/linear2 naming for compatibility)
# These are already mapped correctly in the sanitization
sanitized[new_key] = value
return sanitized
def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Only process VAE decoder weights (skip audio_vae, etc.)
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
continue
elif key.startswith("vae.decoder."):
# Strip the vae.decoder. prefix for decoder weights
new_key = key.replace("vae.decoder.", "")
else:
# Skip other vae.* keys that are not decoder weights
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_vae_encoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE encoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Only process VAE encoder weights
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics._mean_of_means"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics._std_of_means"
else:
# Skip other per_channel_statistics keys
continue
elif key.startswith("vae.encoder."):
# Strip the vae.encoder. prefix for encoder weights
new_key = key.replace("vae.encoder.", "")
else:
# Skip other vae.* keys that are not encoder weights
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for audio VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle audio_vae.decoder weights
if key.startswith("audio_vae.decoder."):
new_key = key.replace("audio_vae.decoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics._mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics._std_of_means"
else:
continue # Skip other statistics keys
else:
continue # Skip non-decoder keys
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_vocoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize vocoder weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for vocoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle vocoder weights
if key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
# Handle ModuleList indices -> dict keys
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
# Handle Conv1d weight shape conversion
# PyTorch: (out_channels, in_channels, kernel)
# MLX: (out_channels, kernel, in_channels)
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
def sanitize_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize weight names from PyTorch format to MLX format.
Generic function that handles both transformer and VAE weights.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Handle transformer weights
if key.startswith("model.diffusion_model."):
new_key = key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def load_config(model_path: Path) -> Dict[str, Any]:
"""Load model configuration.
Args:
model_path: Path to model directory
Returns:
Configuration dictionary
"""
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
return json.load(f)
# Return default config
return {}
def create_model_from_config(config: Dict[str, Any]) -> LTXModel:
"""Create model instance from configuration.
Args:
config: Configuration dictionary
Returns:
LTXModel instance
"""
# Map config to LTXModelConfig
model_config = LTXModelConfig(
model_type=LTXModelType.AudioVideo,
num_attention_heads=config.get("num_attention_heads", 32),
attention_head_dim=config.get("attention_head_dim", 128),
in_channels=config.get("in_channels", 128),
out_channels=config.get("out_channels", 128),
num_layers=config.get("num_layers", 48),
cross_attention_dim=config.get("cross_attention_dim", 4096),
caption_channels=config.get("caption_channels", 3840),
audio_num_attention_heads=config.get("audio_num_attention_heads", 32),
audio_attention_head_dim=config.get("audio_attention_head_dim", 64),
audio_in_channels=config.get("audio_in_channels", 128),
audio_out_channels=config.get("audio_out_channels", 128),
audio_cross_attention_dim=config.get("audio_cross_attention_dim", 2048),
positional_embedding_theta=config.get("positional_embedding_theta", 10000.0),
positional_embedding_max_pos=config.get("positional_embedding_max_pos", [20, 2048, 2048]),
audio_positional_embedding_max_pos=config.get("audio_positional_embedding_max_pos", [20]),
timestep_scale_multiplier=config.get("timestep_scale_multiplier", 1000),
av_ca_timestep_scale_multiplier=config.get("av_ca_timestep_scale_multiplier", 1000),
norm_eps=config.get("norm_eps", 1e-6),
)
return LTXModel(model_config)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
dtype: Optional[str] = None,
quantize: bool = False,
q_bits: int = 4,
q_group_size: int = 64,
) -> Path:
"""Convert HuggingFace model to MLX format.
Args:
hf_path: HuggingFace model path or repo ID
mlx_path: Output path for MLX model
dtype: Target dtype (float16, float32, bfloat16)
quantize: Whether to quantize the model
q_bits: Quantization bits
q_group_size: Quantization group size
Returns:
Path to converted model
"""
print(f"Loading model from {hf_path}...")
model_path = get_model_path(hf_path)
# Load config
config = load_config(model_path)
# Load weights
print("Loading weights...")
weights = load_safetensors(model_path)
# Sanitize weights
print("Sanitizing weights...")
weights = sanitize_weights(weights)
# Convert dtype if specified
if dtype is not None:
dtype_map = {
"float16": mx.float16,
"float32": mx.float32,
"bfloat16": mx.bfloat16,
}
target_dtype = dtype_map.get(dtype, mx.float16)
print(f"Converting to {dtype}...")
weights = {
k: v.astype(target_dtype) if v.dtype in [mx.float32, mx.float16, mx.bfloat16] else v
for k, v in weights.items()
}
# Create output directory
output_path = Path(mlx_path)
output_path.mkdir(parents=True, exist_ok=True)
# Save weights
print(f"Saving weights to {output_path}...")
save_weights(output_path, weights)
# Save config
config_out_path = output_path / "config.json"
with open(config_out_path, "w") as f:
json.dump(config, f, indent=2)
print(f"Model converted successfully to {output_path}")
return output_path
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
"""Save weights in safetensors format.
Args:
path: Output directory
weights: Dictionary of weights
"""
from safetensors.numpy import save_file
import numpy as np
# Convert to numpy for safetensors
np_weights = {k: np.array(v) for k, v in weights.items()}
# Save to file
save_file(np_weights, path / "model.safetensors")
def load_model(
path_or_hf_repo: str,
lazy: bool = False,
) -> LTXModel:
"""Load LTX model from path or HuggingFace.
Args:
path_or_hf_repo: Path to model or HuggingFace repo ID
lazy: Whether to use lazy loading
Returns:
Loaded LTXModel
"""
model_path = get_model_path(path_or_hf_repo)
# Load config
config = load_config(model_path)
# Create model
model = create_model_from_config(config)
# Load weights
weights = load_safetensors(model_path)
# Sanitize if needed
weights = sanitize_weights(weights)
# Load weights into model
model.load_weights(list(weights.items()))
if not lazy:
mx.eval(model.parameters())
return model
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert LTX-2 model to MLX format")
parser.add_argument(
"--hf-path",
type=str,
default="Lightricks/LTX-2",
help="HuggingFace model path or repo ID",
)
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_model",
help="Output path for MLX model",
)
parser.add_argument(
"--dtype",
type=str,
choices=["float16", "float32", "bfloat16"],
default="float16",
help="Target dtype",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize the model",
)
parser.add_argument(
"--q-bits",
type=int,
default=4,
help="Quantization bits",
)
args = parser.parse_args()
convert(
hf_path=args.hf_path,
mlx_path=args.mlx_path,
dtype=args.dtype,
quantize=args.quantize,
q_bits=args.q_bits,
)

View File

@@ -1,710 +0,0 @@
import argparse
import time
from pathlib import Path
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
# ANSI color codes
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder_weights
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state
from mlx_video.utils import get_model_path
# Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
def create_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
temporal_scale: int = 8,
spatial_scale: int = 32,
fps: float = 24.0,
causal_fix: bool = True,
) -> mx.array:
"""Create position grid for RoPE in pixel space.
Args:
batch_size: Batch size
num_frames: Number of frames (latent)
height: Height (latent)
width: Width (latent)
temporal_scale: VAE temporal scale factor (default 8)
spatial_scale: VAE spatial scale factor (default 32)
fps: Frames per second (default 24.0)
causal_fix: Apply causal fix for first frame (default True)
Returns:
Position grid of shape (B, 3, num_patches, 2) in pixel space
where dim 2 is [start, end) bounds for each patch
"""
# Patch size is (1, 1, 1) for LTX-2 - no spatial patching
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
# Generate grid coordinates for each dimension (frame, height, width)
t_coords = np.arange(0, num_frames, patch_size_t)
h_coords = np.arange(0, height, patch_size_h)
w_coords = np.arange(0, width, patch_size_w)
# Create meshgrid with indexing='ij' for (frame, height, width) order
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
# Stack to get shape (3, grid_t, grid_h, grid_w)
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
# Calculate end coordinates (start + patch_size)
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
patch_ends = patch_starts + patch_size_delta
# Stack start and end: shape (3, grid_t, grid_h, grid_w, 2)
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
# Flatten spatial/temporal dims: (3, num_patches, 2)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
# Broadcast to batch: (batch, 3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
# Convert latent coords to pixel coords by scaling with VAE factors
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
# Apply causal fix for first frame temporal axis
if causal_fix:
# VAE temporal stride for first frame is 1 instead of temporal_scale
pixel_coords[:, 0, :, :] = np.clip(
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
a_min=0,
a_max=None
)
# Convert temporal to time in seconds by dividing by fps
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
# Always return float32 for RoPE precision - bfloat16 causes quality degradation
return mx.array(pixel_coords, dtype=mx.float32)
def denoise(
latents: mx.array,
positions: mx.array,
text_embeddings: mx.array,
transformer: LTXModel,
sigmas: list,
verbose: bool = True,
state: Optional[LatentState] = None,
) -> mx.array:
"""Run denoising loop with optional conditioning.
Args:
latents: Noisy latent tensor (B, C, F, H, W)
positions: Position embeddings
text_embeddings: Text conditioning embeddings
transformer: LTX model
sigmas: List of sigma values for denoising schedule
verbose: Whether to show progress bar
state: Optional LatentState for I2V conditioning
Returns:
Denoised latent tensor
"""
# If state is provided, use its latent (which may have conditioning applied)
dtype = latents.dtype
if state is not None:
latents = state.latent
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
num_tokens = f * h * w
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
# Compute per-token timesteps
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
if state is not None:
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
# Per-token timesteps: sigma * mask (preserve dtype)
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
# All tokens get the same timestep (use latent dtype)
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
video_modality = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=text_embeddings,
context_mask=None,
enabled=True,
)
velocity, _ = transformer(video=video_modality, audio=None)
mx.eval(velocity)
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
# Apply conditioning mask if state is provided
if state is not None:
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
mx.eval(denoised)
# Euler step (preserve dtype by converting Python floats to arrays)
if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
else:
latents = denoised
mx.eval(latents)
return latents
def generate_video(
model_repo: str,
text_encoder_repo: str,
prompt: str,
height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
verbose: bool = True,
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
image: Optional[str] = None,
image_strength: float = 1.0,
image_frame_idx: int = 0,
tiling: str = "auto",
stream: bool = False,
):
"""Generate video from text prompt, optionally conditioned on an image.
Args:
model_repo: Model repository ID
text_encoder_repo: Text encoder repository ID
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
seed: Random seed for reproducibility
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
verbose: Whether to print progress
enhance_prompt: Whether to enhance prompt using Gemma
max_tokens: Max tokens for prompt enhancement
temperature: Temperature for prompt enhancement
image: Path to conditioning image for I2V (Image-to-Video)
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
image_frame_idx: Frame index to condition (0 = first frame)
tiling: Tiling mode for VAE decoding. Options:
- "auto": Automatically determine based on video size (default)
- "none": Disable tiling
- "default": 512px spatial, 64 frame temporal
- "aggressive": 256px spatial, 32 frame temporal (lowest memory)
- "conservative": 768px spatial, 96 frame temporal (faster)
- "spatial": Spatial tiling only
- "temporal": Temporal tiling only
"""
start_time = time.time()
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames
is_i2v = image is not None
mode_str = "I2V" if is_i2v else "T2V"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
# Get model path
model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
mx.random.seed(seed)
# Load text encoder
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder()
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters())
# Optionally enhance the prompt
if enhance_prompt:
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
model_dtype = text_embeddings.dtype # bfloat16 from text encoder
mx.eval(text_embeddings)
del text_encoder
mx.clear_cache()
# Load transformer
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
# Load VAE encoder and encode image for I2V conditioning
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
del vae_encoder
mx.clear_cache()
# Stage 1: Generate at half resolution
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
# Apply I2V conditioning if provided
state1 = None
if is_i2v and stage1_image_latent is not None:
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
# Create initial state with zeros
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
state1 = LatentState(
latent=mx.zeros(latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
state1 = apply_conditioning(state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 1, noise_scale = 1.0 (first sigma)
noise = mx.random.normal(latent_shape, dtype=model_dtype)
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
scaled_mask = state1.denoise_mask * noise_scale
state1 = LatentState(
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state1.clean_latent,
denoise_mask=state1.denoise_mask,
)
latents = state1.latent
mx.eval(latents)
else:
# T2V: just use random noise
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
# Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=None # Auto-detect from model metadata
)
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(latents)
del upsampler
mx.clear_cache()
# Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
# Apply I2V conditioning for stage 2 if provided
state2 = None
if is_i2v and stage2_image_latent is not None:
# PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser
state2 = LatentState(
latent=latents, # Start with upscaled latent
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
state2 = apply_conditioning(state2, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 2, noise_scale = stage_2_sigmas[0]
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
noise = mx.random.normal(latents.shape).astype(model_dtype)
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = state2.denoise_mask * noise_scale
state2 = LatentState(
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state2.clean_latent,
denoise_mask=state2.denoise_mask,
)
latents = state2.latent
mx.eval(latents)
else:
# T2V: add noise to all frames for refinement
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
noise = mx.random.normal(latents.shape).astype(model_dtype)
latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
del transformer
mx.clear_cache()
# Decode to video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
# Select tiling configuration
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Stream mode: write frames as they're decoded
video_writer = None
stream_pbar = None
if stream and tiling_config is not None:
import cv2
fourcc = cv2.VideoWriter_fourcc(*'avc1')
video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, (width, height))
stream_pbar = tqdm(total=num_frames, desc="Streaming", unit="frame")
def on_frames_ready(frames: mx.array, start_idx: int):
"""Callback to write frames as they're finalized."""
# frames: (B, 3, num_frames, H, W)
frames = mx.squeeze(frames, axis=0) # (3, num_frames, H, W)
frames = mx.transpose(frames, (1, 2, 3, 0)) # (num_frames, H, W, 3)
frames = mx.clip((frames + 1.0) / 2.0, 0.0, 1.0)
frames = (frames * 255).astype(mx.uint8)
frames_np = np.array(frames)
for frame in frames_np:
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
stream_pbar.update(1)
else:
on_frames_ready = None
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, tiling_mode=tiling, debug=verbose, on_frames_ready=on_frames_ready)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
# Close progressive video writer if used
if video_writer is not None:
video_writer.release()
if stream_pbar is not None:
stream_pbar.close()
print(f"{Colors.GREEN}✅ Streamed video to{Colors.RESET} {output_path}")
# Still need video_np for save_frames option
video = mx.squeeze(video, axis=0)
video = mx.transpose(video, (1, 2, 3, 0))
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
else:
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Save video normally
try:
import cv2
h, w = video_np.shape[1], video_np.shape[2]
fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"{Colors.GREEN}✅ Saved {len(video_np)} frames to {frames_dir}{Colors.RESET}")
elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2 (T2V and I2V)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Text-to-Video (T2V)
python -m mlx_video.generate --prompt "A cat walking on grass"
python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768
python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
# Image-to-Video (I2V)
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8
"""
)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text description of the video to generate"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Output video height (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Output video width (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--num-frames", "-n",
type=int,
default=100,
help="Number of frames (default: 100)"
)
parser.add_argument(
"--seed", "-s",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=24,
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output-path",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
)
parser.add_argument(
"--save-frames",
action="store_true",
help="Save individual frames as images"
)
parser.add_argument(
"--model-repo",
type=str,
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
parser.add_argument(
"--text-encoder-repo",
type=str,
default=None,
help="Text encoder repository to use (default: None)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Verbose output"
)
parser.add_argument(
"--enhance-prompt",
action="store_true",
help="Enhance the prompt using Gemma before generation"
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum number of tokens to generate (default: 512)"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Temperature for prompt enhancement (default: 0.7)"
)
parser.add_argument(
"--image", "-i",
type=str,
default=None,
help="Path to conditioning image for I2V (Image-to-Video) generation"
)
parser.add_argument(
"--image-strength",
type=float,
default=1.0,
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)"
)
parser.add_argument(
"--image-frame-idx",
type=int,
default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)"
)
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="Tiling mode for VAE decoding (default: auto). "
"auto=based on video size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
)
parser.add_argument(
"--stream",
action="store_true",
help="Stream frames to output file as they're decoded (requires tiling). Allows viewing partial results sooner."
)
args = parser.parse_args()
generate_video(
**vars(args)
)
if __name__ == "__main__":
main()

View File

@@ -1,821 +0,0 @@
"""Audio-Video generation pipeline for LTX-2."""
import argparse
import time
from pathlib import Path
from typing import Optional
import mlx.core as mx
import numpy as np
from tqdm import tqdm
# ANSI color codes
class Colors:
CYAN = "\033[96m"
BLUE = "\033[94m"
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
MAGENTA = "\033[95m"
BOLD = "\033[1m"
DIM = "\033[2m"
RESET = "\033[0m"
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
# Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
# Audio constants
AUDIO_SAMPLE_RATE = 24000 # Output audio sample rate
AUDIO_LATENT_SAMPLE_RATE = 16000 # VAE internal sample rate
AUDIO_HOP_LENGTH = 160
AUDIO_LATENT_DOWNSAMPLE_FACTOR = 4
AUDIO_LATENT_CHANNELS = 8 # Latent channels before patchifying
AUDIO_MEL_BINS = 16
AUDIO_LATENTS_PER_SECOND = AUDIO_LATENT_SAMPLE_RATE / AUDIO_HOP_LENGTH / AUDIO_LATENT_DOWNSAMPLE_FACTOR # 25
def create_video_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
temporal_scale: int = 8,
spatial_scale: int = 32,
fps: float = 24.0,
causal_fix: bool = True,
) -> mx.array:
"""Create position grid for video RoPE in pixel space."""
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
t_coords = np.arange(0, num_frames, patch_size_t)
h_coords = np.arange(0, height, patch_size_h)
w_coords = np.arange(0, width, patch_size_w)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
patch_ends = patch_starts + patch_size_delta
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
if causal_fix:
pixel_coords[:, 0, :, :] = np.clip(
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
a_min=0,
a_max=None
)
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
return mx.array(pixel_coords, dtype=mx.float32)
def create_audio_position_grid(
batch_size: int,
audio_frames: int,
sample_rate: int = AUDIO_LATENT_SAMPLE_RATE,
hop_length: int = AUDIO_HOP_LENGTH,
downsample_factor: int = AUDIO_LATENT_DOWNSAMPLE_FACTOR,
is_causal: bool = True,
) -> mx.array:
"""Create temporal position grid for audio RoPE.
Audio positions are timestamps in seconds, shape (B, 1, T, 2).
Matches PyTorch's AudioPatchifier.get_patch_grid_bounds exactly.
"""
def get_audio_latent_time_in_sec(start_idx: int, end_idx: int) -> np.ndarray:
"""Convert latent indices to seconds (matching PyTorch's _get_audio_latent_time_in_sec)."""
latent_frame = np.arange(start_idx, end_idx, dtype=np.float32)
mel_frame = latent_frame * downsample_factor
if is_causal:
# Frame offset for causal alignment (PyTorch uses +1 - downsample_factor)
mel_frame = np.clip(mel_frame + 1 - downsample_factor, 0, None)
return mel_frame * hop_length / sample_rate
# Start times: latent indices 0 to audio_frames
start_times = get_audio_latent_time_in_sec(0, audio_frames)
# End times: latent indices 1 to audio_frames+1 (shifted by 1)
end_times = get_audio_latent_time_in_sec(1, audio_frames + 1)
# Shape: (B, 1, T, 2)
positions = np.stack([start_times, end_times], axis=-1)
positions = positions[np.newaxis, np.newaxis, :, :] # (1, 1, T, 2)
positions = np.tile(positions, (batch_size, 1, 1, 1))
return mx.array(positions, dtype=mx.float32)
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
"""Compute number of audio latent frames given video duration."""
duration = num_video_frames / fps
return round(duration * AUDIO_LATENTS_PER_SECOND)
def denoise_av(
video_latents: mx.array,
audio_latents: mx.array,
video_positions: mx.array,
audio_positions: mx.array,
video_embeddings: mx.array,
audio_embeddings: mx.array,
transformer: LTXModel,
sigmas: list,
verbose: bool = True,
video_state: Optional[LatentState] = None,
) -> tuple[mx.array, mx.array]:
"""Run denoising loop for audio-video generation with optional I2V conditioning.
Args:
video_latents: Video latent tensor (B, C, F, H, W)
audio_latents: Audio latent tensor (B, C, T, F)
video_positions: Video position embeddings
audio_positions: Audio position embeddings
video_embeddings: Video text embeddings
audio_embeddings: Audio text embeddings
transformer: LTX model
sigmas: List of sigma values
verbose: Whether to show progress bar
video_state: Optional LatentState for I2V conditioning
Returns:
Tuple of (video_latents, audio_latents)
"""
dtype = video_latents.dtype
# If video state is provided, use its latent
if video_state is not None:
video_latents = video_state.latent
for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
# Flatten video latents
b, c, f, h, w = video_latents.shape
num_video_tokens = f * h * w
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
ab, ac, at, af = audio_latents.shape
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
# Compute per-token timesteps for video
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
if video_state is not None:
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
# Per-token timesteps: sigma * mask
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
# All tokens get the same timestep
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
video_modality = Modality(
latent=video_flat,
timesteps=video_timesteps,
positions=video_positions,
context=video_embeddings,
context_mask=None,
enabled=True,
)
audio_modality = Modality(
latent=audio_flat,
timesteps=mx.full((ab, at), sigma, dtype=dtype),
positions=audio_positions,
context=audio_embeddings,
context_mask=None,
enabled=True,
)
video_velocity, audio_velocity = transformer(video=video_modality, audio=audio_modality)
mx.eval(video_velocity, audio_velocity)
# Reshape velocities back
video_velocity = mx.reshape(mx.transpose(video_velocity, (0, 2, 1)), (b, c, f, h, w))
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) # (B, C, T, F)
# Compute denoised
video_denoised = to_denoised(video_latents, video_velocity, sigma)
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
# Apply conditioning mask for video if state is provided
if video_state is not None:
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
mx.eval(video_denoised, audio_denoised)
# Euler step - use dtype-preserving arrays to avoid float32 promotion
if sigma_next > 0:
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype)
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
else:
video_latents = video_denoised
audio_latents = audio_denoised
mx.eval(video_latents, audio_latents)
return video_latents, audio_latents
def load_audio_decoder(model_path: Path):
"""Load audio VAE decoder."""
from mlx_video.models.ltx.audio_vae import AudioDecoder, CausalityAxis, NormType
decoder = AudioDecoder(
ch=128,
out_ch=2, # stereo
ch_mult=(1, 2, 4),
num_res_blocks=2,
attn_resolutions={8, 16, 32},
resolution=256,
z_channels=AUDIO_LATENT_CHANNELS,
norm_type=NormType.PIXEL,
causality_axis=CausalityAxis.HEIGHT,
mel_bins=64, # Output mel bins
)
# Load weights from main model file
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_audio_vae_weights(raw_weights)
if sanitized:
decoder.load_weights(list(sanitized.items()), strict=False)
# Manually load per-channel statistics (they're plain mx.array, not tracked by load_weights)
if "per_channel_statistics._mean_of_means" in sanitized:
decoder.per_channel_statistics._mean_of_means = sanitized["per_channel_statistics._mean_of_means"]
if "per_channel_statistics._std_of_means" in sanitized:
decoder.per_channel_statistics._std_of_means = sanitized["per_channel_statistics._std_of_means"]
return decoder
def load_vocoder(model_path: Path):
"""Load vocoder for mel to waveform conversion."""
from mlx_video.models.ltx.audio_vae import Vocoder
vocoder = Vocoder(
resblock_kernel_sizes=[3, 7, 11],
upsample_rates=[6, 5, 2, 2, 2],
upsample_kernel_sizes=[16, 15, 8, 4, 4],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_initial_channel=1024,
stereo=True,
output_sample_rate=AUDIO_SAMPLE_RATE,
)
# Load weights
weight_file = model_path / "ltx-2-19b-distilled.safetensors"
if weight_file.exists():
raw_weights = mx.load(str(weight_file))
sanitized = sanitize_vocoder_weights(raw_weights)
if sanitized:
vocoder.load_weights(list(sanitized.items()), strict=False)
return vocoder
def save_audio(audio: np.ndarray, path: Path, sample_rate: int = AUDIO_SAMPLE_RATE):
"""Save audio to WAV file."""
import wave
# Ensure audio is in correct format (channels, samples) or (samples,)
if audio.ndim == 2:
# (channels, samples) -> (samples, channels)
audio = audio.T
# Normalize and convert to int16
audio = np.clip(audio, -1.0, 1.0)
audio_int16 = (audio * 32767).astype(np.int16)
with wave.open(str(path), 'wb') as wf:
wf.setnchannels(2 if audio_int16.ndim == 2 else 1)
wf.setsampwidth(2) # 16-bit
wf.setframerate(sample_rate)
wf.writeframes(audio_int16.tobytes())
def mux_video_audio(video_path: Path, audio_path: Path, output_path: Path):
"""Combine video and audio into final output using ffmpeg."""
import subprocess
cmd = [
"ffmpeg", "-y",
"-i", str(video_path),
"-i", str(audio_path),
"-c:v", "copy",
"-c:a", "aac",
"-shortest",
str(output_path)
]
try:
subprocess.run(cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print(f"{Colors.RED}FFmpeg error: {e.stderr.decode()}{Colors.RESET}")
return False
except FileNotFoundError:
print(f"{Colors.RED}FFmpeg not found. Please install ffmpeg.{Colors.RESET}")
return False
def generate_video_with_audio(
model_repo: str,
text_encoder_repo: Optional[str],
prompt: str,
height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output_av.mp4",
output_audio_path: Optional[str] = None,
verbose: bool = True,
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
image: Optional[str] = None,
image_strength: float = 1.0,
image_frame_idx: int = 0,
tiling: str = "auto",
):
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image.
Args:
model_repo: Model repository ID
text_encoder_repo: Text encoder repository ID
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames
seed: Random seed
fps: Frames per second
output_path: Output video path
output_audio_path: Output audio path
verbose: Whether to print progress
enhance_prompt: Whether to enhance prompt using Gemma
max_tokens: Max tokens for prompt enhancement
temperature: Temperature for prompt enhancement
image: Path to conditioning image for I2V
image_strength: Conditioning strength (1.0 = full denoise)
image_frame_idx: Frame index to condition (0 = first frame)
tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal)
"""
start_time = time.time()
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Adjusted frames to {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames
# Calculate audio frames
audio_frames = compute_audio_frames(num_frames, fps)
is_i2v = image is not None
mode_str = "I2V+Audio" if is_i2v else "T2V+Audio"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
mx.random.seed(seed)
# Load text encoder with audio embeddings
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder()
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters())
# Optionally enhance prompt
if enhance_prompt:
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
# Get both video and audio embeddings
video_embeddings, audio_embeddings = text_encoder(prompt)
model_dtype = video_embeddings.dtype # bfloat16 from text encoder
mx.eval(video_embeddings, audio_embeddings)
del text_encoder
mx.clear_cache()
# Load transformer with AudioVideo config
print(f"{Colors.BLUE}🤖 Loading transformer (A/V mode)...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig(
model_type=LTXModelType.AudioVideo,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
# Audio config
audio_num_attention_heads=32,
audio_attention_head_dim=64,
audio_in_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS, # 8 * 16 = 128
audio_out_channels=AUDIO_LATENT_CHANNELS * AUDIO_MEL_BINS,
audio_cross_attention_dim=2048,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
audio_positional_embedding_max_pos=[20],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
# Load VAE encoder and encode image for I2V conditioning
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
# Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
del vae_encoder
mx.clear_cache()
# Initialize latents
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
# Create position grids - MUST stay float32 for RoPE precision
# bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32
audio_positions = create_audio_position_grid(1, audio_frames) # float32
mx.eval(video_positions, audio_positions)
# Apply I2V conditioning for stage 1 if provided
video_state1 = None
video_latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
if is_i2v and stage1_image_latent is not None:
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
video_state1 = LatentState(
latent=mx.zeros(video_latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
video_state1 = apply_conditioning(video_state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
noise = mx.random.normal(video_latent_shape).astype(model_dtype)
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
scaled_mask = video_state1.denoise_mask * noise_scale
video_state1 = LatentState(
latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=video_state1.clean_latent,
denoise_mask=video_state1.denoise_mask,
)
video_latents = video_state1.latent
mx.eval(video_latents)
else:
# T2V: just use random noise
video_latents = mx.random.normal(video_latent_shape).astype(model_dtype)
mx.eval(video_latents)
# Audio always uses pure noise (no I2V for audio)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_latents)
# Stage 1 denoising
video_latents, audio_latents = denoise_av(
video_latents, audio_latents,
video_positions, audio_positions,
video_embeddings, audio_embeddings,
transformer, STAGE_1_SIGMAS, verbose=verbose,
video_state=video_state1
)
# Upsample video latents
print(f"{Colors.MAGENTA}🔍 Upsampling video latents 2x...{Colors.RESET}")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=None # Auto-detect from model metadata
)
video_latents = upsample_latents(video_latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(video_latents)
del upsampler
mx.clear_cache()
# Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
# Position grids stay float32 for RoPE precision
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32
mx.eval(video_positions)
# Apply I2V conditioning for stage 2 if provided
video_state2 = None
if is_i2v and stage2_image_latent is not None:
# PyTorch flow: start with upscaled latent -> apply conditioning -> apply noiser
video_state2 = LatentState(
latent=video_latents, # Start with upscaled latent
clean_latent=mx.zeros_like(video_latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
video_state2 = apply_conditioning(video_state2, [conditioning])
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = video_state2.denoise_mask * noise_scale
video_state2 = LatentState(
latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=video_state2.clean_latent,
denoise_mask=video_state2.denoise_mask,
)
video_latents = video_state2.latent
mx.eval(video_latents)
# Audio still gets noise (no I2V for audio)
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents)
else:
# T2V: add noise to all frames for refinement
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
video_latents = video_noise * noise_scale + video_latents * one_minus_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(video_latents, audio_latents)
video_latents, audio_latents = denoise_av(
video_latents, audio_latents,
video_positions, audio_positions,
video_embeddings, audio_embeddings,
transformer, STAGE_2_SIGMAS, verbose=verbose,
video_state=video_state2
)
del transformer
mx.clear_cache()
# Decode video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
# Select tiling configuration
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(video_latents)
mx.eval(video)
# Convert video to uint8 frames
video = mx.squeeze(video, axis=0)
video = mx.transpose(video, (1, 2, 3, 0))
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Decode audio
print(f"{Colors.BLUE}🔊 Decoding audio...{Colors.RESET}")
audio_decoder = load_audio_decoder(model_path)
vocoder = load_vocoder(model_path)
mx.eval(audio_decoder.parameters(), vocoder.parameters())
mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram)
# Audio decoder output is already in vocoder format (B, C, T, F)
audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform)
audio_np = np.array(audio_waveform)
if audio_np.ndim == 3:
audio_np = audio_np[0] # Remove batch dim
del audio_decoder, vocoder, vae_decoder
mx.clear_cache()
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save video (temporary without audio)
temp_video_path = output_path.with_suffix('.temp.mp4')
try:
import cv2
h, w = video_np.shape[1], video_np.shape[2]
fourcc = cv2.VideoWriter_fourcc(*'avc1')
out = cv2.VideoWriter(str(temp_video_path), fourcc, fps, (w, h))
for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"{Colors.GREEN}✅ Video encoded{Colors.RESET}")
except Exception as e:
print(f"{Colors.RED}❌ Video encoding failed: {e}{Colors.RESET}")
return None, None
# Save audio
audio_path = output_path.with_suffix('.wav') if output_audio_path is None else Path(output_audio_path)
save_audio(audio_np, audio_path, AUDIO_SAMPLE_RATE)
print(f"{Colors.GREEN}✅ Saved audio to{Colors.RESET} {audio_path}")
# Mux video and audio
print(f"{Colors.BLUE}🎬 Combining video and audio...{Colors.RESET}")
if mux_video_audio(temp_video_path, audio_path, output_path):
print(f"{Colors.GREEN}✅ Saved video with audio to{Colors.RESET} {output_path}")
temp_video_path.unlink() # Remove temp file
else:
# Fallback: keep video without audio
temp_video_path.rename(output_path)
print(f"{Colors.YELLOW}⚠️ Saved video without audio to{Colors.RESET} {output_path}")
elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s{Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np, audio_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with synchronized audio using MLX LTX-2 (T2V and I2V)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Text-to-Video with Audio (T2V+Audio)
python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach"
python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt
python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav
# Image-to-Video with Audio (I2V+Audio)
python -m mlx_video.generate_av --prompt "A person dancing" --image photo.jpg
python -m mlx_video.generate_av --prompt "Waves crashing" --image beach.png --image-strength 0.8
"""
)
parser.add_argument("--prompt", "-p", type=str, required=True,
help="Text description of the video/audio to generate")
parser.add_argument("--height", "-H", type=int, default=512,
help="Output video height (default: 512)")
parser.add_argument("--width", "-W", type=int, default=512,
help="Output video width (default: 512)")
parser.add_argument("--num-frames", "-n", type=int, default=65,
help="Number of frames (default: 65)")
parser.add_argument("--seed", "-s", type=int, default=42,
help="Random seed (default: 42)")
parser.add_argument("--fps", type=int, default=24,
help="Frames per second (default: 24)")
parser.add_argument("--output-path", type=str, default="output_av.mp4",
help="Output video path (default: output_av.mp4)")
parser.add_argument("--output-audio", type=str, default=None,
help="Output audio path (default: same as video with .wav)")
parser.add_argument("--model-repo", type=str, default="Lightricks/LTX-2",
help="Model repository (default: Lightricks/LTX-2)")
parser.add_argument("--text-encoder-repo", type=str, default=None,
help="Text encoder repository")
parser.add_argument("--verbose", action="store_true",
help="Verbose output")
parser.add_argument("--enhance-prompt", action="store_true",
help="Enhance prompt using Gemma")
parser.add_argument("--max-tokens", type=int, default=512,
help="Max tokens for prompt enhancement")
parser.add_argument("--temperature", type=float, default=0.7,
help="Temperature for prompt enhancement")
parser.add_argument("--image", "-i", type=str, default=None,
help="Path to conditioning image for I2V (Image-to-Video) generation")
parser.add_argument("--image-strength", type=float, default=1.0,
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
parser.add_argument("--image-frame-idx", type=int, default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)")
parser.add_argument("--tiling", type=str, default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="Tiling mode for VAE decoding (default: auto). "
"auto=based on size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f")
args = parser.parse_args()
generate_video_with_audio(
model_repo=args.model_repo,
text_encoder_repo=args.text_encoder_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output_path,
output_audio_path=args.output_audio,
verbose=args.verbose,
enhance_prompt=args.enhance_prompt,
max_tokens=args.max_tokens,
temperature=args.temperature,
image=args.image,
image_strength=args.image_strength,
image_frame_idx=args.image_frame_idx,
tiling=args.tiling,
)
if __name__ == "__main__":
main()

View File

@@ -6,10 +6,7 @@ from mlx_video.lora.apply import (
apply_loras_to_model,
apply_loras_to_weights,
)
from mlx_video.lora.loader import (
load_lora_weights,
load_multiple_loras,
)
from mlx_video.lora.loader import load_lora_weights, load_multiple_loras
from mlx_video.lora.types import AppliedLoRA, LoRAConfig, LoRAWeights
__all__ = [

View File

@@ -66,7 +66,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
candidates = [lora_key]
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
candidates.append(lora_key[len(prefix):])
candidates.append(lora_key[len(prefix) :])
for candidate in candidates:
# Try as-is
@@ -80,33 +80,36 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
transformed = transformed.replace(".ffn.0.", ".ffn.fc1.")
transformed = transformed.replace(".ffn.2.", ".ffn.fc2.")
if transformed.endswith(".ffn.0"):
transformed = transformed[:-len(".ffn.0")] + ".ffn.fc1"
transformed = transformed[: -len(".ffn.0")] + ".ffn.fc1"
if transformed.endswith(".ffn.2"):
transformed = transformed[:-len(".ffn.2")] + ".ffn.fc2"
transformed = transformed[: -len(".ffn.2")] + ".ffn.fc2"
# Text embedding: text_embedding.0 → text_embedding_0
transformed = transformed.replace("text_embedding.0.", "text_embedding_0.")
transformed = transformed.replace("text_embedding.2.", "text_embedding_1.")
if transformed.endswith("text_embedding.0"):
transformed = transformed[:-len("text_embedding.0")] + "text_embedding_0"
transformed = transformed[: -len("text_embedding.0")] + "text_embedding_0"
if transformed.endswith("text_embedding.2"):
transformed = transformed[:-len("text_embedding.2")] + "text_embedding_1"
transformed = transformed[: -len("text_embedding.2")] + "text_embedding_1"
# Time embedding: time_embedding.0 → time_embedding_0
transformed = transformed.replace("time_embedding.0.", "time_embedding_0.")
transformed = transformed.replace("time_embedding.2.", "time_embedding_1.")
if transformed.endswith("time_embedding.0"):
transformed = transformed[:-len("time_embedding.0")] + "time_embedding_0"
transformed = transformed[: -len("time_embedding.0")] + "time_embedding_0"
if transformed.endswith("time_embedding.2"):
transformed = transformed[:-len("time_embedding.2")] + "time_embedding_1"
transformed = transformed[: -len("time_embedding.2")] + "time_embedding_1"
# Time projection: time_projection.1 → time_projection
transformed = transformed.replace("time_projection.1.", "time_projection.")
if transformed.endswith("time_projection.1"):
transformed = transformed[:-len("time_projection.1")] + "time_projection"
transformed = transformed[: -len("time_projection.1")] + "time_projection"
# Patch embedding: patch_embedding → patch_embedding_proj
if "patch_embedding" in transformed and "patch_embedding_proj" not in transformed:
if (
"patch_embedding" in transformed
and "patch_embedding_proj" not in transformed
):
transformed = transformed.replace("patch_embedding", "patch_embedding_proj")
if f"{transformed}.weight" in model_keys or transformed in model_keys:
@@ -115,7 +118,7 @@ def _normalize_wan_lora_key(lora_key: str, model_keys: set) -> str:
# Return best attempt with prefix stripped
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key[len(prefix) :]
return lora_key
@@ -134,21 +137,25 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
normalized = lora_key[len(prefix):]
normalized = lora_key[len(prefix) :]
if f"{normalized}.weight" in model_keys or normalized in model_keys:
return normalized
transformed = normalized
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed[: -len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
transformed = transformed.replace(".ff.net.2.", ".ff.proj_out.")
transformed = transformed.replace(".ff.net.2", ".ff.proj_out")
transformed = transformed.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
transformed = transformed.replace(".audio_ff.net.0.proj", ".audio_ff.proj_in")
transformed = transformed.replace(
".audio_ff.net.0.proj.", ".audio_ff.proj_in."
)
transformed = transformed.replace(
".audio_ff.net.0.proj", ".audio_ff.proj_in"
)
transformed = transformed.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
transformed = transformed.replace(".audio_ff.net.2", ".audio_ff.proj_out")
@@ -158,7 +165,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
# Try transformations on the original key
transformed = lora_key
if transformed.endswith(".to_out.0"):
transformed = transformed[:-len(".to_out.0")] + ".to_out"
transformed = transformed[: -len(".to_out.0")] + ".to_out"
transformed = transformed.replace(".to_out.0.", ".to_out.")
transformed = transformed.replace(".ff.net.0.proj.", ".ff.proj_in.")
transformed = transformed.replace(".ff.net.0.proj", ".ff.proj_in")
@@ -170,7 +177,7 @@ def _normalize_ltx_lora_key(lora_key: str, model_keys: set) -> str:
for prefix in prefixes_to_strip:
if lora_key.startswith(prefix):
return lora_key[len(prefix):]
return lora_key[len(prefix) :]
return lora_key
@@ -226,7 +233,9 @@ def apply_loras_to_weights(
skipped_count += 1
skipped_modules.append(module_name)
if verbose and skipped_count <= 5:
print(f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND")
print(
f" DEBUG: '{module_name}' -> '{normalized_name}' -> NOT FOUND"
)
similar = [
k
for k in list(model_keys)[:1000]
@@ -251,13 +260,21 @@ def apply_loras_to_weights(
if is_quantized:
scales = modified_weights[scales_key]
biases = modified_weights[biases_key]
group_size = (original_weight.shape[-1] * 32) // (scales.shape[-1] * quantization_bits)
group_size = (original_weight.shape[-1] * 32) // (
scales.shape[-1] * quantization_bits
)
dequantized = mx.dequantize(
original_weight, scales, biases, group_size=group_size, bits=quantization_bits
original_weight,
scales,
biases,
group_size=group_size,
bits=quantization_bits,
)
modified = apply_lora_to_linear(dequantized, loras)
# Re-quantize with same parameters
new_w, new_scales, new_biases = mx.quantize(modified, group_size=group_size, bits=quantization_bits)
new_w, new_scales, new_biases = mx.quantize(
modified, group_size=group_size, bits=quantization_bits
)
modified_weights[weight_key] = new_w
modified_weights[scales_key] = new_scales
modified_weights[biases_key] = new_biases
@@ -346,9 +363,15 @@ def apply_loras_to_model(
parent = model
try:
for part in parts[:-1]:
parent = getattr(parent, part) if not part.isdigit() else parent[int(part)]
parent = (
getattr(parent, part) if not part.isdigit() else parent[int(part)]
)
leaf_name = parts[-1]
target = getattr(parent, leaf_name) if not leaf_name.isdigit() else parent[int(leaf_name)]
target = (
getattr(parent, leaf_name)
if not leaf_name.isdigit()
else parent[int(leaf_name)]
)
except (AttributeError, IndexError, TypeError):
skipped.append(lora_key)
if verbose:
@@ -358,8 +381,11 @@ def apply_loras_to_model(
if isinstance(target, nn.QuantizedLinear):
# Dequantize → merge LoRA → replace with bf16 Linear
weight = mx.dequantize(
target.weight, target.scales, target.biases,
group_size=target.group_size, bits=target.bits,
target.weight,
target.scales,
target.biases,
group_size=target.group_size,
bits=target.bits,
)
merged = apply_lora_to_linear(weight, loras)
new_linear = nn.Linear(merged.shape[1], merged.shape[0])
@@ -379,7 +405,9 @@ def apply_loras_to_model(
else:
skipped.append(lora_key)
if verbose:
print(f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear")
print(
f" DEBUG: '{module_path}' is {type(target).__name__}, not Linear"
)
continue
if applied_count > 0:

View File

@@ -2,7 +2,7 @@
import re
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List
import mlx.core as mx

View File

@@ -1,3 +1,2 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.models.wan_2 import WanModel, WanModelConfig

View File

@@ -1,8 +0,0 @@
from mlx_video.models.ltx.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
)
from mlx_video.models.ltx.ltx import LTXModel, X0Model
from mlx_video.models.ltx.audio_vae import AudioDecoder, Vocoder, decode_audio

View File

@@ -1,326 +0,0 @@
"""Audio VAE encoder and decoder for LTX-2."""
from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock
from .upsample import build_upsampling_path
LATENT_DOWNSAMPLE_FACTOR = 4
def build_mid_block(
channels: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
add_attention: bool,
) -> dict:
"""Build the middle block with two ResNet blocks and optional attention."""
mid = {}
mid["block_1"] = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
mid["attn_1"] = (
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
)
mid["block_2"] = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
return mid
def run_mid_block(mid: dict, features: mx.array) -> mx.array:
"""Run features through the middle block."""
features = mid["block_1"](features, temb=None)
if mid["attn_1"] is not None:
features = mid["attn_1"](features)
return mid["block_2"](features, temb=None)
class AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers,
attention resolutions, and causal convolutions.
"""
def __init__(
self,
*,
ch: int = 128,
out_ch: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Set[int] = None,
resolution: int = 256,
z_channels: int = 8,
norm_type: NormType = NormType.PIXEL,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
dropout: float = 0.0,
mid_block_add_attention: bool = True,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: int | None = None,
) -> None:
"""
Initialize the AudioDecoder.
Args:
ch: Base number of feature channels
out_ch: Number of output channels (2 for stereo)
ch_mult: Multiplicative factors for channels at each resolution
num_res_blocks: Number of residual blocks per resolution
attn_resolutions: Resolutions at which to apply attention
resolution: Input spatial resolution
z_channels: Number of latent channels
norm_type: Normalization type
causality_axis: Axis for causal convolutions
dropout: Dropout probability
mid_block_add_attention: Whether to add attention in middle block
sample_rate: Audio sample rate
mel_hop_length: Hop length for mel spectrogram
is_causal: Whether to use causal convolutions
mel_bins: Number of mel frequency bins
"""
super().__init__()
if attn_resolutions is None:
attn_resolutions = {8, 16, 32}
# Internal behavioral defaults
resamp_with_conv = True
attn_type = AttentionType.VANILLA
# Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
# ch=128 matches this dimension, so use ch for per_channel_statistics
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
)
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.out_ch = out_ch
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.z_channels = z_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
self.attn_type = attn_type
base_block_channels = ch * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
self.mid = build_mid_block(
channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=mid_block_add_attention,
)
self.up, final_block_channels = build_upsampling_path(
ch=ch,
ch_mult=ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=num_res_blocks,
resolution=resolution,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
initial_block_channels=base_block_channels,
)
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.conv_out = make_conv2d(
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
def __call__(self, sample: mx.array) -> mx.array:
"""
Decode latent features back to audio spectrograms.
Args:
sample: Encoded latent representation of shape (B, H, W, C) in MLX format
or (B, C, H, W) in PyTorch format (will be transposed)
Returns:
Reconstructed audio spectrogram
"""
# Handle input format - if channels are in dim 1, transpose to channels-last
if sample.shape[1] == self.z_channels and sample.ndim == 4:
# PyTorch format (B, C, H, W) -> MLX format (B, H, W, C)
sample = mx.transpose(sample, (0, 2, 3, 1))
sample, target_shape = self._denormalize_latents(sample)
h = self.conv_in(sample)
h = run_mid_block(self.mid, h)
h = self._run_upsampling_path(h)
h = self._finalize_output(h)
return self._adjust_output_shape(h, target_shape)
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
"""Denormalize latents using per-channel statistics."""
# sample shape: (B, H, W, C) in MLX format
latent_shape = AudioLatentShape(
batch=sample.shape[0],
channels=sample.shape[3], # channels last
frames=sample.shape[1], # height = frames
mel_bins=sample.shape[2], # width = mel_bins
)
sample_patched = self.patchifier.patchify(sample)
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis != CausalityAxis.NONE:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_shape = AudioLatentShape(
batch=latent_shape.batch,
channels=self.out_ch,
frames=target_frames,
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
)
return sample, target_shape
def _adjust_output_shape(
self,
decoded_output: mx.array,
target_shape: AudioLatentShape,
) -> mx.array:
"""
Adjust output shape to match target dimensions for variable-length audio.
Args:
decoded_output: Tensor of shape (B, H, W, C) in MLX format
target_shape: AudioLatentShape describing target dimensions
Returns:
Tensor adjusted to match target_shape exactly
"""
# Current output shape: (batch, frames, mel_bins, channels) in MLX format
_, current_time, current_freq, _ = decoded_output.shape
target_channels = target_shape.channels
target_time = target_shape.frames
target_freq = target_shape.mel_bins
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
]
# Step 2: Calculate padding needed for time and frequency dimensions
time_padding_needed = target_time - decoded_output.shape[1]
freq_padding_needed = target_freq - decoded_output.shape[2]
# Step 3: Apply padding if needed
if time_padding_needed > 0 or freq_padding_needed > 0:
# MLX pad: [(before_0, after_0), ...]
# For (B, H, W, C): H=time, W=freq
padding = [
(0, 0), # batch
(0, max(time_padding_needed, 0)), # time
(0, max(freq_padding_needed, 0)), # freq
(0, 0), # channels
]
decoded_output = mx.pad(decoded_output, padding)
# Step 4: Final safety crop to ensure exact target shape
decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels]
# Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility
decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2))
return decoded_output
def _run_upsampling_path(self, h: mx.array) -> mx.array:
"""Run through upsampling path."""
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx in range(len(stage["block"])):
h = stage["block"][block_idx](h, temb=None)
if block_idx in stage["attn"]:
h = stage["attn"][block_idx](h)
if level != 0 and "upsample" in stage:
h = stage["upsample"](h)
return h
def _finalize_output(self, h: mx.array) -> mx.array:
"""Apply final normalization and convolution."""
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return mx.tanh(h) if self.tanh_out else h
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
"""
Decode an audio latent representation using the provided audio decoder and vocoder.
Args:
latent: Input audio latent tensor
audio_decoder: Model to decode the latent to spectrogram
vocoder: Model to convert spectrogram to audio waveform
Returns:
Decoded audio as a float tensor
"""
decoded_audio = audio_decoder(latent)
decoded_audio = vocoder(decoded_audio)
# Remove batch dimension if present
if decoded_audio.shape[0] == 1:
decoded_audio = decoded_audio[0]
return decoded_audio.astype(mx.float32)

View File

@@ -1,12 +0,0 @@
"""Causality axis enum for specifying causal convolution dimensions."""
from enum import Enum
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"

View File

@@ -1,142 +0,0 @@
"""Vocoder for converting mel spectrograms to audio waveforms."""
import math
from typing import List
import mlx.core as mx
import mlx.nn as nn
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
class Vocoder(nn.Module):
"""
Vocoder model for synthesizing audio from Mel spectrograms.
Based on HiFi-GAN architecture.
Args:
resblock_kernel_sizes: List of kernel sizes for the residual blocks
upsample_rates: List of upsampling rates
upsample_kernel_sizes: List of kernel sizes for the upsampling layers
resblock_dilation_sizes: List of dilation sizes for the residual blocks
upsample_initial_channel: Initial number of channels for upsampling
stereo: Whether to use stereo output
resblock: Type of residual block to use ("1" or "2")
output_sample_rate: Waveform sample rate
"""
def __init__(
self,
resblock_kernel_sizes: List[int] | None = None,
upsample_rates: List[int] | None = None,
upsample_kernel_sizes: List[int] | None = None,
resblock_dilation_sizes: List[List[int]] | None = None,
upsample_initial_channel: int = 1024,
stereo: bool = True,
resblock: str = "1",
output_sample_rate: int = 24000,
):
super().__init__()
# Initialize default values if not provided
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if upsample_rates is None:
upsample_rates = [6, 5, 2, 2, 2]
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [16, 15, 8, 4, 4]
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
self.output_sample_rate = output_sample_rate
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.upsample_initial_channel = upsample_initial_channel
in_channels = 128 if stereo else 64
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
# Upsampling layers using ConvTranspose1d
self.ups = {}
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
in_ch = upsample_initial_channel // (2**i)
out_ch = upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
# Residual blocks
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
block_idx += 1
out_channels = 2 if stereo else 1
final_channels = upsample_initial_channel // (2**self.num_upsamples)
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
self.upsample_factor = math.prod(upsample_rates)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass of the vocoder.
Args:
x: Input Mel spectrogram tensor. Can be either:
- 3D: (batch_size, time, mel_bins) for mono - MLX format (N, L, C)
- 4D: (batch_size, 2, time, mel_bins) for stereo - PyTorch format (N, C, H, W)
Returns:
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
"""
# Input: (batch, channels, time, mel_bins) from audio decoder
# Transpose to (batch, channels, mel_bins, time)
x = mx.transpose(x, (0, 1, 3, 2))
if x.ndim == 4: # stereo
# x shape: (batch, 2, mel_bins, time)
# Rearrange to (batch, 2*mel_bins, time)
b, s, c, t = x.shape
x = x.reshape(b, s * c, t)
# MLX Conv1d expects (N, L, C), so transpose
# Current: (batch, channels, time) -> (batch, time, channels)
x = mx.transpose(x, (0, 2, 1))
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
start = i * self.num_kernels
end = start + self.num_kernels
# Apply residual blocks and average their outputs
block_outputs = []
for idx in range(start, end):
block_outputs.append(self.resblocks[idx](x))
# Stack and mean
x = mx.stack(block_outputs, axis=0)
x = mx.mean(x, axis=0)
# IMPORTANT: Use default leaky_relu slope (0.01), NOT LRELU_SLOPE (0.1)
# PyTorch uses F.leaky_relu(x) which defaults to 0.01
x = nn.leaky_relu(x) # Default negative_slope=0.01
x = self.conv_post(x)
x = mx.tanh(x)
# Transpose back to (batch, channels, time)
x = mx.transpose(x, (0, 2, 1))
return x

View File

@@ -1,182 +0,0 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional
class LTXModelType(Enum):
AudioVideo = "ltx av model"
VideoOnly = "ltx video only model"
AudioOnly = "ltx audio only model"
def is_video_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
def is_audio_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
class LTXRopeType(Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
TWO_D = "2d"
class AttentionType(Enum):
DEFAULT = "default"
@dataclass
class BaseModelConfig:
@classmethod
def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig":
"""Create config from dictionary, filtering only valid parameters."""
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def to_dict(self) -> dict[str, Any]:
"""Export config to dictionary."""
result = {}
for k, v in self.__dict__.items():
if v is not None:
if isinstance(v, Enum):
result[k] = v.value
elif hasattr(v, 'to_dict'):
result[k] = v.to_dict()
else:
result[k] = v
return result
@dataclass
class TransformerConfig(BaseModelConfig):
dim: int
heads: int
d_head: int
context_dim: int
@dataclass
class VideoVAEConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels: int = 3
out_channels: int = 128
latent_channels: int = 128
patch_size: int = 4
encoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
])
decoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
])
@dataclass
class LTXModelConfig(BaseModelConfig):
# Model type
model_type: LTXModelType = LTXModelType.AudioVideo
# Video transformer config
num_attention_heads: int = 32
attention_head_dim: int = 128
in_channels: int = 128
out_channels: int = 128
num_layers: int = 48
cross_attention_dim: int = 4096
caption_channels: int = 3840
# Audio transformer config
audio_num_attention_heads: int = 32
audio_attention_head_dim: int = 64
audio_in_channels: int = 128
audio_out_channels: int = 128
audio_cross_attention_dim: int = 2048
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
# Positional embedding config
positional_embedding_theta: float = 10000.0
positional_embedding_max_pos: Optional[List[int]] = None
audio_positional_embedding_max_pos: Optional[List[int]] = None
use_middle_indices_grid: bool = True
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED
double_precision_rope: bool = False
# Timestep config
timestep_scale_multiplier: int = 1000
av_ca_timestep_scale_multiplier: int = 1000
# Normalization
norm_eps: float = 1e-6
# Attention type
attention_type: AttentionType = AttentionType.DEFAULT
# VAE config
vae_config: Optional[VideoVAEConfig] = None
def __post_init__(self):
"""Set default values after initialization."""
if self.positional_embedding_max_pos is None:
self.positional_embedding_max_pos = [20, 2048, 2048]
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20]
# Convert string enum values if loading from dict
if isinstance(self.model_type, str):
self.model_type = LTXModelType(self.model_type)
if isinstance(self.rope_type, str):
self.rope_type = LTXRopeType(self.rope_type)
if isinstance(self.attention_type, str):
self.attention_type = AttentionType(self.attention_type)
@property
def inner_dim(self) -> int:
"""Video inner dimension."""
return self.num_attention_heads * self.attention_head_dim
@property
def audio_inner_dim(self) -> int:
"""Audio inner dimension."""
return self.audio_num_attention_heads * self.audio_attention_head_dim
def get_video_config(self) -> Optional[TransformerConfig]:
"""Get video transformer configuration."""
if not self.model_type.is_video_enabled():
return None
return TransformerConfig(
dim=self.inner_dim,
heads=self.num_attention_heads,
d_head=self.attention_head_dim,
context_dim=self.cross_attention_dim,
)
def get_audio_config(self) -> Optional[TransformerConfig]:
"""Get audio transformer configuration."""
if not self.model_type.is_audio_enabled():
return None
return TransformerConfig(
dim=self.audio_inner_dim,
heads=self.audio_num_attention_heads,
d_head=self.audio_attention_head_dim,
context_dim=self.audio_cross_attention_dim,
)

View File

@@ -1,8 +0,0 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
)

View File

@@ -1,187 +0,0 @@
"""Video VAE Encoder for LTX-2 Image-to-Video.
The encoder compresses input images/videos to latent representations.
Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation.
"""
from pathlib import Path
from typing import List, Tuple, Any, Optional
import json
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
def load_vae_encoder(model_path: str) -> VideoEncoder:
"""Load VAE encoder from safetensors file.
Args:
model_path: Path to the model weights (safetensors file or directory)
Returns:
Loaded VideoEncoder instance
"""
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE encoder from {weights_path}...")
# Read config from safetensors metadata
encoder_blocks = []
norm_layer = NormLayerType.PIXEL_NORM
latent_log_var = LogVarianceType.UNIFORM
patch_size = 4
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
# Parse encoder blocks
raw_blocks = vae_config.get("encoder_blocks", [])
for block in raw_blocks:
if isinstance(block, list) and len(block) == 2:
name, params = block
encoder_blocks.append((name, params))
# Parse other config
norm_str = vae_config.get("norm_layer", "pixel_norm")
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
var_str = vae_config.get("latent_log_var", "uniform")
if var_str == "uniform":
latent_log_var = LogVarianceType.UNIFORM
elif var_str == "per_channel":
latent_log_var = LogVarianceType.PER_CHANNEL
elif var_str == "constant":
latent_log_var = LogVarianceType.CONSTANT
else:
latent_log_var = LogVarianceType.NONE
patch_size = vae_config.get("patch_size", 4)
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
except Exception as e:
print(f" Could not read config from metadata: {e}")
# Use default config
encoder_blocks = [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
# Create encoder
encoder = VideoEncoder(
convolution_dimensions=3,
in_channels=3,
out_channels=128,
encoder_blocks=encoder_blocks,
patch_size=patch_size,
norm_layer=norm_layer,
latent_log_var=latent_log_var,
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
)
# Load weights
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.encoder."
stats_prefix = "vae.per_channel_statistics."
else:
prefix = "encoder."
stats_prefix = "per_channel_statistics."
# Load per-channel statistics for normalization
mean_key = f"{stats_prefix}mean-of-means"
std_key = f"{stats_prefix}std-of-means"
if mean_key in weights:
encoder.per_channel_statistics.mean = weights[mean_key]
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
if std_key in weights:
encoder.per_channel_statistics.std = weights[std_key]
print(f" Loaded latent std: shape {weights[std_key].shape}")
# Build encoder weights dict with key remapping
encoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
encoder_weights[new_key] = value
print(f" Found {len(encoder_weights)} encoder weights")
# Load weights
encoder.load_weights(list(encoder_weights.items()), strict=False)
print("VAE encoder loaded successfully")
return encoder
def encode_image(
image: mx.array,
encoder: VideoEncoder,
) -> mx.array:
"""Encode a single image to latent space.
Args:
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
encoder: Loaded VAE encoder
Returns:
Latent tensor of shape (1, 128, 1, H//32, W//32)
"""
# Add batch dimension if needed
if image.ndim == 3:
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
# Convert from (B, H, W, C) to (B, C, H, W)
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
# Normalize to [-1, 1]
if image.max() > 1.0:
image = image / 255.0
image = image * 2.0 - 1.0
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
# Encode
latent = encoder(image)
return latent

View File

@@ -0,0 +1,371 @@
# LTX-2 for MLX
MLX port of [LTX-2](https://huggingface.co/Lightricks/LTX-2), a 19B parameter video generation model from Lightricks with synchronized audio-video support.
## Pipelines
Four pipeline types are available via the `--pipeline` flag:
| Pipeline | Description | CFG | Stages | Speed |
|----------|-------------|-----|--------|-------|
| `distilled` (default) | Fixed sigma schedule, no CFG | No | 2 (8+3 steps) | Fastest |
| `dev` | Dynamic sigmas, constant CFG | Yes | 1 (30 steps) | Medium |
| `dev-two-stage` | Dev + LoRA refinement | Yes (stage 1) | 2 (30+3 steps) | Slow |
| `dev-two-stage-hq` | res_2s sampler + LoRA both stages | Yes (stage 1) | 2 (15+3 steps) | Slow, highest quality |
## Usage
### Text-to-Video (T2V)
```bash
# Distilled (default) - fast, two-stage
uv run mlx_video.generate --prompt "Two dogs wearing sunglasses, cinematic, sunset" -n 97 --width 768
# Dev - single-stage with CFG
uv run mlx_video.generate --pipeline dev --prompt "A cinematic scene" --cfg-scale 3.0
# Dev two-stage - dev + LoRA refinement
uv run mlx_video.generate --pipeline dev-two-stage \
--prompt "Two dogs of the poodle breed wearing sunglasses, close up, cinematic, sunset" \
-n 145 --width 1024 --height 768 \
--model-repo prince-canuma/LTX-2-dev \
--cfg-scale 3.0 --lora-strength 0.8 \
--enhance-prompt
# Dev two-stage HQ - res_2s sampler, LoRA both stages (highest quality)
uv run mlx_video.generate --pipeline dev-two-stage-hq \
--prompt "A cinematic scene of ocean waves at golden hour" \
--model-repo prince-canuma/LTX-2-dev
# HQ with custom LoRA strengths
uv run mlx_video.generate --pipeline dev-two-stage-hq \
--prompt "A sunset over mountains" \
--model-repo prince-canuma/LTX-2-dev \
--lora-strength-stage-1 0.3 --lora-strength-stage-2 0.6
```
### Image-to-Video (I2V)
```bash
# Distilled I2V
uv run mlx_video.generate --prompt "A person dancing" --image photo.jpg
# Dev I2V
uv run mlx_video.generate --pipeline dev --prompt "Waves crashing" --image beach.png --cfg-scale 3.5
```
### Audio-to-Video (A2V)
Generate video conditioned on an input audio file. Works with all four pipelines. The audio is encoded to latent space and frozen during denoising -- the transformer's cross-attention reads the audio signal to guide video generation.
```bash
# A2V - distilled (default, fastest)
uv run mlx_video.generate --audio-file music.wav --prompt "A band playing music"
# A2V - dev (single-stage with CFG)
uv run mlx_video.generate --pipeline dev --audio-file ocean.wav --prompt "Ocean waves"
# A2V - dev-two-stage (dev + LoRA refinement)
uv run mlx_video.generate --pipeline dev-two-stage --audio-file music.wav \
--prompt "A band playing music" --model-repo prince-canuma/LTX-2-dev
# A2V - dev-two-stage-hq (highest quality)
uv run mlx_video.generate --pipeline dev-two-stage-hq --audio-file music.wav \
--prompt "A band playing music" --model-repo prince-canuma/LTX-2-dev
# A2V + I2V (audio + image conditioning)
uv run mlx_video.generate --audio-file rain.wav --image forest.jpg --prompt "Rain in forest"
# A2V with custom start time
uv run mlx_video.generate --audio-file song.mp3 --audio-start-time 30.0 --prompt "Concert"
```
> **Note:** `--audio-file` (A2V) and `--audio` (generate audio) are mutually exclusive. Supported formats: WAV, FLAC, MP3, OGG, and video files with audio tracks.
### Audio-Video Generation (experimental)
Generate synchronized audio alongside video from scratch:
```bash
uv run mlx_video.generate --prompt "Ocean waves crashing" --audio
uv run mlx_video.generate --pipeline dev --prompt "A jazz band playing" --audio --enhance-prompt
# With full guidance (STG + modality_scale, matches PyTorch defaults)
uv run mlx_video.generate --pipeline dev --prompt "Ocean waves crashing" --audio \
--stg-scale 1.0 --stg-blocks 29 --modality-scale 3.0
```
### LoRA
LoRA weights can be loaded from a file, directory, or HuggingFace repo:
```bash
# From HuggingFace repo
uv run mlx_video.generate --pipeline dev-two-stage \
--prompt "Camera dolly out of a forest" \
--lora-path Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out \
--lora-strength 1.0
# From local file
uv run mlx_video.generate --pipeline dev-two-stage \
--prompt "A scene" \
--lora-path ./my-lora/weights.safetensors
# From local directory (auto-detects .safetensors file)
uv run mlx_video.generate --pipeline dev-two-stage \
--prompt "A scene" \
--lora-path ./LTX-2-distilled/lora
```
### Upscaling
```bash
# Upscale an image 2x
uv run mlx_video.upscale --input photo.png --output upscaled.png
# Upscale a video 2x
uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4
# Upscale with refinement (higher quality, requires text prompt)
uv run mlx_video.upscale --input video.mp4 --output upscaled.mp4 --refine --prompt "A cinematic scene"
```
## CLI Options
### General
| Option | Default | Description |
|--------|---------|-------------|
| `--prompt`, `-p` | (required) | Text description of the video |
| `--pipeline` | `distilled` | Pipeline type: `distilled`, `dev`, `dev-two-stage`, or `dev-two-stage-hq` |
| `--height`, `-H` | 512 | Output height (divisible by 64 for two-stage, 32 for dev) |
| `--width`, `-W` | 512 | Output width (divisible by 64 for two-stage, 32 for dev) |
| `--num-frames`, `-n` | 33 | Number of frames (must be 1 + 8*k) |
| `--seed`, `-s` | 42 | Random seed for reproducibility |
| `--fps` | 24 | Frames per second |
| `--output-path`, `-o` | output.mp4 | Output video path |
| `--model-repo` | Lightricks/LTX-2 | HuggingFace model repository |
| `--text-encoder-repo` | None | Separate text encoder repo (if not in model repo) |
| `--save-frames` | false | Save individual frames as images |
| `--enhance-prompt` | false | Enhance prompt using Gemma |
| `--image`, `-i` | None | Conditioning image for I2V |
| `--image-strength` | 1.0 | Conditioning strength for I2V |
| `--audio`, `-a` | false | Enable synchronized audio generation |
| `--audio-file` | None | Path to audio file for A2V conditioning |
| `--audio-start-time` | 0.0 | Start time in seconds for audio file |
| `--tiling` | `auto` | VAE tiling mode: `auto`, `none`, `aggressive`, `conservative` |
| `--stream` | false | Stream frames as they decode |
| `--spatial-upscaler` | auto (x2) | Spatial upscaler file for two-stage pipelines (see below) |
### Spatial Upscalers (LTX-2.3)
LTX-2.3 ships with multiple spatial upscaler variants. Use `--spatial-upscaler` to select one:
| Variant | Scale | Output (from 256x256) | Architecture |
|---------|-------|-----------------------|--------------|
| `ltx-2.3-spatial-upscaler-x2-1.0.safetensors` (default) | 2.0x | 512x512 | Conv2d + PixelShuffle(2) |
| `ltx-2.3-spatial-upscaler-x2-1.1.safetensors` | 2.0x | 512x512 | Same arch, newer weights |
| `ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors` | 1.5x | 384x384 | Conv2d + PixelShuffle(3) + BlurDownsample |
```bash
# Default (x2-1.0, auto-detected)
uv run mlx_video.generate --prompt "A sunset" --model-repo ./LTX-2.3-distilled
# x2-1.1 (newer weights)
uv run mlx_video.generate --prompt "A sunset" --model-repo ./LTX-2.3-distilled \
--spatial-upscaler ltx-2.3-spatial-upscaler-x2-1.1.safetensors
# x1.5 (smaller output, faster)
uv run mlx_video.generate --prompt "A sunset" --model-repo ./LTX-2.3-distilled \
--spatial-upscaler ltx-2.3-spatial-upscaler-x1.5-1.0.safetensors
```
> **Note:** Stage 1 always runs at half the target resolution. With x1.5, the final output is 75% of `--width`/`--height` (e.g., 512 target -> 256 stage 1 -> 384 output). With x2, the output matches the target exactly.
### Dev / Dev-Two-Stage
| Option | Default | Description |
|--------|---------|-------------|
| `--steps` | 30 | Number of denoising steps |
| `--cfg-scale` | 3.0 | CFG guidance scale |
| `--cfg-rescale` | 0.7 | CFG rescale factor (reduces over-saturation) |
| `--negative-prompt` | (default) | Negative prompt for CFG |
| `--apg` | false | Use Adaptive Projected Guidance (more stable for I2V) |
| `--stg-scale` | 0.0 | STG scale (PyTorch default: 1.0, requires `--audio`) |
| `--stg-blocks` | None | Transformer blocks for STG ([29] for LTX-2, [28] for LTX-2.3) |
| `--modality-scale` | 1.0 | Cross-modal guidance scale (PyTorch default: 3.0, requires `--audio`) |
### Dev-Two-Stage LoRA
| Option | Default | Description |
|--------|---------|-------------|
| `--lora-path` | auto-detect | Path to LoRA file, directory, or HuggingFace repo |
| `--lora-strength` | 1.0 | LoRA merge strength |
### Dev-Two-Stage HQ
| Option | Default | Description |
|--------|---------|-------------|
| `--lora-strength-stage-1` | 0.25 | LoRA strength for stage 1 |
| `--lora-strength-stage-2` | 0.5 | LoRA strength for stage 2 |
HQ defaults: 15 steps (vs 30), `cfg-rescale` 0.45 (vs 0.7), STG disabled. Uses the res_2s second-order sampler (2 model evals per step) for better quality at the same compute budget.
## How It Works
### Distilled Pipeline (default)
1. **Stage 1**: Generate at half resolution with 8 denoising steps (fixed sigmas)
2. **Upsample**: Spatial upsampling via LatentUpsampler (x2 or x1.5, selectable via `--spatial-upscaler`)
3. **Stage 2**: Refine at upsampled resolution with 3 denoising steps
4. **Decode**: VAE decoder converts latents to RGB video
### Dev Pipeline
1. **Generate**: Full resolution with configurable steps and constant CFG
2. **Decode**: VAE decoder converts latents to RGB video
### Dev Two-Stage Pipeline
1. **Stage 1**: Dev denoising at half resolution with CFG
2. **Upsample**: Spatial upsampling via LatentUpsampler (x2 or x1.5)
3. **Stage 2**: Distilled refinement at upsampled resolution with LoRA weights (3 steps, no CFG)
4. **Decode**: VAE decoder converts latents to RGB video
### Dev Two-Stage HQ Pipeline
1. **Stage 1**: res_2s denoising at half resolution with CFG + LoRA@0.25 (15 steps, 2 evals/step)
2. **Upsample**: Spatial upsampling via LatentUpsampler (x2 or x1.5)
3. **Stage 2**: res_2s refinement at upsampled resolution with LoRA@0.5 (3 steps, no CFG)
4. **Decode**: VAE decoder converts latents to RGB video
The res_2s sampler uses an exponential Rosenbrock-type Runge-Kutta integrator with SDE noise injection, producing higher quality results than Euler at the same compute budget (~30 total model evaluations).
### Audio-to-Video (A2V) Conditioning
A2V works by encoding input audio into the same latent space as generated audio, then **freezing** those latents during denoising:
1. Load audio file, resample to 16kHz, compute mel-spectrogram
2. `AudioEncoder(mel_spec)` produces audio latents `(B, 8, T, 16)`
3. Normalize via `PerChannelStatistics`
4. Freeze during denoising: `timesteps=0`, `sigma=0`, skip Euler/RK updates
5. Transformer's A2V cross-attention reads frozen audio to guide video generation
6. Output: denoised video + original input audio waveform (skip audio VAE decode)
## Converting Models
Convert original Lightricks/LTX-2 weights to the modular mlx-video format:
```bash
# Convert distilled model
uv run python -m mlx_video.models.ltx_2.convert \
--source Lightricks/LTX-2 --output ./LTX-2-distilled --variant distilled
# Convert dev model
uv run python -m mlx_video.models.ltx_2.convert \
--source Lightricks/LTX-2 --output ./LTX-2-dev --variant dev
```
This extracts 7 components from the monolithic checkpoint:
```
LTX-2-distilled/
├── transformer/ # DiT transformer (19B params)
├── vae/
│ ├── decoder/ # Video VAE decoder
│ └── encoder/ # Video VAE encoder
├── audio_vae/
│ ├── decoder/ # Audio VAE decoder
│ └── encoder/ # Audio VAE encoder
├── vocoder/ # Mel-spectrogram to waveform
└── text_projections/ # Text embedding projections
```
Pre-converted weights are available on HuggingFace:
- [prince-canuma/LTX-2-distilled](https://huggingface.co/prince-canuma/LTX-2-distilled)
- [prince-canuma/LTX-2-dev](https://huggingface.co/prince-canuma/LTX-2-dev)
- [prince-canuma/LTX-2.3-distilled](https://huggingface.co/prince-canuma/LTX-2.3-distilled)
- [prince-canuma/LTX-2.3-dev](https://huggingface.co/prince-canuma/LTX-2.3-dev)
## Model Specifications
- **Transformer**: 48 layers, 32 attention heads, 128 dim per head (19B parameters)
- **Latent channels**: 128
- **Patch size**: 4 (for VAE patchify/unpatchify)
- **Text encoder**: Gemma 3 with 3840-dim output
- **RoPE**: Split mode with double precision (LTX-2.3) or standard (LTX-2)
- **Audio VAE**: Encoder (~35M), Decoder (~50M), Vocoder (~13M)
### Audio VAE Architecture
```
Audio Encoder: mel-spectrogram -> latents (B, 8, T, 16)
- Channel multipliers: (1, 2, 4)
- ResNet blocks with optional attention
- GroupNorm or PixelNorm normalization
- Optional causal convolutions
Audio Decoder: latents -> mel-spectrogram
- Mirrors encoder with upsampling path
- Per-channel statistics for latent normalization
Vocoder: mel-spectrogram -> waveform (~13M params)
- HiFi-GAN style architecture
- Upsample rates: [6, 5, 2, 2, 2]
- ResBlock1 with dilations [1, 3, 5]
```
## Project Structure
```
mlx_video/models/ltx_2/
├── __init__.py
├── config.py # LTXModelConfig, AudioEncoderModelConfig, AudioDecoderModelConfig
├── convert.py # Weight conversion from Lightricks/LTX-2
├── generate.py # Unified generation pipeline (T2V, I2V, A2V, +Audio)
├── postprocess.py # Video post-processing
├── samplers.py # Euler and res_2s samplers
├── utils.py # Shared utilities (get_model_path, load_safetensors, etc.)
├── ltx.py # Main LTXModel (DiT transformer with AV support)
├── transformer.py # Transformer blocks, Modality dataclass
├── attention.py # Multi-head attention with RoPE
├── feed_forward.py # Feed-forward layers
├── adaln.py # Adaptive Layer Normalization
├── rope.py # Rotary Position Embeddings (split/combined)
├── text_projection.py # Text embedding projection
├── text_encoder.py # Text encoder with AV embeddings support
├── upsampler.py # LatentUpsampler for 2-stage generation
├── conditioning/
│ ├── keyframe.py # Image-to-video keyframe conditioning
│ └── latent.py # Video-to-video latent conditioning
├── video_vae/
│ ├── decoder.py # VAE decoder with timestep conditioning
│ ├── encoder.py # VAE encoder for image/video encoding
│ ├── convolution.py # CausalConv3d, CausalConv2d
│ ├── ops.py # patchify, unpatchify, PerChannelStatistics
│ ├── resnet.py # ResBlock3D, ResBlockGroup
│ ├── sampling.py # DepthToSpaceUpsample, SpaceToDepthDownsample
│ └── video_vae.py # Full VAE (encoder + decoder)
└── audio_vae/
├── audio_vae.py # Audio encoder and decoder
├── audio_processor.py # Mel-spectrogram computation (librosa)
├── vocoder.py # Mel-spectrogram to waveform synthesis
├── ops.py # AudioPatchifier, PerChannelStatistics
├── resnet.py # ResNet blocks for audio
├── attention.py # Attention blocks for audio VAE
├── normalization.py # Normalization layers
├── causal_conv_2d.py # Causal 2D convolutions
├── downsample.py # Downsampling layers
└── upsample.py # Upsampling layers
```
## LTX-2 vs LTX-2.3
LTX-2.3 introduces prompt-conditioned adaptive layer normalization (adaln):
| Feature | LTX-2 | LTX-2.3 |
|---------|--------|---------|
| AdaLN | Standard | Prompt-conditioned (`has_prompt_adaln=True`) |
| Attention gate | None | `2.0 * sigmoid(gate_logits)` |
| Scale-shift table | 6 params | 9 params (+ cross-attn Q) |
| Text encoder connectors | 2 blocks | 8 blocks with gate_logits |
| Feature extractor | V1 (batch-level) | V2 (per-token RMSNorm) |
| RoPE | Standard | Double precision |
| STG blocks | [29] | [28] |
| Text encoder repo | Included | Separate (`--text-encoder-repo`) |

View File

@@ -0,0 +1,7 @@
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
LTXModelType,
TransformerConfig,
)
from mlx_video.models.ltx_2.ltx_2 import LTXModel, X0Model

View File

@@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding
class AdaLayerNormSingle(nn.Module):
def __init__(
self,
embedding_dim: int,
@@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module):
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
self.linear = nn.Linear(
embedding_dim, embedding_coefficient * embedding_dim, bias=True
)
def __call__(
self,
@@ -56,15 +57,19 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
use_additional_conditions: bool = False,
timestep_proj_dim: int = 256,
):
super().__init__()
self.embedding_dim = embedding_dim
self.size_emb_dim = size_emb_dim
self.use_additional_conditions = use_additional_conditions
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
self.time_proj = Timesteps(
timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.timestep_embedder = TimestepEmbedding(
timestep_proj_dim, embedding_dim, out_dim=embedding_dim
)
if use_additional_conditions and size_emb_dim > 0:
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
@@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
# Add additional conditions if enabled
if self.use_additional_conditions and self.size_emb_dim > 0:
if resolution is not None and aspect_ratio is not None:
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
additional_embeds = self.additional_embedder(
resolution, aspect_ratio, hidden_dtype
)
timesteps_emb = timesteps_emb + additional_embeds
return timesteps_emb

View File

@@ -6,8 +6,8 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType
from mlx_video.models.ltx.rope import apply_rotary_emb
from mlx_video.models.ltx_2.config import LTXRopeType
from mlx_video.models.ltx_2.rope import apply_rotary_emb
def scaled_dot_product_attention(
@@ -67,17 +67,8 @@ class Attention(nn.Module):
dim_head: int = 64,
norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
has_gate_logits: bool = False,
):
"""Initialize attention module.
Args:
query_dim: Dimension of query input
context_dim: Dimension of context (key/value) input. If None, same as query_dim
heads: Number of attention heads
dim_head: Dimension per head
norm_eps: Epsilon for RMS normalization
rope_type: Type of rotary position embedding
"""
super().__init__()
self.rope_type = rope_type
@@ -99,6 +90,10 @@ class Attention(nn.Module):
# Output projection
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
# Per-head gating (LTX-2.3)
if has_gate_logits:
self.to_gate_logits = nn.Linear(query_dim, heads, bias=True)
def __call__(
self,
x: mx.array,
@@ -106,6 +101,7 @@ class Attention(nn.Module):
mask: Optional[mx.array] = None,
pe: Optional[Tuple[mx.array, mx.array]] = None,
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
skip_attention: bool = False,
) -> mx.array:
"""Forward pass.
@@ -115,28 +111,44 @@ class Attention(nn.Module):
mask: Attention mask
pe: Position embeddings for query (and key if k_pe is None)
k_pe: Position embeddings for key (optional, uses pe if None)
skip_attention: If True, bypass Q*K*V attention and use value projection
only (for STG perturbation). Matches PyTorch all_perturbed=True.
Returns:
Attention output of shape (B, seq_len, query_dim)
"""
# Compute Q, K, V
q = self.to_q(x)
# Compute per-head gate early (from original input)
gate = None
if hasattr(self, "to_gate_logits"):
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
# Apply normalization
q = self.q_norm(q)
k = self.k_norm(k)
if skip_attention:
# STG: bypass Q*K*V attention, use value projection only
out = v
else:
# Standard attention
q = self.to_q(x)
k = self.to_k(context)
# Apply rotary position embeddings
if pe is not None:
q = apply_rotary_emb(q, pe, self.rope_type)
k_pe_to_use = pe if k_pe is None else k_pe
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
q = self.q_norm(q)
k = self.k_norm(k)
# Compute attention
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
if pe is not None:
q = apply_rotary_emb(q, pe, self.rope_type)
k_pe_to_use = pe if k_pe is None else k_pe
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
# Apply per-head gating
if gate is not None:
b, seq_len, _ = out.shape
out = mx.reshape(out, (b, seq_len, self.heads, self.dim_head))
out = out * gate[..., None]
out = mx.reshape(out, (b, seq_len, -1))
# Project output
return self.to_out(out)

View File

@@ -1,21 +1,28 @@
"""Audio VAE module for LTX-2 audio generation."""
from ..config import CausalityAxis
from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, decode_audio
from .audio_processor import ensure_stereo, load_audio, waveform_to_mel
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
from .causal_conv_2d import CausalConv2d, make_conv2d
from .causality_axis import CausalityAxis
from .downsample import Downsample, build_downsampling_path
from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock
from .upsample import Upsample, build_upsampling_path
from .vocoder import Vocoder
from .vocoder import Vocoder, load_vocoder
__all__ = [
# Main components
"AudioEncoder",
"AudioDecoder",
"Vocoder",
"load_vocoder",
"decode_audio",
# Audio processing
"load_audio",
"ensure_stereo",
"waveform_to_mel",
# Ops
"AudioLatentShape",
"AudioPatchifier",

View File

@@ -32,7 +32,9 @@ class AttnBlock(nn.Module):
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def __call__(self, x: mx.array) -> mx.array:
"""
@@ -103,6 +105,8 @@ def make_attn(
elif attn_type == AttentionType.NONE:
return Identity()
elif attn_type == AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
raise NotImplementedError(
f"Attention type {attn_type.value} is not supported yet."
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")

View File

@@ -0,0 +1,136 @@
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
using librosa for macOS/MLX compatibility.
"""
import mlx.core as mx
import numpy as np
def load_audio(
path: str,
target_sr: int = 16000,
start_time: float = 0.0,
max_duration: float | None = None,
mono: bool = False,
) -> tuple[np.ndarray, int]:
"""Load audio file, resample to target sample rate.
Args:
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
target_sr: Target sample rate (default 16000 Hz).
start_time: Start time in seconds.
max_duration: Maximum duration in seconds. None = read to end.
mono: If True, convert to mono. Default False (preserve channels).
Returns:
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
"""
import librosa
# librosa.load returns mono by default; we want to preserve stereo
y, sr = librosa.load(
path,
sr=target_sr,
mono=mono,
offset=start_time,
duration=max_duration,
)
# Ensure 2D: (channels, samples)
if y.ndim == 1:
y = y[np.newaxis, :] # (1, samples)
return y.astype(np.float32), sr
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
if waveform.ndim == 1:
waveform = waveform[np.newaxis, :]
if waveform.shape[0] == 1:
waveform = np.concatenate([waveform, waveform], axis=0)
elif waveform.shape[0] > 2:
waveform = waveform[:2]
return waveform
def waveform_to_mel(
waveform: np.ndarray,
sample_rate: int = 16000,
n_fft: int = 1024,
hop_length: int = 160,
win_length: int = 1024,
n_mels: int = 64,
fmin: float = 0.0,
fmax: float = 8000.0,
) -> mx.array:
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
PyTorch reference:
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
Args:
waveform: (channels, samples) float32 numpy array.
sample_rate: Sample rate of the waveform.
n_fft: FFT size.
hop_length: Hop length.
win_length: Window length.
n_mels: Number of mel bins.
fmin: Minimum frequency for mel filterbank.
fmax: Maximum frequency for mel filterbank.
Returns:
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
"""
import librosa
# Ensure 2D
if waveform.ndim == 1:
waveform = waveform[np.newaxis, :]
channels = waveform.shape[0]
mels = []
for ch in range(channels):
# Magnitude spectrogram (power=1.0)
S = np.abs(
librosa.stft(
waveform[ch],
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=True,
pad_mode="reflect",
)
)
# Mel filterbank with slaney normalization
mel_basis = librosa.filters.mel(
sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
norm="slaney",
)
mel = mel_basis @ S
# Log scale
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
# Transpose: (n_mels, time) -> (time, n_mels)
mel = mel.T
mels.append(mel)
# Stack channels: (channels, time, n_mels)
mel_spec = np.stack(mels, axis=0)
# Add batch dim: (1, channels, time, n_mels)
mel_spec = mel_spec[np.newaxis, ...]
return mx.array(mel_spec, dtype=mx.float32)

View File

@@ -0,0 +1,571 @@
"""Audio VAE encoder and decoder for LTX-2."""
from pathlib import Path
from typing import Dict
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock
from .upsample import build_upsampling_path
LATENT_DOWNSAMPLE_FACTOR = 4
def build_mid_block(
channels: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
add_attention: bool,
) -> dict:
"""Build the middle block with two ResNet blocks and optional attention."""
mid = {}
mid["block_1"] = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
mid["attn_1"] = (
make_attn(channels, attn_type=attn_type, norm_type=norm_type)
if add_attention
else None
)
mid["block_2"] = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
return mid
def run_mid_block(mid: dict, features: mx.array) -> mx.array:
"""Run features through the middle block."""
features = mid["block_1"](features, temb=None)
if mid["attn_1"] is not None:
features = mid["attn_1"](features)
return mid["block_2"](features, temb=None)
class AudioEncoder(nn.Module):
"""Encoder that compresses audio spectrograms into latent representations."""
def __init__(self, config: AudioEncoderModelConfig) -> None:
super().__init__()
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = config.sample_rate
self.mel_hop_length = config.mel_hop_length
self.is_causal = config.is_causal
self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=config.sample_rate,
hop_length=config.mel_hop_length,
is_causal=config.is_causal,
)
self.ch = config.ch
self.temb_ch = 0
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.in_channels = config.in_channels
self.z_channels = config.z_channels
self.double_z = config.double_z
self.norm_type = config.norm_type
self.causality_axis = config.causality_axis
self.attn_type = config.attn_type
self.conv_in = make_conv2d(
config.in_channels,
self.ch,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
self.down, block_in = build_downsampling_path(
ch=config.ch,
ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=config.num_res_blocks,
resolution=config.resolution,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=config.attn_resolutions or set(),
resamp_with_conv=config.resamp_with_conv,
)
self.mid = build_mid_block(
channels=block_in,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=config.mid_block_add_attention,
)
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
self.conv_out = make_conv2d(
block_in,
out_channels,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio encoder weights from PyTorch format."""
sanitized = {}
for key, value in weights.items():
new_key = key
if key.startswith("audio_vae.encoder."):
new_key = key.replace("audio_vae.encoder.", "")
elif key.startswith("encoder."):
new_key = key.replace("encoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif "per_channel_statistics" in key:
if "mean-of-means" in key or "latents_mean" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key or "latents_std" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif key == "latents_mean":
new_key = "per_channel_statistics.mean_of_means"
elif key == "latents_std":
new_key = "per_channel_statistics.std_of_means"
else:
continue
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = (
value
if check_array_shape(value)
else mx.transpose(value, (0, 2, 3, 1))
)
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
"""Load audio encoder from pretrained weights."""
import json
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
model_path = Path(model_path)
config = AudioEncoderModelConfig.from_dict(
json.load(open(model_path / "config.json"))
)
encoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
encoder.load_weights(list(weights.items()), strict=True)
return encoder
def __call__(self, spectrogram: mx.array) -> mx.array:
"""Encode audio spectrogram into normalized latent representation.
Args:
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
Returns:
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
"""
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
h = self.conv_in(spectrogram)
h = self._run_downsampling_path(h)
h = run_mid_block(self.mid, h)
h = self._finalize_output(h)
return self._normalize_latents(h)
def _run_downsampling_path(self, h: mx.array) -> mx.array:
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx in range(self.num_res_blocks):
h = stage["block"][block_idx](h, temb=None)
if block_idx in stage["attn"]:
h = stage["attn"][block_idx](h)
if level != self.num_resolutions - 1 and "downsample" in stage:
h = stage["downsample"](h)
return h
def _finalize_output(self, h: mx.array) -> mx.array:
h = self.norm_out(h)
h = nn.silu(h)
return self.conv_out(h)
def _normalize_latents(self, h: mx.array) -> mx.array:
"""Normalize encoder output using per-channel statistics.
Takes first half of channels (mean) when double_z=True,
then patchifies, normalizes, and unpatchifies.
"""
# h shape: (B, T', F', 2*z_channels) in MLX format
z_channels = self.z_channels
means = h[..., :z_channels]
latent_shape = AudioLatentShape(
batch=means.shape[0],
channels=means.shape[3],
frames=means.shape[1],
mel_bins=means.shape[2],
)
patched = self.patchifier.patchify(means)
normalized = self.per_channel_statistics.normalize(patched)
return self.patchifier.unpatchify(normalized, latent_shape)
class AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers,
attention resolutions, and causal convolutions.
"""
def __init__(
self,
config: AudioDecoderModelConfig,
) -> None:
"""
Initialize the AudioDecoder.
Args:
ch: Base number of feature channels
out_ch: Number of output channels (2 for stereo)
ch_mult: Multiplicative factors for channels at each resolution
num_res_blocks: Number of residual blocks per resolution
attn_resolutions: Resolutions at which to apply attention
resolution: Input spatial resolution
z_channels: Number of latent channels
norm_type: Normalization type
causality_axis: Axis for causal convolutions
dropout: Dropout probability
mid_block_add_attention: Whether to add attention in middle block
sample_rate: Audio sample rate
mel_hop_length: Hop length for mel spectrogram
is_causal: Whether to use causal convolutions
mel_bins: Number of mel frequency bins
"""
super().__init__()
# Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
# ch=128 matches this dimension, so use ch for per_channel_statistics
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = config.sample_rate
self.mel_hop_length = config.mel_hop_length
self.is_causal = config.is_causal
self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=config.sample_rate,
hop_length=config.mel_hop_length,
is_causal=config.is_causal,
)
self.ch = config.ch
self.temb_ch = 0
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.out_ch = config.out_ch
self.give_pre_end = config.give_pre_end
self.tanh_out = config.tanh_out
self.norm_type = config.norm_type
self.z_channels = config.z_channels
self.channel_multipliers = config.ch_mult
self.attn_resolutions = config.attn_resolutions
self.causality_axis = config.causality_axis
self.attn_type = config.attn_type
base_block_channels = config.ch * self.channel_multipliers[-1]
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
config.z_channels,
base_block_channels,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
self.mid = build_mid_block(
channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=config.mid_block_add_attention,
)
self.up, final_block_channels = build_upsampling_path(
ch=config.ch,
ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=config.num_res_blocks,
resolution=config.resolution,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=config.attn_resolutions,
resamp_with_conv=config.resamp_with_conv,
initial_block_channels=base_block_channels,
)
self.norm_out = build_normalization_layer(
final_block_channels, normtype=self.norm_type
)
self.conv_out = make_conv2d(
final_block_channels,
config.out_ch,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for audio VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle audio_vae.decoder weights
if key.startswith("audio_vae.decoder."):
new_key = key.replace("audio_vae.decoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue # Skip other statistics keys
else:
continue # Skip non-decoder keys
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = (
value
if check_array_shape(value)
else mx.transpose(value, (0, 2, 3, 1))
)
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model."""
import json
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
config = AudioDecoderModelConfig.from_dict(
json.load(open(model_path / "config.json"))
)
decoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = decoder.sanitize(weights)
decoder.load_weights(list(weights.items()), strict=True)
return decoder
def __call__(self, sample: mx.array) -> mx.array:
"""
Decode latent features back to audio spectrograms.
Args:
sample: Encoded latent representation of shape (B, H, W, C) in MLX format
or (B, C, H, W) in PyTorch format (will be transposed)
Returns:
Reconstructed audio spectrogram
"""
# Handle input format - if channels are in dim 1, transpose to channels-last
if sample.shape[1] == self.z_channels and sample.ndim == 4:
# PyTorch format (B, C, H, W) -> MLX format (B, H, W, C)
sample = mx.transpose(sample, (0, 2, 3, 1))
sample, target_shape = self._denormalize_latents(sample)
h = self.conv_in(sample)
h = run_mid_block(self.mid, h)
h = self._run_upsampling_path(h)
h = self._finalize_output(h)
return self._adjust_output_shape(h, target_shape)
def _denormalize_latents(
self, sample: mx.array
) -> tuple[mx.array, AudioLatentShape]:
"""Denormalize latents using per-channel statistics."""
# sample shape: (B, H, W, C) in MLX format
latent_shape = AudioLatentShape(
batch=sample.shape[0],
channels=sample.shape[3], # channels last
frames=sample.shape[1], # height = frames
mel_bins=sample.shape[2], # width = mel_bins
)
sample_patched = self.patchifier.patchify(sample)
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis != CausalityAxis.NONE:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_shape = AudioLatentShape(
batch=latent_shape.batch,
channels=self.out_ch,
frames=target_frames,
mel_bins=(
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
),
)
return sample, target_shape
def _adjust_output_shape(
self,
decoded_output: mx.array,
target_shape: AudioLatentShape,
) -> mx.array:
"""
Adjust output shape to match target dimensions for variable-length audio.
Args:
decoded_output: Tensor of shape (B, H, W, C) in MLX format
target_shape: AudioLatentShape describing target dimensions
Returns:
Tensor adjusted to match target_shape exactly
"""
# Current output shape: (batch, frames, mel_bins, channels) in MLX format
_, current_time, current_freq, _ = decoded_output.shape
target_channels = target_shape.channels
target_time = target_shape.frames
target_freq = target_shape.mel_bins
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:,
: min(current_time, target_time),
: min(current_freq, target_freq),
:target_channels,
]
# Step 2: Calculate padding needed for time and frequency dimensions
time_padding_needed = target_time - decoded_output.shape[1]
freq_padding_needed = target_freq - decoded_output.shape[2]
# Step 3: Apply padding if needed
if time_padding_needed > 0 or freq_padding_needed > 0:
# MLX pad: [(before_0, after_0), ...]
# For (B, H, W, C): H=time, W=freq
padding = [
(0, 0), # batch
(0, max(time_padding_needed, 0)), # time
(0, max(freq_padding_needed, 0)), # freq
(0, 0), # channels
]
decoded_output = mx.pad(decoded_output, padding)
# Step 4: Final safety crop to ensure exact target shape
decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels]
# Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility
decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2))
return decoded_output
def _run_upsampling_path(self, h: mx.array) -> mx.array:
"""Run through upsampling path."""
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx in range(len(stage["block"])):
h = stage["block"][block_idx](h, temb=None)
if block_idx in stage["attn"]:
h = stage["attn"][block_idx](h)
if level != 0 and "upsample" in stage:
h = stage["upsample"](h)
return h
def _finalize_output(self, h: mx.array) -> mx.array:
"""Apply final normalization and convolution."""
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return mx.tanh(h) if self.tanh_out else h
def decode_audio(
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
) -> mx.array:
"""
Decode an audio latent representation using the provided audio decoder and vocoder.
Args:
latent: Input audio latent tensor
audio_decoder: Model to decode the latent to spectrogram
vocoder: Model to convert spectrogram to audio waveform
Returns:
Decoded audio as a float tensor
"""
decoded_audio = audio_decoder(latent)
decoded_audio = vocoder(decoded_audio)
# Remove batch dimension if present
if decoded_audio.shape[0] == 1:
decoded_audio = decoded_audio[0]
return decoded_audio.astype(mx.float32)

View File

@@ -5,7 +5,7 @@ from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
@@ -53,8 +53,16 @@ class CausalConv2d(nn.Module):
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
if self.causality_axis == CausalityAxis.NONE:
# Non-causal: symmetric padding
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
self.padding = (
pad_h // 2,
pad_h - pad_h // 2,
pad_w // 2,
pad_w - pad_w // 2,
)
elif self.causality_axis in (
CausalityAxis.WIDTH,
CausalityAxis.WIDTH_COMPATIBILITY,
):
# Causal on width: pad left (before width axis)
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
elif self.causality_axis == CausalityAxis.HEIGHT:
@@ -90,7 +98,10 @@ class CausalConv2d(nn.Module):
if any(p > 0 for p in self.padding):
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
x = mx.pad(
x,
[(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)],
)
return self.conv(x)
@@ -124,7 +135,14 @@ def make_conv2d(
if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d(
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
in_channels,
out_channels,
kernel_size,
stride,
dilation,
groups,
bias,
causality_axis,
)
else:
# For non-causal convolution, use symmetric padding if not specified

View File

@@ -5,8 +5,8 @@ from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from ..config import CausalityAxis
from .attention import AttentionType, make_attn
from .causality_axis import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in MLX conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array) -> mx.array:
"""
@@ -116,10 +118,14 @@ def build_downsampling_path(
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if i_level != num_resolutions - 1:
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
stage["downsample"] = Downsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res = curr_res // 2
down_modules[i_level] = stage

View File

@@ -51,7 +51,9 @@ def build_normalization_layer(
A normalization layer
"""
if normtype == NormType.GROUP:
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
return nn.GroupNorm(
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
)
if normtype == NormType.PIXEL:
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
# PyTorch uses dim=1 for channels-first format (B, C, H, W)

View File

@@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module):
self.latent_channels = latent_channels
# Initialize buffers - will be loaded from weights
# Using underscores for MLX compatibility with weight loading
self._std_of_means = mx.ones((latent_channels,))
self._mean_of_means = mx.zeros((latent_channels,))
self.std_of_means = mx.ones((latent_channels,))
self.mean_of_means = mx.zeros((latent_channels,))
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latent representation."""
# Broadcast statistics to match x shape
# x shape: (B, C, ...) or (B, ..., C)
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x * std) + mean
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latent representation."""
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x - mean) / std

View File

@@ -1,12 +1,12 @@
"""ResNet blocks for audio VAE and vocoder."""
from typing import List, Tuple
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from ..config import CausalityAxis
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
self.conv1 = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
if temb_channels > 0:
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
self.dropout_rate = dropout
self.conv2 = make_conv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
out_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=1,
stride=1,
causality_axis=causality_axis,
)
def __call__(
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
if temb is not None and self.temb_channels > 0:
# temb: (B, temb_channels) -> (B, out_channels)
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
h = h + mx.expand_dims(
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
)
h = self.norm2(h)
h = nn.silu(h)

View File

@@ -5,9 +5,9 @@ from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from ..config import CausalityAxis
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
in_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
def __call__(self, x: mx.array) -> mx.array:
@@ -124,10 +128,14 @@ def build_upsampling_path(
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if level != 0:
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
stage["upsample"] = Upsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res *= 2
up_modules[level] = stage

View File

@@ -0,0 +1,737 @@
"""Vocoder for converting mel spectrograms to audio waveforms.
Supports:
- HiFi-GAN (LTX-2): ResBlock1 with LeakyReLU
- BigVGAN v2 (LTX-2.3): AMPBlock1 with Snake/SnakeBeta + anti-aliased resampling
- VocoderWithBWE (LTX-2.3): Base vocoder + bandwidth extension (16kHz -> 48kHz)
"""
import math
from pathlib import Path
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from ..config import VocoderModelConfig
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
# ---------------------------------------------------------------------------
# Snake / SnakeBeta activations (BigVGAN v2)
# ---------------------------------------------------------------------------
class Snake(nn.Module):
"""Snake activation: x + (1/alpha) * sin^2(alpha * x)."""
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
alpha = self.alpha # (C,)
if self.alpha_logscale:
alpha = mx.exp(alpha)
return x + (1.0 / (alpha + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
class SnakeBeta(nn.Module):
"""SnakeBeta activation: x + (1/beta) * sin^2(alpha * x)."""
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
self.beta = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
def __call__(self, x: mx.array) -> mx.array:
alpha = self.alpha
beta = self.beta
if self.alpha_logscale:
alpha = mx.exp(alpha)
beta = mx.exp(beta)
return x + (1.0 / (beta + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling (Kaiser-sinc filters)
# ---------------------------------------------------------------------------
def _sinc(x: mx.array) -> mx.array:
return mx.where(
x == 0,
mx.ones_like(x),
mx.sin(mx.array(math.pi) * x) / (mx.array(math.pi) * x),
)
def kaiser_sinc_filter1d(
cutoff: float, half_width: float, kernel_size: int
) -> mx.array:
"""Compute a Kaiser-windowed sinc filter."""
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if amplitude > 50.0:
beta = 0.1102 * (amplitude - 8.7)
elif amplitude >= 21.0:
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
else:
beta = 0.0
# Kaiser window - compute using scipy-compatible formula
import numpy as np
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
if even:
time = mx.arange(-half_size, half_size).astype(mx.float32) + 0.5
else:
time = mx.arange(kernel_size).astype(mx.float32) - half_size
if cutoff == 0:
filter_ = mx.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ = filter_ / mx.sum(filter_)
return filter_.reshape(1, 1, kernel_size)
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
import numpy as np
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
kernel_size = 2 * width * ratio + 1
pad = width
pad_left = 2 * width * ratio
pad_right = kernel_size - ratio
time = (np.arange(kernel_size) / ratio - width) * rolloff
time_clamped = np.clip(time, -lowpass_filter_width, lowpass_filter_width)
window = np.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
sinc_filter = np.sinc(time) * window * rolloff / ratio
filter_ = mx.array(sinc_filter.astype(np.float32)).reshape(1, 1, kernel_size)
return filter_, pad, pad_left, pad_right
class LowPassFilter1d(nn.Module):
"""Low-pass filter using depthwise convolution with Kaiser-sinc kernel."""
def __init__(
self,
cutoff: float = 0.5,
half_width: float = 0.6,
stride: int = 1,
kernel_size: int = 12,
) -> None:
super().__init__()
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
# Filter buffer - shape (1, 1, K) in PyTorch format, loaded from weights
self.filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
n, l, c = x.shape
# Pad with edge values: replicate first/last value
first = mx.repeat(x[:, :1, :], self.pad_left, axis=1)
last = mx.repeat(x[:, -1:, :], self.pad_right, axis=1)
x = mx.concatenate([first, x, last], axis=1)
# Expand filter for depthwise conv: (1, 1, K) -> (C, K, 1) for groups=C
# Filter is stored in PyTorch format (1, 1, K), need (C, K, 1) MLX format
filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
filt = mx.repeat(filt, c, axis=0) # (C, K, 1)
# Transpose x for depthwise conv: (N, L, C) -> (N*C, L, 1) then conv
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
x = mx.conv1d(x, filt[:1], stride=self.stride, groups=1) # (N*C, L', 1)
x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
return x
class UpSample1d(nn.Module):
"""Anti-aliased upsampling using transposed convolution with sinc filter."""
def __init__(
self,
ratio: int = 2,
kernel_size: int = None,
window_type: str = "kaiser",
) -> None:
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
filt, self.pad, self.pad_left, self.pad_right = hann_sinc_filter1d(ratio)
self.kernel_size = filt.shape[2]
self.filter = filt
else:
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.pad = self.kernel_size // ratio - 1
self.pad_left = (
self.pad * self.stride + (self.kernel_size - self.stride) // 2
)
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
self.filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size,
)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
n, l, c = x.shape
# Pad with edge values
first = mx.repeat(x[:, :1, :], self.pad, axis=1)
last = mx.repeat(x[:, -1:, :], self.pad, axis=1)
x = mx.concatenate([first, x, last], axis=1)
# Process per-channel via reshape: (N, L, C) -> (N*C, L, 1)
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
# Transposed conv for upsampling
# Filter: (1, 1, K) PyTorch -> (1, K, 1) MLX format for conv_transpose1d
filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
x = self.ratio * mx.conv_transpose1d(
x, filt, stride=self.stride
) # (N*C, L', 1)
# Trim padding
x = x[:, self.pad_left : -self.pad_right, :]
x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
return x
class DownSample1d(nn.Module):
"""Anti-aliased downsampling using low-pass filter."""
def __init__(self, ratio: int = 2, kernel_size: int = None) -> None:
super().__init__()
self.ratio = ratio
kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=kernel_size,
)
def __call__(self, x: mx.array) -> mx.array:
return self.lowpass(x)
class Activation1d(nn.Module):
"""Anti-aliased activation: upsample -> activate -> downsample."""
def __init__(
self,
activation: nn.Module,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
) -> None:
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def __call__(self, x: mx.array) -> mx.array:
x = self.upsample(x)
x = self.act(x)
return self.downsample(x)
# ---------------------------------------------------------------------------
# AMPBlock1 (BigVGAN v2 residual block)
# ---------------------------------------------------------------------------
class AMPBlock1(nn.Module):
"""BigVGAN v2 residual block with anti-aliased Snake activations."""
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
activation: str = "snakebeta",
) -> None:
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = {
i: nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
)
for i, d in enumerate(dilation)
}
self.convs2 = {
i: nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
for i in range(len(dilation))
}
self.acts1 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
self.acts2 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
def __call__(self, x: mx.array) -> mx.array:
for i in range(len(self.convs1)):
xt = self.acts1[i](x)
xt = self.convs1[i](xt)
xt = self.acts2[i](xt)
xt = self.convs2[i](xt)
x = x + xt
return x
# ---------------------------------------------------------------------------
# STFT and MelSTFT (for BWE)
# ---------------------------------------------------------------------------
class STFTFn(nn.Module):
"""STFT via conv1d with precomputed DFT x window bases (loaded from checkpoint)."""
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
# Buffers loaded from checkpoint - PyTorch format (n_freqs*2, 1, filter_length)
self.forward_basis = mx.zeros((n_freqs * 2, 1, filter_length))
self.inverse_basis = mx.zeros((n_freqs * 2, 1, filter_length))
def __call__(self, y: mx.array) -> Tuple[mx.array, mx.array]:
"""Compute magnitude and phase from waveform.
Args:
y: (B, T) waveform
Returns:
magnitude: (B, n_freqs, T_frames)
phase: (B, n_freqs, T_frames)
"""
if y.ndim == 2:
y = mx.expand_dims(y, -1) # (B, T, 1)
left_pad = max(0, self.win_length - self.hop_length)
if left_pad > 0:
first = mx.repeat(y[:, :1, :], left_pad, axis=1)
y = mx.concatenate([first, y], axis=1)
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
basis = mx.transpose(
self.forward_basis.astype(y.dtype), (0, 2, 1)
) # (514, K, 1)
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
spec = mx.conv1d(y, basis, stride=self.hop_length)
# Split real and imaginary
n_freqs = spec.shape[-1] // 2
real = spec[..., :n_freqs]
imag = spec[..., n_freqs:]
magnitude = mx.sqrt(real**2 + imag**2)
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(
real.dtype
)
# Output: (B, T_frames, n_freqs) in MLX channels-last
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram from precomputed STFT bases."""
def __init__(
self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int
) -> None:
super().__init__()
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
self.mel_basis = mx.zeros((n_mel_channels, n_freqs))
def mel_spectrogram(self, y: mx.array) -> mx.array:
"""Compute log-mel spectrogram.
Args:
y: (B, T) waveform
Returns:
log_mel: (B, n_mels, T_frames) in channels-first for compatibility
"""
magnitude, phase = self.stft_fn(y)
# magnitude: (B, T_frames, n_freqs)
mel = (
magnitude @ self.mel_basis.astype(magnitude.dtype).T
) # (B, T_frames, n_mels)
log_mel = mx.log(mx.clip(mel, 1e-5, None))
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
return mx.transpose(log_mel, (0, 2, 1))
# ---------------------------------------------------------------------------
# Vocoder (supports both HiFi-GAN and BigVGAN v2)
# ---------------------------------------------------------------------------
class Vocoder(nn.Module):
"""Vocoder for mel-to-waveform synthesis.
Supports resblock="1" (HiFi-GAN / LTX-2) and resblock="AMP1" (BigVGAN v2 / LTX-2.3).
"""
def __init__(self, config: VocoderModelConfig) -> None:
super().__init__()
self.output_sampling_rate = config.output_sample_rate
self.num_kernels = len(config.resblock_kernel_sizes)
self.num_upsamples = len(config.upsample_rates)
self.upsample_rates = config.upsample_rates
self.is_amp = config.resblock == "AMP1"
self.use_tanh_at_final = config.use_tanh_at_final
self.apply_final_activation = config.apply_final_activation
in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(
in_channels,
config.upsample_initial_channel,
kernel_size=7,
stride=1,
padding=3,
)
# Upsampling layers
self.ups = {}
for i, (stride, kernel_size) in enumerate(
zip(config.upsample_rates, config.upsample_kernel_sizes)
):
in_ch = config.upsample_initial_channel // (2**i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
# Residual blocks
if self.is_amp:
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = AMPBlock1(
ch,
kernel_size,
tuple(dilations),
activation=config.activation,
)
block_idx += 1
else:
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = resblock_class(
ch, kernel_size, tuple(dilations)
)
block_idx += 1
final_channels = config.upsample_initial_channel // (
2 ** len(config.upsample_rates)
)
# Post-activation
if self.is_amp:
act_cls = SnakeBeta if config.activation == "snakebeta" else Snake
self.act_post = Activation1d(act_cls(final_channels))
# Final conv
out_channels = 2 if config.stereo else 1
self.conv_post = nn.Conv1d(
final_channels,
out_channels,
kernel_size=7,
stride=1,
padding=3,
bias=config.use_bias_at_final,
)
self.upsample_factor = math.prod(config.upsample_rates)
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass.
Args:
x: Mel spectrogram (B, C, T, mel_bins) for stereo or (B, T, mel_bins) mono.
Returns:
Waveform (B, out_channels, T_audio) in channels-first format.
"""
# (B, C, T, mel) -> (B, C, mel, T)
x = mx.transpose(x, (0, 1, 3, 2))
if x.ndim == 4: # stereo: (B, 2, mel, T) -> (B, 2*mel, T)
b, s, c, t = x.shape
x = x.reshape(b, s * c, t)
# Channels-first (B, C, T) -> channels-last (B, T, C) for MLX conv
x = mx.transpose(x, (0, 2, 1))
x = self.conv_pre(x)
for i in range(self.num_upsamples):
if not self.is_amp:
x = leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
start = i * self.num_kernels
end = start + self.num_kernels
block_outputs = mx.stack(
[self.resblocks[idx](x) for idx in range(start, end)],
axis=0,
)
x = mx.mean(block_outputs, axis=0)
if self.is_amp:
x = self.act_post(x)
else:
x = nn.leaky_relu(x)
x = self.conv_post(x)
if self.apply_final_activation:
x = mx.tanh(x) if self.use_tanh_at_final else mx.clip(x, -1, 1)
# Back to channels-first (B, T, C) -> (B, C, T)
x = mx.transpose(x, (0, 2, 1))
return x
# ---------------------------------------------------------------------------
# VocoderWithBWE (Bandwidth Extension)
# ---------------------------------------------------------------------------
class VocoderWithBWE(nn.Module):
"""Vocoder + bandwidth extension upsampling (16kHz -> 48kHz).
Chains a base vocoder with a BWE generator that predicts a residual
added to a sinc-resampled skip connection.
"""
def __init__(
self,
vocoder: Vocoder,
bwe_generator: Vocoder,
mel_stft: MelSTFT,
input_sampling_rate: int = 16000,
output_sampling_rate: int = 48000,
hop_length: int = 80,
) -> None:
super().__init__()
self.vocoder = vocoder
self.bwe_generator = bwe_generator
self.mel_stft = mel_stft
self.input_sampling_rate = input_sampling_rate
self.output_sampling_rate = output_sampling_rate
self.hop_length = hop_length
# Hann-windowed sinc resampler (not stored in checkpoint)
self.resampler = UpSample1d(
ratio=output_sampling_rate // input_sampling_rate,
window_type="hann",
)
@property
def output_sample_rate(self) -> int:
return self.output_sampling_rate
def _compute_mel(self, audio: mx.array) -> mx.array:
"""Compute log-mel spectrogram from waveform.
Args:
audio: (B, C, T) waveform in channels-first
Returns:
mel: (B, C, n_mels, T_frames)
"""
batch, n_channels, _ = audio.shape
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
mel = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2])
def __call__(self, mel_spec: mx.array) -> mx.array:
"""Run vocoder + BWE.
Args:
mel_spec: Mel spectrogram, same format as Vocoder.forward input.
Returns:
Waveform (B, out_channels, T_audio) at output_sampling_rate.
"""
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
_, _, length_low_rate = x.shape
output_length = (
length_low_rate * self.output_sampling_rate // self.input_sampling_rate
)
# Pad to hop_length multiple
remainder = length_low_rate % self.hop_length
if remainder != 0:
pad_amount = self.hop_length - remainder
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_amount)])
# Compute mel from vocoder output: (B, C, n_mels, T_frames)
mel = self._compute_mel(x)
# BWE expects (B, C, T_frames, mel_bins) -> transpose last two dims
mel_for_bwe = mx.transpose(mel, (0, 1, 3, 2)) # (B, C, T_frames, n_mels)
residual = self.bwe_generator(mel_for_bwe) # (B, C, T_high)
# Sinc upsample skip connection
# resampler expects (N, L, C): transpose from (B, C, T) -> (B, T, C)
x_for_resample = mx.transpose(x, (0, 2, 1))
skip = self.resampler(x_for_resample)
skip = mx.transpose(skip, (0, 2, 1)) # back to (B, C, T)
return mx.clip(residual + skip, -1, 1)[..., :output_length]
# ---------------------------------------------------------------------------
# Factory / from_pretrained
# ---------------------------------------------------------------------------
def load_vocoder(model_path: Path) -> nn.Module:
"""Load vocoder from pretrained model directory.
Automatically detects whether to load a simple Vocoder or VocoderWithBWE.
"""
import json
config_path = model_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {model_path}")
with open(config_path) as f:
config_dict = json.load(f)
weights = mx.load(str(model_path / "model.safetensors"))
has_bwe = config_dict.get("has_bwe_generator", False)
if has_bwe:
return _load_vocoder_with_bwe(config_dict, weights)
else:
config = VocoderModelConfig.from_dict(config_dict)
model = Vocoder(config)
model.load_weights(list(weights.items()), strict=True)
return model
def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
"""Load VocoderWithBWE from config and weights."""
# Build vocoder from config
vocoder_cfg = config_dict.get("vocoder", {})
vocoder_config = VocoderModelConfig.from_dict(vocoder_cfg)
vocoder = Vocoder(vocoder_config)
# Build BWE generator from config
bwe_cfg = config_dict.get("bwe", {})
bwe_config = VocoderModelConfig.from_dict(bwe_cfg)
bwe_config.apply_final_activation = False
bwe_generator = Vocoder(bwe_config)
# MelSTFT from weight shapes
stft_basis = weights.get("mel_stft.stft_fn.forward_basis")
filter_length = stft_basis.shape[2] if stft_basis is not None else 512
mel_basis = weights.get("mel_stft.mel_basis")
n_mel_channels = mel_basis.shape[0] if mel_basis is not None else 64
hop_length = bwe_cfg.get("hop_length", 80)
input_sr = bwe_cfg.get("input_sampling_rate", 16000)
output_sr = bwe_cfg.get("output_sampling_rate", 48000)
mel_stft = MelSTFT(
filter_length=filter_length,
hop_length=hop_length,
win_length=filter_length,
n_mel_channels=n_mel_channels,
)
model = VocoderWithBWE(
vocoder=vocoder,
bwe_generator=bwe_generator,
mel_stft=mel_stft,
input_sampling_rate=input_sr,
output_sampling_rate=output_sr,
hop_length=hop_length,
)
model.load_weights(list(weights.items()), strict=False)
return model

View File

@@ -0,0 +1,6 @@
"""Conditioning modules for LTX-2 video generation."""
from mlx_video.models.ltx_2.conditioning.latent import (
VideoConditionByLatentIndex,
apply_conditioning,
)

View File

@@ -5,7 +5,7 @@ the video generation process at specific frame positions.
"""
from dataclasses import dataclass
from typing import Optional, List, Tuple
from typing import List, Optional, Tuple
import mlx.core as mx
@@ -22,6 +22,7 @@ class VideoConditionByLatentIndex:
frame_idx: Frame index to condition (0 = first frame)
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
"""
latent: mx.array
frame_idx: int = 0
strength: float = 1.0
@@ -41,6 +42,7 @@ class LatentState:
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
1.0 = full denoise, 0.0 = keep clean
"""
latent: mx.array
clean_latent: mx.array
denoise_mask: mx.array
@@ -130,15 +132,15 @@ def apply_conditioning(
if frame_idx <= i < end_idx:
# Use conditioning latent
cond_idx = i - frame_idx
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
# Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
else:
# Keep original
latent_list.append(state.latent[:, :, i:i+1])
clean_list.append(state.clean_latent[:, :, i:i+1])
mask_list.append(state.denoise_mask[:, :, i:i+1])
latent_list.append(state.latent[:, :, i : i + 1])
clean_list.append(state.clean_latent[:, :, i : i + 1])
mask_list.append(state.denoise_mask[:, :, i : i + 1])
state.latent = mx.concatenate(latent_list, axis=2)
state.clean_latent = mx.concatenate(clean_list, axis=2)

View File

@@ -0,0 +1,393 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional, Tuple
class LTXModelType(Enum):
AudioVideo = "ltx av model"
VideoOnly = "ltx video only model"
AudioOnly = "ltx audio only model"
def is_video_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
def is_audio_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
class LTXRopeType(Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
TWO_D = "2d"
class AttentionType(Enum):
DEFAULT = "default"
@dataclass
class BaseModelConfig:
@classmethod
def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig":
"""Create config from dictionary, filtering only valid parameters."""
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def to_dict(self) -> dict[str, Any]:
"""Export config to dictionary."""
result = {}
for k, v in self.__dict__.items():
if v is not None:
if isinstance(v, Enum):
result[k] = v.value
elif hasattr(v, "to_dict"):
result[k] = v.to_dict()
else:
result[k] = v
return result
@dataclass
class TransformerConfig(BaseModelConfig):
dim: int
heads: int
d_head: int
context_dim: int
@dataclass
class VideoVAEConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels: int = 3
out_channels: int = 128
latent_channels: int = 128
patch_size: int = 4
encoder_blocks: List[tuple] = field(
default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
)
decoder_blocks: List[tuple] = field(
default_factory=lambda: [
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
]
)
@dataclass
class LTXModelConfig(BaseModelConfig):
# Model type
model_type: LTXModelType = LTXModelType.AudioVideo
# Video transformer config
num_attention_heads: int = 32
attention_head_dim: int = 128
in_channels: int = 128
out_channels: int = 128
num_layers: int = 48
cross_attention_dim: int = 4096
caption_channels: int = 3840
# Audio transformer config
audio_num_attention_heads: int = 32
audio_attention_head_dim: int = 64
audio_in_channels: int = 128
audio_out_channels: int = 128
audio_cross_attention_dim: int = 2048
audio_caption_channels: int = (
3840 # Input dim for audio text embeddings (same as video)
)
# Positional embedding config
positional_embedding_theta: float = 10000.0
positional_embedding_max_pos: Optional[List[int]] = None
audio_positional_embedding_max_pos: Optional[List[int]] = None
use_middle_indices_grid: bool = True
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED
double_precision_rope: bool = False
# Timestep config
timestep_scale_multiplier: int = 1000
av_ca_timestep_scale_multiplier: int = 1000
# Normalization
norm_eps: float = 1e-6
# Attention type
attention_type: AttentionType = AttentionType.DEFAULT
# LTX-2.3: prompt-conditioned adaptive layer norm
# Controls: gate_logits in attention, 9-param scale_shift_table,
# prompt_adaln_single, per-block prompt_scale_shift_table,
# removal of caption_projection
has_prompt_adaln: bool = False
# VAE config
vae_config: Optional[VideoVAEConfig] = None
def __post_init__(self):
"""Set default values after initialization."""
if self.positional_embedding_max_pos is None:
self.positional_embedding_max_pos = [20, 2048, 2048]
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20]
# PyTorch LTX-2 configurator reads "frequencies_precision" (not
# "double_precision_rope") from the config. For LTX-2 (no prompt adaln)
# the key is absent, so double_precision_rope = False. For LTX-2.3
# (has_prompt_adaln=True) the safetensors config has
# frequencies_precision="float64", so double_precision_rope = True.
if not self.has_prompt_adaln:
self.double_precision_rope = False
# Convert string enum values if loading from dict
if isinstance(self.model_type, str):
self.model_type = LTXModelType(self.model_type)
if isinstance(self.rope_type, str):
self.rope_type = LTXRopeType(self.rope_type)
if isinstance(self.attention_type, str):
self.attention_type = AttentionType(self.attention_type)
@property
def inner_dim(self) -> int:
"""Video inner dimension."""
return self.num_attention_heads * self.attention_head_dim
@property
def audio_inner_dim(self) -> int:
"""Audio inner dimension."""
return self.audio_num_attention_heads * self.audio_attention_head_dim
def get_video_config(self) -> Optional[TransformerConfig]:
"""Get video transformer configuration."""
if not self.model_type.is_video_enabled():
return None
return TransformerConfig(
dim=self.inner_dim,
heads=self.num_attention_heads,
d_head=self.attention_head_dim,
context_dim=self.cross_attention_dim,
)
def get_audio_config(self) -> Optional[TransformerConfig]:
"""Get audio transformer configuration."""
if not self.model_type.is_audio_enabled():
return None
return TransformerConfig(
dim=self.audio_inner_dim,
heads=self.audio_num_attention_heads,
d_head=self.audio_attention_head_dim,
context_dim=self.audio_cross_attention_dim,
)
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
@dataclass
class AudioDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int | None = None
resamp_with_conv: bool = True
attn_type: str = None
give_pre_end: bool = False
tanh_out: bool = False
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
# Import here to avoid circular imports
from .audio_vae.attention import AttentionType
from .audio_vae.normalization import NormType
# Convert causality_axis string to enum
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
# Convert norm_type string to enum
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
# Convert attn_type string to enum
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class AudioEncoderModelConfig(BaseModelConfig):
ch: int = 128
in_channels: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
double_z: bool = True
n_fft: int = 1024
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int = 64
resamp_with_conv: bool = True
attn_type: str = None
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
from .audio_vae.attention import AttentionType
from .audio_vae.normalization import NormType
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class VocoderModelConfig(BaseModelConfig):
resblock_kernel_sizes: Optional[List[int]] = None
upsample_rates: Optional[List[int]] = None
upsample_kernel_sizes: Optional[List[int]] = None
resblock_dilation_sizes: Optional[List[List[int]]] = None
upsample_initial_channel: int = 1024
stereo: bool = True
resblock: str = "1"
output_sample_rate: int = 24000
activation: str = "snake"
use_tanh_at_final: bool = True
apply_final_activation: bool = True
use_bias_at_final: bool = True
def __post_init__(self):
if self.resblock_kernel_sizes is None:
self.resblock_kernel_sizes = [3, 7, 11]
if self.upsample_rates is None:
self.upsample_rates = [6, 5, 2, 2, 2]
if self.upsample_kernel_sizes is None:
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
if self.resblock_dilation_sizes is None:
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@dataclass
class VideoDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
timestep_conditioning: bool = False
@dataclass
class VideoEncoderModelConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels: int = 3
out_channels: int = 128
patch_size: int = 4
norm_layer: Enum = None
latent_log_var: Enum = None
encoder_spatial_padding_mode: Enum = None
encoder_blocks: List[tuple] = field(
default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
)
def __post_init__(self):
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM
if self.latent_log_var is None:
self.latent_log_var = LogVarianceType.UNIFORM
if self.encoder_spatial_padding_mode is None:
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
if isinstance(self.norm_layer, str):
self.norm_layer = NormLayerType(self.norm_layer)
if isinstance(self.latent_log_var, str):
self.latent_log_var = LogVarianceType(self.latent_log_var)
if isinstance(self.encoder_spatial_padding_mode, str):
self.encoder_spatial_padding_mode = PaddingModeType(
self.encoder_spatial_padding_mode
)
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.encoder_blocks is not None:
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
return result

View File

@@ -0,0 +1,857 @@
"""Convert LTX-2/2.3 safetensors to MLX directory layout.
Converts from the single-file format (e.g. Lightricks/LTX-2/ltx-2-19b-distilled.safetensors
or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular directory structure:
output/
├── transformer/ # DiT transformer weights (sharded)
│ ├── config.json
│ ├── model-00001-of-N.safetensors
│ └── model.safetensors.index.json
├── vae/
│ ├── decoder/ # Video VAE decoder
│ │ ├── config.json
│ │ └── model.safetensors
│ └── encoder/ # Video VAE encoder
│ ├── config.json
│ └── model.safetensors
├── audio_vae/
│ ├── decoder/ # Audio VAE decoder
│ │ ├── config.json
│ │ └── model.safetensors
│ └── encoder/ # Audio VAE encoder
│ ├── config.json
│ └── model.safetensors
├── vocoder/ # Audio vocoder
│ ├── config.json
│ └── model.safetensors
└── text_projections/ # Text projection connectors
└── model.safetensors
Usage:
# From HF repo ID
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
# From local folder containing the monolithic safetensors
python -m mlx_video.models.ltx_2.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
# From a direct safetensors file path
python -m mlx_video.models.ltx_2.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
"""
import argparse
import json
import re
import shutil
from pathlib import Path
from typing import Dict
import mlx.core as mx
# ─── Key prefix routing ──────────────────────────────────────────────────────
TRANSFORMER_PREFIX = "model.diffusion_model."
VAE_DECODER_PREFIX = "vae.decoder."
VAE_ENCODER_PREFIX = "vae.encoder."
VAE_STATS_PREFIX = "vae.per_channel_statistics."
AUDIO_DECODER_PREFIX = "audio_vae.decoder."
AUDIO_ENCODER_PREFIX = "audio_vae.encoder."
AUDIO_STATS_PREFIX = "audio_vae.per_channel_statistics."
VOCODER_PREFIX = "vocoder."
TEXT_PROJ_PREFIX = "text_embedding_projection."
VIDEO_CONNECTOR_PREFIX = "model.diffusion_model.video_embeddings_connector."
AUDIO_CONNECTOR_PREFIX = "model.diffusion_model.audio_embeddings_connector."
# ─── Sanitization functions ──────────────────────────────────────────────────
def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize transformer keys: strip prefix, rename layers, cast to bfloat16."""
sanitized = {}
for key, value in weights.items():
if not key.startswith(TRANSFORMER_PREFIX):
continue
# Skip connector weights (they go to text_projections)
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
new_key = key[len(TRANSFORMER_PREFIX) :]
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Cast all weights to bfloat16 (matches MLX model loading behavior)
if value.dtype != mx.bfloat16:
value = value.astype(mx.bfloat16)
sanitized[new_key] = value
return sanitized
def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE decoder keys: strip prefix, transpose Conv3d, wrap .conv."""
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith(VAE_STATS_PREFIX):
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
elif key.startswith(VAE_DECODER_PREFIX):
new_key = key[len(VAE_DECODER_PREFIX) :]
else:
continue
# Conv3d weight transpose: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Wrap .conv.weight -> .conv.conv.weight (CausalConv3d wrapper)
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder keys: strip prefix, transpose Conv3d/Conv2d."""
sanitized = {}
for key, value in weights.items():
new_key = None
if "position_ids" in key:
continue
if key.startswith(VAE_STATS_PREFIX):
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
# Per-channel statistics must stay float32 for precision
if value.dtype != mx.float32:
value = value.astype(mx.float32)
elif key.startswith(VAE_ENCODER_PREFIX):
new_key = key[len(VAE_ENCODER_PREFIX) :]
else:
continue
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE decoder keys: strip prefix, transpose Conv2d."""
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith(AUDIO_DECODER_PREFIX):
new_key = key[len(AUDIO_DECODER_PREFIX) :]
elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
else:
continue
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE encoder keys: strip prefix, transpose Conv2d."""
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith(AUDIO_ENCODER_PREFIX):
new_key = key[len(AUDIO_ENCODER_PREFIX) :]
elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif key == "latents_mean":
new_key = "per_channel_statistics.mean_of_means"
elif key == "latents_std":
new_key = "per_channel_statistics.std_of_means"
else:
continue
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize vocoder keys: strip prefix, transpose Conv1d/ConvTranspose1d."""
sanitized = {}
for key, value in weights.items():
if not key.startswith(VOCODER_PREFIX):
continue
new_key = key[len(VOCODER_PREFIX) :]
# Handle Conv1d/ConvTranspose1d weight shape conversion
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
def sanitize_connector_key(key: str) -> str:
"""Sanitize connector sub-key names."""
key = key.replace(".ff.net.0.proj.", ".ff.proj_in.")
key = key.replace(".ff.net.2.", ".ff.proj_out.")
key = key.replace(".to_out.0.", ".to_out.")
return key
def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Extract text projection weights (aggregate_embed + connectors).
Handles both LTX-2 (aggregate_embed.weight) and LTX-2.3
(video_aggregate_embed.*, audio_aggregate_embed.*) formats.
"""
extracted = {}
# aggregate_embed weights (text_embedding_projection.*)
for key, value in weights.items():
if key.startswith(TEXT_PROJ_PREFIX):
new_key = key[len(TEXT_PROJ_PREFIX) :]
extracted[new_key] = value
# video_embeddings_connector
for key, value in weights.items():
if key.startswith(VIDEO_CONNECTOR_PREFIX):
suffix = key[len(VIDEO_CONNECTOR_PREFIX) :]
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value
# audio_embeddings_connector
for key, value in weights.items():
if key.startswith(AUDIO_CONNECTOR_PREFIX):
suffix = key[len(AUDIO_CONNECTOR_PREFIX) :]
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value
return extracted
# ─── Saving utilities ─────────────────────────────────────────────────────────
def save_sharded(
weights: Dict[str, mx.array],
output_dir: Path,
max_shard_size_bytes: int = 5 * 1024 * 1024 * 1024, # 5GB per shard
):
"""Save weights as sharded safetensors with an index file."""
output_dir.mkdir(parents=True, exist_ok=True)
# Sort keys for deterministic output
sorted_keys = sorted(weights.keys())
# Calculate total size
total_size = sum(weights[k].nbytes for k in sorted_keys)
# Determine sharding
shards = []
current_shard = {}
current_size = 0
for key in sorted_keys:
tensor = weights[key]
tensor_size = tensor.nbytes
if current_size + tensor_size > max_shard_size_bytes and current_shard:
shards.append(current_shard)
current_shard = {}
current_size = 0
current_shard[key] = tensor
current_size += tensor_size
if current_shard:
shards.append(current_shard)
num_shards = len(shards)
weight_map = {}
for i, shard in enumerate(shards):
if num_shards == 1:
filename = "model.safetensors"
else:
filename = f"model-{i+1:05d}-of-{num_shards:05d}.safetensors"
mx.save_safetensors(str(output_dir / filename), shard)
for key in shard:
weight_map[key] = filename
# Write index
index = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
with open(output_dir / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2, sort_keys=True)
return num_shards
def save_single(weights: Dict[str, mx.array], output_dir: Path):
"""Save weights as a single safetensors file with an index."""
output_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(output_dir / "model.safetensors"), weights)
# Also write index for consistency
total_size = sum(v.nbytes for v in weights.values())
weight_map = {k: "model.safetensors" for k in sorted(weights.keys())}
index = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
with open(output_dir / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2, sort_keys=True)
def save_config(config: dict, output_dir: Path):
"""Save config.json to a directory."""
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=4)
f.write("\n")
# ─── Source resolution ─────────────────────────────────────────────────────────
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
MONOLITHIC_PATTERN = re.compile(
r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$"
)
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
UPSCALER_PATTERN = re.compile(
r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$"
)
def resolve_source(source: str, variant: str) -> Path:
"""Resolve source to a monolithic safetensors file path.
Args:
source: HF repo ID (e.g. "Lightricks/LTX-2"), local directory, or direct file path.
variant: Model variant ("distilled" or "dev") to select the right file.
Returns:
Path to the monolithic safetensors file.
"""
source_path = Path(source)
# Direct file path
if source_path.is_file():
return source_path
# Local directory — find the variant's safetensors file
if source_path.is_dir():
matches = []
for f in sorted(source_path.glob("ltx-*b-*.safetensors")):
m = MONOLITHIC_PATTERN.match(f.name)
if m and m.group("variant") == variant:
matches.append(f)
if matches:
return matches[0]
# Broader fallback
all_mono = sorted(source_path.glob("ltx-*.safetensors"))
for f in all_mono:
if variant in f.name and MONOLITHIC_PATTERN.match(f.name):
return f
raise FileNotFoundError(
f"No monolithic *-{variant}.safetensors found in {source_path}. "
f"Files found: {[f.name for f in all_mono]}"
)
# HF repo ID — download via huggingface_hub
if "/" in source and not source_path.exists():
from huggingface_hub import hf_hub_download, list_repo_files
# Find the right file in the repo
repo_files = list_repo_files(source)
target = None
for f in repo_files:
m = MONOLITHIC_PATTERN.match(f)
if m and m.group("variant") == variant:
target = f
break
if not target:
raise FileNotFoundError(
f"No *-{variant}.safetensors found in {source}. "
f"Available: {[f for f in repo_files if f.endswith('.safetensors')]}"
)
print(f"Downloading {target} from {source}...")
local_path = hf_hub_download(repo_id=source, filename=target)
return Path(local_path)
raise FileNotFoundError(
f"Source not found: {source}. Provide an HF repo ID, local directory, or file path."
)
# ─── Config inference ─────────────────────────────────────────────────────────
def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
"""Infer transformer config from weight shapes."""
# Count transformer layers
max_layer = -1
for key in weights:
if "transformer_blocks." in key:
parts = key.split(".")
try:
idx = parts.index("transformer_blocks") + 1
if idx < len(parts) and parts[idx].isdigit():
max_layer = max(max_layer, int(parts[idx]))
except ValueError:
pass
num_layers = max_layer + 1 if max_layer >= 0 else 48
# Detect cross_attention_dim from attn2.to_k (cross-attention input dim)
cross_attention_dim = 4096
for key, value in weights.items():
if "transformer_blocks.0.attn2.to_k.weight" in key:
cross_attention_dim = value.shape[-1]
break
# Check for prompt_adaln_single (LTX-2.3 feature)
has_prompt_adaln = any("prompt_adaln_single" in k for k in weights)
config = {
"attention_head_dim": 128,
"attention_type": "default",
"audio_attention_head_dim": 64,
"audio_caption_channels": 3840,
"audio_cross_attention_dim": 2048,
"audio_in_channels": 128,
"audio_num_attention_heads": 32,
"audio_out_channels": 128,
"audio_positional_embedding_max_pos": [20],
"av_ca_timestep_scale_multiplier": 1000,
"caption_channels": 3840,
"cross_attention_dim": cross_attention_dim,
"double_precision_rope": True,
"in_channels": 128,
"model_type": "ltx av model",
"norm_eps": 1e-06,
"num_attention_heads": 32,
"num_layers": num_layers,
"out_channels": 128,
"positional_embedding_max_pos": [20, 2048, 2048],
"positional_embedding_theta": 10000.0,
"rope_type": "split",
"timestep_scale_multiplier": 1000,
"use_middle_indices_grid": True,
}
if has_prompt_adaln:
config["has_prompt_adaln"] = True
return config
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
"""Infer VAE decoder config from weights."""
# Check for timestep conditioning keys
has_timestep = any(
"last_time_embedder" in k or "last_scale_shift_table" in k for k in weights
)
# Count channel multipliers from up_blocks
max_block = -1
for key in weights:
if "up_blocks." in key:
parts = key.split(".")
try:
idx = parts.index("up_blocks") + 1
if idx < len(parts) and parts[idx].isdigit():
max_block = max(max_block, int(parts[idx]))
except ValueError:
pass
# Default config
config = {
"ch": 128,
"ch_mult": [1, 2, 4],
"dropout": 0.0,
"num_res_blocks": 2,
"out_ch": 2,
"resolution": 256,
"timestep_conditioning": has_timestep,
"z_channels": 8,
}
return config
def infer_vae_encoder_config(weights: Dict[str, mx.array]) -> dict:
"""Return VAE encoder config (architecture is consistent across versions)."""
return {
"convolution_dimensions": 3,
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_space_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
],
"encoder_spatial_padding_mode": "zeros",
"in_channels": 3,
"latent_log_var": "uniform",
"norm_layer": "pixel_norm",
"out_channels": 128,
"patch_size": 4,
}
def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict:
"""Return audio VAE config."""
return {
"attn_resolutions": [],
"attn_type": "vanilla",
"causality_axis": "height",
"ch": 128,
"ch_mult": [1, 2, 4],
"dropout": 0.0,
"give_pre_end": False,
"is_causal": True,
"mel_bins": 64,
"mel_hop_length": 160,
"mid_block_add_attention": False,
"norm_type": "pixel",
"num_res_blocks": 2,
"out_ch": 2,
"resamp_with_conv": True,
"resolution": 256,
"sample_rate": 16000,
"tanh_out": False,
"z_channels": 8,
}
def infer_audio_encoder_config(weights: Dict[str, mx.array]) -> dict:
"""Return audio encoder config (mirrors decoder but with encoder-specific fields)."""
return {
"attn_resolutions": [],
"attn_type": "vanilla",
"causality_axis": "height",
"ch": 128,
"ch_mult": [1, 2, 4],
"dropout": 0.0,
"in_channels": 2,
"double_z": True,
"is_causal": True,
"mel_bins": 64,
"mel_hop_length": 160,
"mid_block_add_attention": False,
"n_fft": 1024,
"norm_type": "pixel",
"num_res_blocks": 2,
"resamp_with_conv": True,
"resolution": 256,
"sample_rate": 16000,
"z_channels": 8,
}
def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict:
"""Infer vocoder config from weights."""
# Check for bwe_generator (LTX-2.3 BigVGAN vocoder)
has_bwe = any(k.startswith("bwe_generator") for k in weights)
if has_bwe:
return {
"type": "bigvgan",
"has_bwe_generator": True,
}
return {
"output_sample_rate": 24000,
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_kernel_sizes": [3, 7, 11],
"stereo": True,
"upsample_initial_channel": 1024,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_rates": [6, 5, 2, 2, 2],
}
# ─── Main ─────────────────────────────────────────────────────────────────────
def convert(source: str, output_path: Path, variant: str = "distilled"):
"""Convert monolithic safetensors to modular directory layout.
Args:
source: HF repo ID (e.g. "Lightricks/LTX-2"), local directory, or file path.
output_path: Output directory for the modular layout.
variant: "distilled" or "dev".
"""
source_path = resolve_source(source, variant)
print(f"Loading monolithic weights from {source_path.name}...")
all_weights = mx.load(str(source_path))
total_keys = len(all_weights)
print(f" Loaded {total_keys} keys")
# Route keys to components
print("\nExtracting components...")
# 1. Transformer
print(" [1/7] Transformer...")
transformer_weights = sanitize_transformer(all_weights)
num_shards = save_sharded(transformer_weights, output_path / "transformer")
config = infer_transformer_config(transformer_weights)
save_config(config, output_path / "transformer")
t_params = sum(v.size for v in transformer_weights.values())
print(
f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards"
)
# 2. VAE Decoder
print(" [2/7] VAE Decoder...")
vae_decoder_weights = sanitize_vae_decoder(all_weights)
save_single(vae_decoder_weights, output_path / "vae" / "decoder")
config = infer_vae_decoder_config(vae_decoder_weights, variant)
save_config(config, output_path / "vae" / "decoder")
d_params = sum(v.size for v in vae_decoder_weights.values())
print(f" {len(vae_decoder_weights)} keys, {d_params:,} params")
# 3. VAE Encoder
print(" [3/7] VAE Encoder...")
vae_encoder_weights = sanitize_vae_encoder(all_weights)
save_single(vae_encoder_weights, output_path / "vae" / "encoder")
config = infer_vae_encoder_config(vae_encoder_weights)
save_config(config, output_path / "vae" / "encoder")
e_params = sum(v.size for v in vae_encoder_weights.values())
print(f" {len(vae_encoder_weights)} keys, {e_params:,} params")
# 4. Audio VAE Decoder
print(" [4/7] Audio VAE Decoder...")
audio_decoder_weights = sanitize_audio_decoder(all_weights)
save_single(audio_decoder_weights, output_path / "audio_vae" / "decoder")
config = infer_audio_vae_config(audio_decoder_weights)
save_config(config, output_path / "audio_vae" / "decoder")
a_params = sum(v.size for v in audio_decoder_weights.values())
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
# 5. Audio VAE Encoder
print(" [5/7] Audio VAE Encoder...")
audio_encoder_weights = sanitize_audio_encoder(all_weights)
save_single(audio_encoder_weights, output_path / "audio_vae" / "encoder")
config = infer_audio_encoder_config(audio_encoder_weights)
save_config(config, output_path / "audio_vae" / "encoder")
ae_params = sum(v.size for v in audio_encoder_weights.values())
print(f" {len(audio_encoder_weights)} keys, {ae_params:,} params")
# 6. Vocoder
print(" [6/7] Vocoder...")
vocoder_weights = sanitize_vocoder(all_weights)
save_single(vocoder_weights, output_path / "vocoder")
config = infer_vocoder_config(vocoder_weights)
save_config(config, output_path / "vocoder")
v_params = sum(v.size for v in vocoder_weights.values())
print(f" {len(vocoder_weights)} keys, {v_params:,} params")
# 7. Text Projections
print(" [7/7] Text Projections...")
text_proj_weights = extract_text_projections(all_weights)
tp_dir = output_path / "text_projections"
tp_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(tp_dir / "model.safetensors"), text_proj_weights)
tp_params = sum(v.size for v in text_proj_weights.values())
print(f" {len(text_proj_weights)} keys, {tp_params:,} params")
# Copy upscaler files
print("\nCopying upscaler files...")
source_dir = source_path.parent
is_hf_repo = "/" in source and not Path(source).exists()
upscaler_files = []
if is_hf_repo:
from huggingface_hub import list_repo_files
upscaler_files = [
f for f in list_repo_files(source) if UPSCALER_PATTERN.match(f)
]
else:
upscaler_files = [
f.name
for f in source_dir.iterdir()
if f.is_file() and UPSCALER_PATTERN.match(f.name)
]
if not upscaler_files:
print(" No upscaler files found")
for upscaler_file in sorted(upscaler_files):
dest = output_path / upscaler_file
if dest.exists():
print(f" {upscaler_file}: already exists, skipping")
continue
local_candidate = source_dir / upscaler_file
if local_candidate.is_file():
shutil.copy2(str(local_candidate), str(dest))
print(f" {upscaler_file}: copied")
elif is_hf_repo:
from huggingface_hub import hf_hub_download
print(f" {upscaler_file}: downloading from {source}...")
downloaded = hf_hub_download(repo_id=source, filename=upscaler_file)
shutil.copy2(downloaded, str(dest))
print(f" {upscaler_file}: done")
else:
print(f" {upscaler_file}: not found, skipping")
# Link text_encoder and tokenizer directories
print("\nLinking text encoder & tokenizer...")
for subdir in ["text_encoder", "tokenizer"]:
dest = output_path / subdir
if dest.exists():
print(f" {subdir}/: already exists, skipping")
continue
local_candidate = source_dir / subdir
if local_candidate.is_dir():
# Resolve through symlinks to get the real directory
real_path = local_candidate.resolve()
dest.symlink_to(real_path)
print(f" {subdir}/: symlinked to {real_path}")
elif is_hf_repo:
from huggingface_hub import list_repo_files, snapshot_download
# Only download if the subdir exists in the repo
repo_files = list_repo_files(source)
if any(f.startswith(f"{subdir}/") for f in repo_files):
print(f" {subdir}/: downloading from {source}...")
snapshot_download(
repo_id=source,
allow_patterns=f"{subdir}/*",
local_dir=str(output_path),
)
print(f" {subdir}/: done")
else:
print(f" {subdir}/: not in repo, skipping")
else:
print(f" {subdir}/: not found in source, skipping")
# Summary
all_converted = (
len(transformer_weights)
+ len(vae_decoder_weights)
+ len(vae_encoder_weights)
+ len(audio_decoder_weights)
+ len(audio_encoder_weights)
+ len(vocoder_weights)
+ len(text_proj_weights)
)
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
if all_converted < total_keys:
known_prefixes = (
TRANSFORMER_PREFIX,
VAE_DECODER_PREFIX,
VAE_ENCODER_PREFIX,
VAE_STATS_PREFIX,
AUDIO_DECODER_PREFIX,
AUDIO_ENCODER_PREFIX,
AUDIO_STATS_PREFIX,
VOCODER_PREFIX,
TEXT_PROJ_PREFIX,
VIDEO_CONNECTOR_PREFIX,
AUDIO_CONNECTOR_PREFIX,
)
skipped = [
k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)
]
if skipped:
print(f" Skipped {len(skipped)} keys:")
for k in sorted(skipped)[:20]:
print(f" {k}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert monolithic LTX-2/2.3 safetensors to modular MLX layout"
)
parser.add_argument(
"--source",
type=str,
required=True,
help="HF repo ID (e.g. Lightricks/LTX-2, Lightricks/LTX-2.3), local directory, or direct safetensors file path",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Output directory for modular layout",
)
parser.add_argument(
"--variant",
type=str,
choices=["distilled", "dev"],
default="distilled",
help="Model variant (affects VAE decoder config and which file to download)",
)
args = parser.parse_args()
convert(args.source, Path(args.output), variant=args.variant)

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,17 @@
from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import (
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
LTXModelType,
LTXRopeType,
TransformerConfig,
)
from mlx_video.models.ltx.adaln import AdaLayerNormSingle
from mlx_video.models.ltx.rope import precompute_freqs_cis
from mlx_video.models.ltx.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx.transformer import (
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx_2.transformer import (
BasicAVTransformerBlock,
Modality,
TransformerArgs,
@@ -26,7 +25,7 @@ class TransformerArgsPreprocessor:
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
caption_projection: Optional[PixArtAlphaTextProjection],
inner_dim: int,
max_pos: List[int],
num_attention_heads: int,
@@ -35,10 +34,12 @@ class TransformerArgsPreprocessor:
positional_embedding_theta: float,
rope_type: LTXRopeType,
double_precision_rope: bool = False,
prompt_adaln: Optional[AdaLayerNormSingle] = None,
):
self.patchify_proj = patchify_proj
self.adaln = adaln
self.caption_projection = caption_projection
self.prompt_adaln = prompt_adaln
self.inner_dim = inner_dim
self.max_pos = max_pos
self.num_attention_heads = num_attention_heads
@@ -56,14 +57,39 @@ class TransformerArgsPreprocessor:
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb, embedded_timestep = self.adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
# Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
timestep_emb = mx.reshape(
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
)
embedded_timestep = mx.reshape(
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
)
return timestep_emb, embedded_timestep
def _prepare_timestep_with_adaln(
self,
adaln: AdaLayerNormSingle,
timestep: mx.array,
batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
timestep_emb = mx.reshape(
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
)
embedded_timestep = mx.reshape(
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
)
return timestep_emb, embedded_timestep
def _prepare_context(
self,
context: mx.array,
@@ -72,9 +98,8 @@ class TransformerArgsPreprocessor:
) -> Tuple[mx.array, Optional[mx.array]]:
batch_size = x.shape[0]
# Context is already processed through embeddings connector in text encoder
# Here we just apply the caption projection
context = self.caption_projection(context)
if self.caption_projection is not None:
context = self.caption_projection(context)
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
return context, attention_mask
@@ -93,7 +118,9 @@ class TransformerArgsPreprocessor:
# Convert boolean/int mask to float mask
# 0 -> -inf (masked), 1 -> 0 (not masked)
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
mask = mx.reshape(
mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
)
return mask
def _prepare_positional_embeddings(
@@ -118,16 +145,40 @@ class TransformerArgsPreprocessor:
def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
timestep, embedded_timestep = self._prepare_timestep(
modality.timesteps, x.shape[0], hidden_dtype=x.dtype
)
context, attention_mask = self._prepare_context(
modality.context, x, modality.context_mask
)
attention_mask = self._prepare_attention_mask(
attention_mask, modality.latent.dtype
)
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
if modality.positional_embeddings is not None:
pe = modality.positional_embeddings
else:
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
# Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps
prompt_timestep = None
prompt_embedded_timestep = None
if self.prompt_adaln is not None and modality.sigma is not None:
prompt_timestep, prompt_embedded_timestep = (
self._prepare_timestep_with_adaln(
self.prompt_adaln,
modality.sigma,
x.shape[0],
hidden_dtype=x.dtype,
)
)
return TransformerArgs(
x=x,
@@ -140,6 +191,8 @@ class TransformerArgsPreprocessor:
cross_scale_shift_timestep=None,
cross_gate_timestep=None,
enabled=modality.enabled,
prompt_timesteps=prompt_timestep,
prompt_embedded_timestep=prompt_embedded_timestep,
)
@@ -149,7 +202,7 @@ class MultiModalTransformerArgsPreprocessor:
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: PixArtAlphaTextProjection,
caption_projection: Optional[PixArtAlphaTextProjection],
cross_scale_shift_adaln: AdaLayerNormSingle,
cross_gate_adaln: AdaLayerNormSingle,
inner_dim: int,
@@ -163,6 +216,7 @@ class MultiModalTransformerArgsPreprocessor:
rope_type: LTXRopeType,
av_ca_timestep_scale_multiplier: int,
double_precision_rope: bool = False,
prompt_adaln: Optional[AdaLayerNormSingle] = None,
):
self.simple_preprocessor = TransformerArgsPreprocessor(
patchify_proj=patchify_proj,
@@ -176,6 +230,7 @@ class MultiModalTransformerArgsPreprocessor:
positional_embedding_theta=positional_embedding_theta,
rope_type=rope_type,
double_precision_rope=double_precision_rope,
prompt_adaln=prompt_adaln,
)
self.cross_scale_shift_adaln = cross_scale_shift_adaln
self.cross_gate_adaln = cross_gate_adaln
@@ -198,11 +253,13 @@ class MultiModalTransformerArgsPreprocessor:
)
# Prepare cross-attention timestep embeddings
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
cross_scale_shift_timestep, cross_gate_timestep = (
self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
)
)
return replace(
@@ -223,17 +280,25 @@ class MultiModalTransformerArgsPreprocessor:
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
scale_shift_timestep = mx.reshape(
scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])
)
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(
timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype
)
gate_timestep = mx.reshape(
gate_timestep, (batch_size, -1, gate_timestep.shape[-1])
)
return scale_shift_timestep, gate_timestep
class LTXModel(nn.Module):
def __init__(self, config: LTXModelConfig):
super().__init__()
@@ -254,18 +319,25 @@ class LTXModel(nn.Module):
self._init_video(config)
if config.model_type.is_audio_enabled():
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
self.audio_positional_embedding_max_pos = (
config.audio_positional_embedding_max_pos
)
self.audio_num_attention_heads = config.audio_num_attention_heads
self.audio_inner_dim = config.audio_inner_dim
self._init_audio(config)
# Initialize cross-modal components
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
if (
config.model_type.is_video_enabled()
and config.model_type.is_audio_enabled()
):
cross_pe_max_pos = max(
config.positional_embedding_max_pos[0],
config.audio_positional_embedding_max_pos[0],
)
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
self.av_ca_timestep_scale_multiplier = (
config.av_ca_timestep_scale_multiplier
)
self.audio_cross_attention_dim = config.audio_cross_attention_dim
self._init_audio_video(config)
@@ -275,29 +347,51 @@ class LTXModel(nn.Module):
def _init_video(self, config: LTXModelConfig) -> None:
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
self.adaln_single = AdaLayerNormSingle(self.inner_dim)
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.inner_dim,
adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=adaln_coefficient
)
if config.has_prompt_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=2
)
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.inner_dim,
)
self.scale_shift_table = mx.zeros((2, self.inner_dim))
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels,
hidden_size=self.audio_inner_dim,
self.audio_patchify_proj = nn.Linear(
config.audio_in_channels, self.audio_inner_dim, bias=True
)
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient
)
if config.has_prompt_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim, embedding_coefficient=2
)
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels,
hidden_size=self.audio_inner_dim,
)
# Output components
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
self.audio_norm_out = nn.LayerNorm(
self.audio_inner_dim, eps=config.norm_eps, affine=False
)
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
def _init_audio_video(self, config: LTXModelConfig) -> None:
@@ -320,13 +414,18 @@ class LTXModel(nn.Module):
embedding_coefficient=1,
)
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
def _init_preprocessors(
self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]
) -> None:
if (
config.model_type.is_video_enabled()
and config.model_type.is_audio_enabled()
):
# Multi-modal preprocessors
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
caption_projection=getattr(self, "caption_projection", None),
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
inner_dim=self.inner_dim,
@@ -340,11 +439,12 @@ class LTXModel(nn.Module):
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
caption_projection=getattr(self, "audio_caption_projection", None),
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
inner_dim=self.audio_inner_dim,
@@ -358,12 +458,13 @@ class LTXModel(nn.Module):
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
elif config.model_type.is_video_enabled():
self.video_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=self.caption_projection,
caption_projection=getattr(self, "caption_projection", None),
inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
@@ -372,12 +473,13 @@ class LTXModel(nn.Module):
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
elif config.model_type.is_audio_enabled():
self.audio_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=self.audio_caption_projection,
caption_projection=getattr(self, "audio_caption_projection", None),
inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
@@ -386,13 +488,13 @@ class LTXModel(nn.Module):
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
video_config = config.get_video_config()
audio_config = config.get_audio_config()
self.transformer_blocks = {
idx: BasicAVTransformerBlock(
idx=idx,
@@ -400,6 +502,7 @@ class LTXModel(nn.Module):
audio=audio_config,
rope_type=config.rope_type,
norm_eps=config.norm_eps,
has_prompt_adaln=config.has_prompt_adaln,
)
for idx in range(config.num_layers)
}
@@ -408,10 +511,27 @@ class LTXModel(nn.Module):
self,
video: Optional[TransformerArgs],
audio: Optional[TransformerArgs],
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks."""
for block in self.transformer_blocks.values():
video, audio = block(video=video, audio=audio)
"""Process through all transformer blocks.
Args:
stg_video_blocks: Block indices where video self-attention is skipped (STG).
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
"""
stg_v_set = set(stg_video_blocks) if stg_video_blocks else set()
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
for idx, block in self.transformer_blocks.items():
video, audio = block(
video=video,
audio=audio,
skip_video_self_attn=(idx in stg_v_set),
skip_audio_self_attn=(idx in stg_a_set),
skip_cross_modal=skip_cross_modal,
)
return video, audio
def _process_output(
@@ -422,7 +542,7 @@ class LTXModel(nn.Module):
x: mx.array,
embedded_timestep: mx.array,
) -> mx.array:
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
@@ -445,8 +565,19 @@ class LTXModel(nn.Module):
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
"""Forward pass.
Args:
video: Video modality input.
audio: Audio modality input.
stg_video_blocks: Block indices where video self-attention is skipped (STG).
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
"""
# Validate inputs
if not self.model_type.is_video_enabled() and video is not None:
raise ValueError("Video is not enabled for this model")
@@ -454,13 +585,20 @@ class LTXModel(nn.Module):
raise ValueError("Audio is not enabled for this model")
# Preprocess arguments
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
video_args = (
self.video_args_preprocessor.prepare(video) if video is not None else None
)
audio_args = (
self.audio_args_preprocessor.prepare(audio) if audio is not None else None
)
# Process transformer blocks
video_out, audio_out = self._process_transformer_blocks(
video=video_args,
audio=audio_args,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
# Process outputs
@@ -492,24 +630,70 @@ class LTXModel(nn.Module):
def sanitize(self, weights: dict) -> dict:
sanitized = {}
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
if not has_raw_prefix:
return weights
for key, value in weights.items():
new_key = key
# Handle common remappings
# transformer_blocks.X -> transformer_blocks[X]
if "transformer_blocks." in new_key:
# Keep as-is for now, MLX handles this
pass
if not key.startswith("model.diffusion_model."):
continue
if (
"audio_embeddings_connector" in key
or "video_embeddings_connector" in key
):
continue
# Remove 'model.diffusion_model.' prefix
new_key = new_key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTXModel":
import json
config_dict = {}
with open(model_path / "config.json", "r") as f:
config_dict = json.load(f)
config = LTXModelConfig(**config_dict)
model = cls(config)
weights = {}
for weight_file in model_path.glob("*.safetensors"):
weights.update(mx.load(str(weight_file)))
sanitized = model.sanitize(weights)
sanitized = {
k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v
for k, v in sanitized.items()
}
model.load_weights(list(sanitized.items()), strict=strict)
mx.eval(model.parameters())
model.eval()
return model
class X0Model(nn.Module):
def __init__(self, velocity_model: LTXModel):
super().__init__()
self.velocity_model = velocity_model
@@ -517,11 +701,24 @@ class X0Model(nn.Module):
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model(video, audio)
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
vx, ax = self.velocity_model(
video,
audio,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
denoised_video = (
to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
)
denoised_audio = (
to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
)
return denoised_video, denoised_audio

View File

@@ -1,9 +1,10 @@
import numpy as np
from typing import Optional
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
def bilateral_filter(
image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75
) -> np.ndarray:
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
Args:
@@ -17,6 +18,7 @@ def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sig
"""
try:
import cv2
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
except ImportError:
# Fallback to simple Gaussian blur if cv2 not available
@@ -35,14 +37,20 @@ def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
"""
try:
import cv2
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
except ImportError:
# Simple box blur fallback
from scipy.ndimage import uniform_filter
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(
np.uint8
)
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
def unsharp_mask(
image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0
) -> np.ndarray:
"""Apply unsharp masking to enhance edges after blur.
Args:
@@ -56,6 +64,7 @@ def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, am
"""
try:
import cv2
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
return np.clip(sharpened, 0, 255).astype(np.uint8)
@@ -81,23 +90,23 @@ def reduce_grid_artifacts(
if method == "bilateral":
d = max(3, int(5 * strength))
sigma = 50 + 50 * strength
processed = np.stack([
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
for frame in video
])
processed = np.stack(
[
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
for frame in video
]
)
elif method == "gaussian":
kernel_size = max(3, int(3 + 4 * strength))
if kernel_size % 2 == 0:
kernel_size += 1
processed = np.stack([
gaussian_blur(frame, kernel_size=kernel_size)
for frame in video
])
processed = np.stack(
[gaussian_blur(frame, kernel_size=kernel_size) for frame in video]
)
elif method == "frequency":
processed = np.stack([
remove_grid_frequency(frame, grid_size=8)
for frame in video
])
processed = np.stack(
[remove_grid_frequency(frame, grid_size=8) for frame in video]
)
else:
raise ValueError(f"Unknown method: {method}")
@@ -160,6 +169,3 @@ def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
return result

View File

@@ -1,11 +1,9 @@
import math
from typing import Callable, List, Optional, Tuple
from typing import List, Optional, Tuple
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.config import LTXRopeType
from mlx_video.models.ltx_2.config import LTXRopeType
def apply_rotary_emb(
@@ -87,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array:
"""
# x: (..., dim) where dim is even
x_even = x[..., 0::2] # [x0, x2, x4, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
rotated = mx.stack([-x_odd, x_even], axis=-1)
return mx.reshape(rotated, x.shape)
def apply_rotary_emb_1d(
q: mx.array,
k: mx.array,
@@ -229,9 +228,9 @@ def get_fractional_positions(
Fractional positions in range [-1, 1] after scaling
"""
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), (
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
)
assert n_pos_dims == len(
max_pos
), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
# Divide each dimension by its max position
fractional_positions = []
@@ -393,13 +392,25 @@ def precompute_freqs_cis(
if max_pos is None:
max_pos = [20, 2048, 2048]
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
num_attention_heads, rope_type
indices_grid,
dim,
theta,
max_pos,
use_middle_indices_grid,
num_attention_heads,
rope_type,
)
# Keep positions in float32 for RoPE computation.
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
# empirical comparison shows float32 positions produce RoPE values matching
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
# position computation that gets amplified by high-frequency indices
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
indices_grid = indices_grid.astype(mx.float32)
# Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
@@ -429,66 +440,77 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
warnings.warn(
"Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. "
"Use float32 for position grids to avoid quality degradation. "
"See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss",
UserWarning,
stacklevel=2
)
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
frequency grid computation (log-spaced values), then converts to float32.
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
model dtype throughout generate_freqs).
"""
import numpy as np
# Convert to numpy float64 (first to float32 for numpy compatibility)
# Note: If input is bfloat16, precision is already lost at this step
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# Keep positions in float32 — same reasoning as the non-double-precision path.
indices_grid_f32 = indices_grid.astype(mx.float32)
# Generate frequency indices in float64
n_pos_dims = indices_grid_np.shape[1]
n_pos_dims = indices_grid_f32.shape[1]
n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies
log_start = math.log(1.0) / math.log(theta)
log_end = math.log(theta) / math.log(theta)
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
# This is the critical precision step - PyTorch uses np.float64 here
log_start = np.log(1.0) / np.log(theta)
log_end = np.log(theta) / np.log(theta) # = 1.0
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = np.linspace(log_start, log_end, num_indices)
indices_np = np.power(theta, lin_space) * (math.pi / 2)
# Use numpy float64 for the linspace computation (matches PyTorch)
pow_indices = np.power(
theta,
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
)
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
# Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
if use_middle_indices_grid:
assert len(indices_grid_np.shape) == 4
assert indices_grid_np.shape[-1] == 2
indices_grid_start = indices_grid_np[..., 0]
indices_grid_end = indices_grid_np[..., 1]
indices_grid_np = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid_np.shape) == 4:
indices_grid_np = indices_grid_np[..., 0]
# After handling: indices_grid_np shape is (B, n_dims, T)
assert len(indices_grid_f32.shape) == 4
assert indices_grid_f32.shape[-1] == 2
indices_grid_start = indices_grid_f32[..., 0]
indices_grid_end = indices_grid_f32[..., 1]
indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid_f32.shape) == 4:
indices_grid_f32 = indices_grid_f32[..., 0]
# After handling: indices_grid_f32 shape is (B, n_dims, T)
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
batch_size = indices_grid_np.shape[0]
seq_len = indices_grid_np.shape[2]
fractional_positions = np.zeros((batch_size, seq_len, n_pos_dims), dtype=np.float64)
# Compute fractional positions for each dimension
fractional_list = []
for i in range(n_pos_dims):
# indices_grid_np[:, i, :] has shape (B, T)
fractional_positions[:, :, i] = indices_grid_np[:, i, :] / max_pos[i]
frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T)
fractional_list.append(frac)
# Stack: (B, T, n_dims)
fractional_positions = mx.stack(fractional_list, axis=-1)
# Scale to [-1, 1]
scaled_positions = fractional_positions * 2 - 1
# Compute frequencies: outer product
freqs = np.expand_dims(scaled_positions, axis=-1) * indices_np.reshape(1, 1, 1, -1)
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[:-2] + (-1,))
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(
freq_indices, (1, 1, 1, -1)
)
# freqs: (B, T, n_dims, num_indices)
# Compute cos/sin in float64
cos_freq = np.cos(freqs)
sin_freq = np.sin(freqs)
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
freqs = mx.swapaxes(freqs, -1, -2)
freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1))
# Compute cos/sin
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Prepare based on rope type
if rope_type == LTXRopeType.SPLIT:
@@ -498,31 +520,27 @@ def _precompute_freqs_cis_double_precision(
# Add padding
if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, num_attention_heads, -1)
sin_freq = sin_freq.reshape(b, t, num_attention_heads, -1)
cos_freq = np.swapaxes(cos_freq, 1, 2)
sin_freq = np.swapaxes(sin_freq, 1, 2)
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
cos_freq = mx.swapaxes(cos_freq, 1, 2)
sin_freq = mx.swapaxes(sin_freq, 1, 2)
else:
# Interleaved
cos_freq = np.repeat(cos_freq, 2, axis=-1)
sin_freq = np.repeat(sin_freq, 2, axis=-1)
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
pad_size = dim % n_elem
if pad_size > 0:
cos_padding = np.ones((*cos_freq.shape[:-1], pad_size), dtype=np.float64)
sin_padding = np.zeros((*sin_freq.shape[:-1], pad_size), dtype=np.float64)
cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1)
# Convert back to MLX (float32 for GPU compatibility)
cos_freq = mx.array(cos_freq.astype(np.float32))
sin_freq = mx.array(sin_freq.astype(np.float32))
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
return cos_freq, sin_freq

View File

@@ -0,0 +1,185 @@
"""Second-order res_2s sampler for diffusion models.
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
noise injection, ported from the LTX-2 PyTorch implementation.
"""
import math
import mlx.core as mx
# ---------------------------------------------------------------------------
# Phi functions and RK coefficients (pure Python math, no MLX needed)
# ---------------------------------------------------------------------------
def phi(j: int, neg_h: float) -> float:
"""Compute phi_j(z) where z = -h (negative step size in log-space).
phi_1(z) = (e^z - 1) / z
phi_2(z) = (e^z - 1 - z) / z^2
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
"""
if abs(neg_h) < 1e-10:
return 1.0 / math.factorial(j)
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
return (math.exp(neg_h) - remainder) / (neg_h**j)
def get_res2s_coefficients(
h: float,
phi_cache: dict,
c2: float = 0.5,
) -> tuple[float, float, float]:
"""Compute res_2s Runge-Kutta coefficients for a given step size.
Args:
h: Step size in log-space = log(sigma / sigma_next)
phi_cache: Dictionary to cache phi function results.
c2: Substep position (default 0.5 = midpoint)
Returns:
(a21, b1, b2): RK coefficients.
"""
def get_phi(j: int, neg_h: float) -> float:
cache_key = (j, neg_h)
if cache_key in phi_cache:
return phi_cache[cache_key]
result = phi(j, neg_h)
phi_cache[cache_key] = result
return result
neg_h_c2 = -h * c2
phi_1_c2 = get_phi(1, neg_h_c2)
a21 = c2 * phi_1_c2
neg_h_full = -h
phi_2_full = get_phi(2, neg_h_full)
b2 = phi_2_full / c2
phi_1_full = get_phi(1, neg_h_full)
b1 = phi_1_full - b2
return a21, b1, b2
# ---------------------------------------------------------------------------
# SDE noise injection
# ---------------------------------------------------------------------------
def get_sde_coeff(
sigma_next: float,
) -> tuple[float, float, float]:
"""Compute SDE coefficients for variance-preserving noise injection.
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
Returns:
(alpha_ratio, sigma_down, sigma_up)
"""
sigma_up = sigma_next * 0.5
# Clamp sigma_up to avoid sqrt(negative)
sigma_up = min(sigma_up, sigma_next * 0.9999)
sigma_signal = 1.0 - sigma_next # sigma_max=1
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
alpha_ratio = sigma_signal + sigma_residual
if alpha_ratio == 0:
sigma_down = sigma_next
else:
sigma_down = sigma_residual / alpha_ratio
# Handle NaN edge cases
if math.isnan(sigma_up):
sigma_up = 0.0
if math.isnan(sigma_down):
sigma_down = sigma_next
if math.isnan(alpha_ratio):
alpha_ratio = 1.0
return alpha_ratio, sigma_down, sigma_up
def sde_noise_step(
sample: mx.array,
denoised_sample: mx.array,
sigma: float,
sigma_next: float,
noise: mx.array,
) -> mx.array:
"""Apply SDE noise injection step.
Advances sample from sigma to sigma_next with stochastic noise injection.
Args:
sample: Current sample (anchor point)
denoised_sample: Denoised prediction at this step
sigma: Current noise level
sigma_next: Next noise level
noise: Pre-generated noise tensor (channel-wise normalized)
Returns:
Noised sample at sigma_next
"""
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
if sigma_up == 0 or sigma_next == 0:
return denoised_sample
# Float32 arithmetic
sample_f32 = sample.astype(mx.float32)
denoised_f32 = denoised_sample.astype(mx.float32)
noise_f32 = noise.astype(mx.float32)
# Extract epsilon prediction
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
denoised_next = sample_f32 - sigma * eps_next
# Mix deterministic and stochastic components
x_noised = (
alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
)
return x_noised
# ---------------------------------------------------------------------------
# Noise generation
# ---------------------------------------------------------------------------
def channelwise_normalize(x: mx.array) -> mx.array:
"""Normalize each channel to zero mean and unit variance over spatial dims.
Operates on the last 2 dimensions (spatial H, W or time, freq).
"""
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
x = x - mean
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
x = x / std
return x
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
"""Generate channel-wise normalized Gaussian noise.
PyTorch uses float64; we use float32 (MLX doesn't support float64).
The channel-wise normalization is the key quality-affecting step.
Args:
shape: Shape of the noise tensor
key: MLX random key for deterministic generation
Returns:
Channel-wise normalized noise in float32
"""
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
# Global normalization
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
# Channel-wise normalization
noise = channelwise_normalize(noise)
return noise

View File

@@ -11,7 +11,7 @@ class PixArtAlphaTextProjection(nn.Module):
out_features: int | None = None,
bias: bool = True,
):
super().__init__()
out_features = out_features or hidden_size

View File

@@ -4,34 +4,41 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx.attention import Attention
from mlx_video.models.ltx.feed_forward import FeedForward
from mlx_video.models.ltx_2.attention import Attention
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx_2.feed_forward import FeedForward
from mlx_video.utils import rms_norm
@dataclass(frozen=True)
class Modality:
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
enabled: bool = True
context_mask: Optional[mx.array] = None
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
# Raw sigma value (scalar per batch) for prompt adaln (LTX-2.3)
sigma: Optional[mx.array] = None
@dataclass(frozen=True)
class TransformerArgs:
x: mx.array
context: mx.array
context_mask: Optional[mx.array]
timesteps: mx.array
embedded_timestep: mx.array
positional_embeddings: Tuple[mx.array, mx.array]
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array]
x: mx.array
context: mx.array
context_mask: Optional[mx.array]
timesteps: mx.array
embedded_timestep: mx.array
positional_embeddings: Tuple[mx.array, mx.array]
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array]
enabled: bool
# LTX-2.3: prompt-conditioned timestep embeddings for cross-attention
prompt_timesteps: Optional[mx.array] = None
prompt_embedded_timestep: Optional[mx.array] = None
class BasicAVTransformerBlock(nn.Module):
@@ -48,20 +55,13 @@ class BasicAVTransformerBlock(nn.Module):
audio: Optional[TransformerConfig] = None,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
norm_eps: float = 1e-6,
has_prompt_adaln: bool = False,
):
"""Initialize transformer block.
Args:
idx: Block index
video: Video modality configuration
audio: Audio modality configuration
rope_type: Type of rotary position embedding
norm_eps: Epsilon for normalization
"""
super().__init__()
self.idx = idx
self.norm_eps = norm_eps
self.has_prompt_adaln = has_prompt_adaln
# Video components
if video is not None:
@@ -72,6 +72,7 @@ class BasicAVTransformerBlock(nn.Module):
context_dim=None, # Self-attention
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.attn2 = Attention(
query_dim=video.dim,
@@ -80,10 +81,15 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=video.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.ff = FeedForward(video.dim, dim_out=video.dim)
# 6 scale-shift parameters: 3 for attention, 3 for MLP
self.scale_shift_table = mx.zeros((6, video.dim))
# 9 params for LTX-2.3 (self-attn + cross-attn + FFN), 6 for LTX-2
num_ada_params = 9 if has_prompt_adaln else 6
self.scale_shift_table = mx.zeros((num_ada_params, video.dim))
if has_prompt_adaln:
self.prompt_scale_shift_table = mx.zeros((2, video.dim))
# Audio components
if audio is not None:
@@ -94,6 +100,7 @@ class BasicAVTransformerBlock(nn.Module):
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.audio_attn2 = Attention(
query_dim=audio.dim,
@@ -102,9 +109,14 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
self.audio_scale_shift_table = mx.zeros((6, audio.dim))
num_audio_ada_params = 9 if has_prompt_adaln else 6
self.audio_scale_shift_table = mx.zeros((num_audio_ada_params, audio.dim))
if has_prompt_adaln:
self.audio_prompt_scale_shift_table = mx.zeros((2, audio.dim))
# Cross-modal attention (when both video and audio are enabled)
if audio is not None and video is not None:
@@ -116,6 +128,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
# Video-to-Audio: Q from audio, K/V from video
self.video_to_audio_attn = Attention(
@@ -125,6 +138,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
# Scale-shift tables for cross-attention
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
@@ -157,8 +171,7 @@ class BasicAVTransformerBlock(nn.Module):
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
timestep_reshaped = mx.reshape(
timestep,
(batch_size, timestep.shape[1], num_ada_params, -1)
timestep, (batch_size, timestep.shape[1], num_ada_params, -1)
)
# Extract the relevant indices
@@ -211,8 +224,12 @@ class BasicAVTransformerBlock(nn.Module):
)
# Squeeze the sequence dimension if it's 1
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
scale_shift_squeezed = tuple(
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada
)
gate_squeezed = tuple(
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada
)
return (*scale_shift_squeezed, *gate_squeezed)
@@ -220,12 +237,18 @@ class BasicAVTransformerBlock(nn.Module):
self,
video: Optional[TransformerArgs] = None,
audio: Optional[TransformerArgs] = None,
skip_video_self_attn: bool = False,
skip_audio_self_attn: bool = False,
skip_cross_modal: bool = False,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Forward pass through transformer block.
Args:
video: Video modality arguments
audio: Audio modality arguments
skip_video_self_attn: Skip video self-attention (for STG perturbation)
skip_audio_self_attn: Skip audio self-attention (for STG perturbation)
skip_cross_modal: Skip all cross-modal attention (for modality isolation)
Returns:
Tuple of (updated_video, updated_audio) TransformerArgs
@@ -238,8 +261,16 @@ class BasicAVTransformerBlock(nn.Module):
# Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0)
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0)
run_a2v = (
run_vx
and (audio is not None and audio.enabled and ax.size > 0)
and not skip_cross_modal
)
run_v2a = (
run_ax
and (video is not None and video.enabled and vx.size > 0)
and not skip_cross_modal
)
# Process video self-attention and cross-attention with text
if run_vx:
@@ -247,16 +278,49 @@ class BasicAVTransformerBlock(nn.Module):
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
)
# Self-attention with RoPE
# Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings) * vgate_msa
vx = (
vx
+ self.attn1(
norm_vx,
pe=video.positional_embeddings,
skip_attention=skip_video_self_attn,
)
* vgate_msa
)
# Cross-attention with text context
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
context=video.context,
mask=video.context_mask,
)
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
vshift_q, vscale_q, vgate_q = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
)
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
self.prompt_scale_shift_table,
vx.shape[0],
video.prompt_timesteps,
slice(0, 2),
)
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
encoder_hidden_states = (
video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
)
vx = (
vx
+ self.attn2(
attn_input,
context=encoder_hidden_states,
mask=video.context_mask,
)
* vgate_q
)
else:
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
context=video.context,
mask=video.context_mask,
)
# Process audio self-attention and cross-attention with text
if run_ax:
@@ -264,16 +328,54 @@ class BasicAVTransformerBlock(nn.Module):
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
)
# Self-attention with RoPE
# Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings) * agate_msa
ax = (
ax
+ self.audio_attn1(
norm_ax,
pe=audio.positional_embeddings,
skip_attention=skip_audio_self_attn,
)
* agate_msa
)
# Cross-attention with text context
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),
context=audio.context,
mask=audio.context_mask,
)
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
ashift_q, ascale_q, agate_q = self.get_ada_values(
self.audio_scale_shift_table,
ax.shape[0],
audio.timesteps,
slice(6, 9),
)
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
self.audio_prompt_scale_shift_table,
ax.shape[0],
audio.prompt_timesteps,
slice(0, 2),
)
attn_input_a = (
rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
)
encoder_hidden_states_a = (
audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
)
ax = (
ax
+ self.audio_attn2(
attn_input_a,
context=encoder_hidden_states_a,
mask=audio.context_mask,
)
* agate_q
)
else:
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),
context=audio.context,
mask=audio.context_mask,
)
# Audio-Video cross-modal attention
if run_a2v or run_v2a:
@@ -339,7 +441,7 @@ class BasicAVTransformerBlock(nn.Module):
# Process video feed-forward
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, None)
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
)
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp
@@ -347,7 +449,7 @@ class BasicAVTransformerBlock(nn.Module):
# Process audio feed-forward
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, None)
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
)
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp

View File

@@ -1,4 +1,5 @@
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -36,11 +37,20 @@ class Conv3d(nn.Module):
self.groups = groups
# Weight shape: (C_out, KD, KH, KW, C_in)
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
scale = (
1.0
/ (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
shape=(
out_channels,
kernel_size[0],
kernel_size[1],
kernel_size[2],
in_channels,
),
)
if bias:
@@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module):
n, d, h, w, c = x.shape
input_dtype = x.dtype
x = x.astype(mx.float32)
# Reshape to (N, D*H*W, num_groups, C//num_groups)
@@ -115,65 +124,138 @@ class GroupNorm3d(nn.Module):
class PixelShuffle2D(nn.Module):
"""Pixel shuffle for 2D spatial upsampling."""
"""Pixel shuffle for 2D spatial upsampling with per-axis factors."""
def __init__(self, upscale_factor: int = 2):
def __init__(self, upscale_factor_h: int = 2, upscale_factor_w: int = 2):
super().__init__()
self.upscale_factor = upscale_factor
self.rh = upscale_factor_h
self.rw = upscale_factor_w
def __call__(self, x: mx.array) -> mx.array:
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
# x: (N, H, W, C) where C = out_channels * rh * rw
n, h, w, c = x.shape
r = self.upscale_factor
out_c = c // (r * r)
rh, rw = self.rh, self.rw
out_c = c // (rh * rw)
# Reshape: (N, H, W, out_c, r, r)
x = mx.reshape(x, (n, h, w, out_c, r, r))
# Reshape: (N, H, W, out_c, rh, rw)
x = mx.reshape(x, (n, h, w, out_c, rh, rw))
# Permute: (N, H, r, W, r, out_c)
# Permute: (N, H, rh, W, rw, out_c)
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
# Reshape: (N, H*r, W*r, out_c)
x = mx.reshape(x, (n, h * r, w * r, out_c))
# Reshape: (N, H*rh, W*rw, out_c)
x = mx.reshape(x, (n, h * rh, w * rw, out_c))
return x
class BlurDownsample(nn.Module):
"""Anti-aliased downsampling with a fixed 5x5 binomial blur kernel.
PyTorch source uses a depthwise conv with the binomial kernel.
The kernel weight is stored as (1, 1, 5, 5) and loaded via safetensors.
"""
def __init__(self, stride: int = 2):
super().__init__()
self.stride = stride
# 5x5 binomial (1,4,6,4,1) kernel, normalized
# This will be overwritten by loaded weights if available
k = mx.array([1.0, 4.0, 6.0, 4.0, 1.0])
kernel_2d = mx.outer(k, k)
kernel_2d = kernel_2d / kernel_2d.sum()
# MLX conv2d weight: (O, H, W, I) — we use (1, 5, 5, 1) for per-channel
self.kernel = kernel_2d.reshape(1, 5, 5, 1)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, H, W, C) channels-last
n, h, w, c = x.shape
# Pad with edge replication (2 on each side for 5x5 kernel)
x = mx.pad(x, [(0, 0), (2, 2), (2, 2), (0, 0)], mode="edge")
# Apply blur per-channel: reshape so each channel is a separate "batch"
# (N, H+4, W+4, C) -> (N*C, H+4, W+4, 1)
x = mx.transpose(x, (0, 3, 1, 2)) # (N, C, H+4, W+4)
x = mx.reshape(x, (n * c, h + 4, w + 4, 1))
# Depthwise conv: (N*C, H+4, W+4, 1) * (1, 5, 5, 1) -> (N*C, H_out, W_out, 1)
x = mx.conv2d(x, self.kernel, stride=(self.stride, self.stride))
_, h_out, w_out, _ = x.shape
# Reshape back: (N*C, H_out, W_out, 1) -> (N, C, H_out, W_out) -> (N, H_out, W_out, C)
x = mx.reshape(x, (n, c, h_out, w_out))
x = mx.transpose(x, (0, 2, 3, 1))
return x
class SpatialUpsampler2x(nn.Module):
"""Standard 2x spatial upsampler: Conv2d + PixelShuffle(2)."""
def __init__(self, mid_channels: int = 1024):
super().__init__()
self.scale = 2.0
# Sequential: conv (index 0) + pixel shuffle
# Weight key: upsampler.0.weight -> mapped to upsampler.conv.weight in sanitize
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffle2D(2, 2)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C)
n, d, h, w, c = x.shape
x = mx.reshape(x, (n * d, h, w, c))
x = self.conv(x)
x = self.pixel_shuffle(x)
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
return x
class SpatialRationalResampler(nn.Module):
"""Rational spatial resampler for non-integer scale factors (e.g., 1.5x).
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
For scale=1.5: upsample 3x via PixelShuffle, then downsample 2x via BlurDownsample.
Rational fraction: 1.5 = 3/2.
"""
def __init__(self, mid_channels: int = 1024, scale: float = 1.5):
super().__init__()
self.scale = scale
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
# Rational fraction for 1.5: numerator=3, denominator=2
num, den = _rational_for_scale(scale)
self.num = num
self.den = den
# Blur kernel for antialiasing
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
self.pixel_shuffle = PixelShuffle2D(2)
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
self.conv = nn.Conv2d(
mid_channels, num * num * mid_channels, kernel_size=3, padding=1
)
self.pixel_shuffle = PixelShuffle2D(num, num)
self.blur_down = BlurDownsample(stride=den)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C) - channels last 3D format
# x: (N, D, H, W, C)
n, d, h, w, c = x.shape
# Process frame by frame
# Reshape to (N*D, H, W, C) for 2D operations
x = mx.reshape(x, (n * d, h, w, c))
# Apply 2D conv
x = self.conv(x)
x = self.pixel_shuffle(x) # H*num, W*num
x = self.blur_down(x) # H*num/den, W*num/den
# Pixel shuffle for 2x upscaling
x = self.pixel_shuffle(x)
# Reshape back to (N, D, H*2, W*2, C)
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
_, h_out, w_out, _ = x.shape
x = mx.reshape(x, (n, d, h_out, w_out, c))
return x
def _rational_for_scale(scale: float) -> Tuple[int, int]:
"""Convert a float scale to a rational fraction (numerator, denominator)."""
from fractions import Fraction
frac = Fraction(scale).limit_denominator(10)
return frac.numerator, frac.denominator
class ResBlock3D(nn.Module):
def __init__(self, channels: int):
@@ -201,48 +283,62 @@ class ResBlock3D(nn.Module):
class LatentUpsampler(nn.Module):
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
spatial_scale: float = 2.0,
rational_resampler: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.spatial_scale = spatial_scale
# Initial projection
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = GroupNorm3d(32, mid_channels)
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
self.res_blocks = {
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
}
# Upsampler: 2D spatial upsampling (frame-by-frame)
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
if rational_resampler:
self.upsampler = SpatialRationalResampler(
mid_channels=mid_channels, scale=spatial_scale
)
else:
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
self.post_upsample_res_blocks = {
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
}
# Final projection
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
"""Upsample latents by 2x spatially.
"""Upsample latents spatially.
Args:
latent: Input tensor of shape (B, C, F, H, W) - channels first
debug: If True, print intermediate values for debugging
Returns:
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
"""
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
print(
f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}"
)
if debug:
print(" [DEBUG] LatentUpsampler forward pass:")
@@ -250,41 +346,27 @@ class LatentUpsampler(nn.Module):
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
x = mx.transpose(latent, (0, 2, 3, 4, 1))
if debug:
debug_stats("After transpose to channels-last", x)
# Initial conv
x = self.initial_conv(x)
if debug:
debug_stats("After initial_conv", x)
x = self.initial_norm(x)
if debug:
debug_stats("After initial_norm", x)
x = nn.silu(x)
if debug:
debug_stats("After silu", x)
# Pre-upsample blocks
for i in sorted(self.res_blocks.keys()):
x = self.res_blocks[i](x)
if debug:
debug_stats(f"After res_blocks[{i}]", x)
# Upsample (2D spatial, frame-by-frame)
x = self.upsampler(x)
if debug:
debug_stats("After upsampler (spatial 2x)", x)
debug_stats(f"After upsampler (spatial {self.spatial_scale}x)", x)
# Post-upsample blocks
for i in sorted(self.post_upsample_res_blocks.keys()):
x = self.post_upsample_res_blocks[i](x)
if debug:
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
# Final conv
x = self.final_conv(x)
if debug:
debug_stats("After final_conv", x)
# Convert back to channels first (B, C, F, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
@@ -301,48 +383,73 @@ def upsample_latents(
latent_std: mx.array,
debug: bool = False,
) -> mx.array:
# Un-normalize: latent * std + mean
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
latent = latent * latent_std + latent_mean
# Upsample
latent = upsampler(latent, debug=debug)
# Re-normalize: (latent - mean) / std
latent = (latent - latent_mean) / latent_std
return latent
def load_upsampler(weights_path: str) -> LatentUpsampler:
def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
"""Load upsampler from safetensors weights.
Auto-detects whether the weights are for x2 or x1.5 upscaling based on
the upsampler conv output channels:
- x2: upsampler.0.weight shape [4*mid, mid, 3, 3] (4096 out channels)
- x1.5: upsampler.conv.weight shape [9*mid, mid, 3, 3] (9216 out channels)
Args:
weights_path: Path to upsampler weights file
Returns:
Loaded LatentUpsampler model
Tuple of (LatentUpsampler model, spatial_scale)
"""
print(f"Loading spatial upsampler from {weights_path}...")
raw_weights = mx.load(weights_path)
# Check weight shapes to determine mid_channels
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
# Detect mid_channels from res_blocks
sample_key = "res_blocks.0.conv1.weight"
if sample_key in raw_weights:
mid_channels = raw_weights[sample_key].shape[0]
else:
mid_channels = 1024 # default
mid_channels = 1024
print(f" Detected mid_channels: {mid_channels}")
# Detect upsampler type from conv output channels
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
# Both formats may have upsampler.blur_down.kernel, so use channel count
conv_key = (
"upsampler.conv.weight"
if "upsampler.conv.weight" in raw_weights
else "upsampler.0.weight"
)
if conv_key in raw_weights:
out_channels = raw_weights[conv_key].shape[0]
ratio = out_channels // mid_channels
rational_resampler = ratio == 9 # 3^2 for PixelShuffle(3) + blur downsample
spatial_scale = 1.5 if rational_resampler else 2.0
else:
rational_resampler = False
spatial_scale = 2.0
print(
f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}"
)
# Create model
upsampler = LatentUpsampler(
in_channels=128,
mid_channels=mid_channels,
num_blocks_per_stage=4,
spatial_scale=spatial_scale,
rational_resampler=rational_resampler,
)
# Sanitize weights - convert from PyTorch to MLX format
@@ -350,19 +457,18 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
for key, value in raw_weights.items():
new_key = key
# x2 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
if key.startswith("upsampler.0."):
new_key = key.replace("upsampler.0.", "upsampler.conv.")
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 5:
if "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in key and "weight" in key and value.ndim == 4:
if ("weight" in new_key or "kernel" in new_key) and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
# Map upsampler.conv to upsampler.conv (SpatialRationalResampler)
# Keys: upsampler.conv.weight, upsampler.conv.bias, upsampler.blur_down.kernel
if key.startswith("upsampler."):
new_key = key # Keep as is for SpatialRationalResampler
sanitized[new_key] = value
# Load weights
@@ -370,4 +476,4 @@ def load_upsampler(weights_path: str) -> LatentUpsampler:
print(f" Loaded {len(sanitized)} weights")
return upsampler
return upsampler, spatial_scale

View File

@@ -0,0 +1,162 @@
"""Shared utilities for LTX-2 model loading and conversion."""
import json
from pathlib import Path
from typing import Any, Dict, Optional
import mlx.core as mx
from huggingface_hub import snapshot_download
def get_model_path(
path_or_hf_repo: str,
revision: Optional[str] = None,
) -> Path:
"""Get local path to model, downloading if necessary.
Args:
path_or_hf_repo: Local path or HuggingFace repo ID
revision: Git revision for HF repo
Returns:
Path to model directory
"""
model_path = Path(path_or_hf_repo)
if model_path.exists():
return model_path
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.safetensors",
"*.json",
"config.json",
],
)
)
return model_path
def load_safetensors(path: Path) -> Dict[str, mx.array]:
"""Load weights from safetensors file(s) using MLX.
Args:
path: Path to model directory or single safetensors file
Returns:
Dictionary of weights
"""
if path.is_file():
return mx.load(str(path))
weights = {}
for sf_path in path.glob("*.safetensors"):
weights.update(mx.load(str(sf_path)))
return weights
def load_config(model_path: Path) -> Dict[str, Any]:
"""Load model configuration from config.json.
Args:
model_path: Path to model directory
Returns:
Configuration dictionary
"""
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
return json.load(f)
return {}
def save_weights(path: Path, weights: Dict[str, mx.array]) -> None:
"""Save weights in safetensors format.
Args:
path: Output directory
weights: Dictionary of weights
"""
path.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(path / "model.safetensors"), weights)
def convert_audio_encoder(
model_path,
source_repo: str = "Lightricks/LTX-2",
) -> Path:
"""Convert and save audio encoder weights from original HF checkpoint.
Extracts encoder weights from the combined audio VAE safetensors,
transposes Conv2d for MLX, and saves for AudioEncoder.from_pretrained().
Args:
model_path: Local model directory (output location).
source_repo: HF repo containing audio_vae/diffusion_pytorch_model.safetensors.
Returns:
Path to the audio_vae/encoder directory.
"""
model_path = Path(model_path)
encoder_dir = model_path / "audio_vae" / "encoder"
if (encoder_dir / "model.safetensors").exists():
return encoder_dir
from huggingface_hub import hf_hub_download
vae_path = hf_hub_download(
source_repo,
"audio_vae/diffusion_pytorch_model.safetensors",
)
raw_weights = mx.load(vae_path)
from mlx_video.models.ltx_2.audio_vae import AudioEncoder
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
# Build config from the decoder config (same audio VAE architecture)
decoder_config_path = model_path / "audio_vae" / "decoder" / "config.json"
if decoder_config_path.exists():
with open(decoder_config_path) as f:
dec_cfg = json.load(f)
enc_config = {
"ch": dec_cfg.get("ch", 128),
"in_channels": dec_cfg.get("out_ch", 2),
"ch_mult": dec_cfg.get("ch_mult", [1, 2, 4]),
"num_res_blocks": dec_cfg.get("num_res_blocks", 2),
"attn_resolutions": dec_cfg.get("attn_resolutions", []),
"resolution": dec_cfg.get("resolution", 256),
"z_channels": dec_cfg.get("z_channels", 8),
"double_z": True,
"n_fft": 1024,
"norm_type": dec_cfg.get("norm_type", "pixel"),
"causality_axis": dec_cfg.get("causality_axis", "height"),
"dropout": dec_cfg.get("dropout", 0.0),
"mid_block_add_attention": dec_cfg.get("mid_block_add_attention", False),
"sample_rate": dec_cfg.get("sample_rate", 16000),
"mel_hop_length": dec_cfg.get("mel_hop_length", 160),
"is_causal": dec_cfg.get("is_causal", True),
"mel_bins": dec_cfg.get("mel_bins", 64) or 64,
"resamp_with_conv": dec_cfg.get("resamp_with_conv", True),
"attn_type": dec_cfg.get("attn_type", "vanilla"),
}
else:
enc_config = {"in_channels": 2, "double_z": True, "n_fft": 1024, "mel_bins": 64}
config = AudioEncoderModelConfig.from_dict(enc_config)
encoder = AudioEncoder(config)
sanitized = encoder.sanitize(raw_weights)
encoder_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(encoder_dir / "model.safetensors"), sanitized)
with open(encoder_dir / "config.json", "w") as f:
json.dump(enc_config, f, indent=2)
print(f"Audio encoder weights saved to {encoder_dir}")
return encoder_dir

View File

@@ -0,0 +1,8 @@
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.tiling import (
SpatialTilingConfig,
TemporalTilingConfig,
TilingConfig,
)
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
# Height padding (axis 2)
if pad_h > 0:
# Get reflection indices - exclude boundary
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][
:, :, ::-1, :, :
] # Flip bottom portion
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
# Width padding (axis 3)
if pad_w > 0:
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w - 1 : -1, :][
:, :, :, ::-1, :
] # Flip right portion
x = mx.concatenate([left_pad, x, right_pad], axis=3)
return x
@@ -50,7 +54,7 @@ def make_conv_nd(
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> nn.Module:
if dims == 2:
return CausalConv2d(
in_channels=in_channels,
@@ -118,15 +122,17 @@ class CausalConv3d(nn.Module):
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
use_causal = causal if causal is not None else self.causal
# Apply temporal padding via frame replication
# Apply temporal padding via frame replication
# Only apply if kernel_size > 1
if self.time_kernel_size > 1:
if use_causal:
# Causal: replicate first frame kernel_size-1 times at the beginning
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
first_frame_pad = mx.repeat(
x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2
)
x = mx.concatenate([first_frame_pad, x], axis=2)
else:
# Non-causal: replicate first frame at start, last frame at end
@@ -176,7 +182,6 @@ class CausalConv3d(nn.Module):
"""
b, d, h, w, c = x.shape
total_elements = d * h * w * c
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
@@ -191,11 +196,10 @@ class CausalConv3d(nn.Module):
overlap = kernel_t - 1
expected_output_frames = d - overlap
outputs = []
out_idx = 0
out_idx = 0
# Process chunks
in_start = 0

View File

@@ -15,15 +15,16 @@ Architecture (from PyTorch weights):
"""
import math
from typing import Optional
from pathlib import Path
from typing import Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
def get_timestep_embedding(
@@ -76,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim
in_channels=256, time_embed_dim=embedding_dim
)
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
def __call__(
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
) -> mx.array:
timesteps_proj = get_timestep_embedding(
timestep,
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb
@@ -118,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
@@ -129,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
padding=1,
spatial_padding_mode=padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__(
self,
@@ -152,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
ada_values = self.scale_shift_table[
None, :, :, None, None, None
] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
@@ -198,16 +202,14 @@ class ResBlockGroup(nn.Module):
# Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=channels * 4
)
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
# Use dict with int keys for MLX to track parameters properly
self.res_blocks = {
i: ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
timestep_conditioning=timestep_conditioning,
)
for i in range(num_layers)
}
@@ -223,8 +225,7 @@ class ResBlockGroup(nn.Module):
if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0]
timestep_embed = self.time_embedder(
timestep.flatten(),
hidden_dtype=x.dtype
timestep.flatten(), hidden_dtype=x.dtype
)
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
@@ -249,6 +250,18 @@ class LTX2VideoDecoder(nn.Module):
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
"""
# Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride)
# stride is (D, H, W) tuple
DEFAULT_BLOCKS = [
("res", 1024, 5),
("d2s", 1024, 2, (2, 2, 2)),
("res", 512, 5),
("d2s", 512, 2, (2, 2, 2)),
("res", 256, 5),
("d2s", 256, 2, (2, 2, 2)),
("res", 128, 5),
]
def __init__(
self,
in_channels: int = 128,
@@ -257,6 +270,7 @@ class LTX2VideoDecoder(nn.Module):
num_layers_per_block: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = True,
decoder_blocks: list = None,
):
super().__init__()
@@ -269,72 +283,72 @@ class LTX2VideoDecoder(nn.Module):
self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights)
self.latents_mean = mx.zeros((in_channels,))
self.latents_std = mx.ones((in_channels,))
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Initial conv: 128 -> 1024
blocks = decoder_blocks or self.DEFAULT_BLOCKS
first_ch = blocks[0][1]
last_ch = blocks[-1][1]
# Initial conv: in_channels -> first block channels
class ConvInWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_channels,
out_channels=1024,
out_channels=first_ch,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
# Use dict with int keys for MLX to track parameters properly
self.up_blocks = {
0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
1: DepthToSpaceUpsample(
dims=3,
in_channels=1024,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
3: DepthToSpaceUpsample(
dims=3,
in_channels=512,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
5: DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
residual=True,
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
}
# Build up blocks from config
self.up_blocks = {}
for idx, block_def in enumerate(blocks):
block_type = block_def[0]
ch = block_def[1]
if block_type == "res":
num_layers = (
block_def[2] if len(block_def) > 2 else num_layers_per_block
)
self.up_blocks[idx] = ResBlockGroup(
ch, num_layers, spatial_padding_mode, timestep_conditioning
)
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
residual = block_def[4] if len(block_def) > 4 else True
self.up_blocks[idx] = DepthToSpaceUpsample(
dims=3,
in_channels=ch,
stride=stride,
residual=residual,
out_channels_reduction_factor=reduction,
spatial_padding_mode=spatial_padding_mode,
)
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=128,
in_channels=last_ch,
out_channels=final_out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper()
self.act = nn.SiLU()
@@ -342,21 +356,202 @@ class LTX2VideoDecoder(nn.Module):
if timestep_conditioning:
self.timestep_scale_multiplier = mx.array(1000.0)
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=128 * 2 # 256, matches (2, 128) table
embedding_dim=last_ch * 2
)
self.last_scale_shift_table = mx.zeros((2, 128))
self.last_scale_shift_table = mx.zeros((2, last_ch))
def denormalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics."""
dtype = x.dtype
# Cast to float32 for precision (statistics may be in bfloat16)
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return (x * std + mean).astype(dtype)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
if key.startswith("vae.per_channel_statistics."):
# Map per-channel statistics (use exact key matching)
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue # Skip other statistics keys
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if (
".conv.conv.weight" not in new_key
and ".conv.conv.bias" not in new_key
):
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(
cls, model_path: Path, strict: bool = True
) -> "LTX2VideoDecoder":
"""Load a pretrained decoder from a directory with config.json and weights.
Args:
model_path: Path to directory containing config.json and safetensors files,
or path to a single safetensors file.
strict: Whether to require all weight keys to match.
Returns:
Loaded LTX2VideoDecoder instance
"""
import json
model_path = Path(model_path)
config_dict = {}
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Load weights from directory
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
# Determine spatial padding mode from config
spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect")
spatial_padding_mode = PaddingModeType(spatial_padding_mode_str)
model = cls(
timestep_conditioning=config_dict.get("timestep_conditioning", False),
decoder_blocks=decoder_blocks,
spatial_padding_mode=spatial_padding_mode,
)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
@staticmethod
def _infer_blocks(weights: dict) -> list:
"""Infer decoder block structure from weight keys."""
block_indices = set()
for k in weights:
if "up_blocks." in k:
idx_str = k.split("up_blocks.")[1].split(".")[0]
if idx_str.isdigit():
block_indices.add(int(idx_str))
if not block_indices:
return None
# First pass: collect block info
raw_blocks = []
for idx in sorted(block_indices):
has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights)
res_indices = set()
for k in weights:
prefix = f"up_blocks.{idx}.res_blocks."
if prefix in k:
res_idx = k.split(prefix)[1].split(".")[0]
if res_idx.isdigit():
res_indices.add(int(res_idx))
if has_conv and not res_indices:
# D2S block - get conv shape
for k, v in weights.items():
if f"up_blocks.{idx}.conv." in k and "weight" in k:
in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1]
conv_out_ch = v.shape[0]
raw_blocks.append(("d2s", in_ch, conv_out_ch))
break
elif res_indices:
num_res = max(res_indices) + 1
for k, v in weights.items():
if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k:
ch = v.shape[0]
raw_blocks.append(("res", ch, num_res))
break
# Second pass: determine d2s strides using the channel progression
# For each d2s block, the next res block tells us the expected output channels
blocks = []
d2s_strides = []
for i, block in enumerate(raw_blocks):
if block[0] == "res":
blocks.append(block)
elif block[0] == "d2s":
in_ch, conv_out_ch = block[1], block[2]
# Find next res block's channels
next_ch = None
for j in range(i + 1, len(raw_blocks)):
if raw_blocks[j][0] == "res":
next_ch = raw_blocks[j][1]
break
if next_ch is None:
next_ch = in_ch // 2 # fallback
# out_ch = in_ch // reduction
reduction = in_ch // next_ch if next_ch > 0 else 2
# conv_out = next_ch * multiplier → multiplier = conv_out / next_ch
multiplier = conv_out_ch // next_ch if next_ch > 0 else 8
# Determine stride from multiplier
if multiplier == 8:
stride = (2, 2, 2)
elif multiplier == 4:
stride = (1, 2, 2)
elif multiplier == 2:
stride = (2, 1, 1)
else:
stride = (2, 2, 2)
d2s_strides.append(stride)
blocks.append(("d2s", in_ch, reduction, stride))
if not blocks:
return None
# Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True
# LTX-2.3 has mixed strides or reduction=1 → residual=False
has_mixed_strides = len(set(d2s_strides)) > 1
has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s")
use_residual = not has_mixed_strides and not has_non_standard_reduction
# Apply residual flag to all d2s blocks
final_blocks = []
for block in blocks:
if block[0] == "d2s":
final_blocks.append(("d2s", block[1], block[2], block[3], use_residual))
else:
final_blocks.append(block)
return final_blocks
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__(
self,
@@ -366,29 +561,15 @@ class LTX2VideoDecoder(nn.Module):
debug: bool = False,
chunked_conv: bool = False,
) -> mx.array:
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
batch_size = sample.shape[0]
if debug:
debug_stats("Input", sample)
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
if debug:
debug_stats("After noise", sample)
if debug:
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
sample = self.denormalize(sample)
if debug:
debug_stats("After denormalize", sample)
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
@@ -398,8 +579,6 @@ class LTX2VideoDecoder(nn.Module):
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
if debug:
debug_stats("After conv_in", x)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
@@ -408,22 +587,18 @@ class LTX2VideoDecoder(nn.Module):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
if debug:
block_type = type(block).__name__
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
x = self.pixel_norm(x)
if debug:
debug_stats("After pixel_norm", x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(),
hidden_dtype=x.dtype
scaled_timestep.flatten(), hidden_dtype=x.dtype
)
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
ada_values = self.last_scale_shift_table[
None, :, :, None, None, None
] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped
@@ -431,21 +606,13 @@ class LTX2VideoDecoder(nn.Module):
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
if debug:
debug_stats("After timestep modulation", x)
x = self.act(x)
if debug:
debug_stats("After activation", x)
x = self.conv_out(x, causal=causal)
if debug:
debug_stats("After conv_out", x)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
if debug:
debug_stats("After unpatchify", x)
return x
@@ -502,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
use_chunked_conv = tiling_mode in (
"conservative",
"none",
"auto",
"default",
"spatial",
)
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
return self(
sample,
causal=causal,
timestep=timestep,
debug=debug,
chunked_conv=use_chunked_conv,
)
return decode_with_tiling(
decoder_fn=self,
@@ -521,101 +700,5 @@ class LTX2VideoDecoder(nn.Module):
)
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path
import json
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...")
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.decoder."
stats_prefix = "vae.per_channel_statistics."
elif has_decoder_prefix:
prefix = "decoder."
stats_prefix = ""
else:
prefix = ""
stats_prefix = ""
# Load per-channel statistics for denormalization
# Note: use std-of-means (not mean-of-stds) for proper denormalization
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
if mean_key in weights:
decoder.latents_mean = weights[mean_key]
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
if std_key in weights:
decoder.latents_std = weights[std_key]
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
# Build decoder weights dict with key remapping
decoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
decoder_weights[new_key] = value
print(f" Found {len(decoder_weights)} decoder weights")
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
print(f" Found {len(ts_keys)} timestep conditioning weights")
# Load weights
decoder.load_weights(list(decoder_weights.items()), strict=False)
print("VAE decoder loaded successfully")
return decoder
# Backward-compatible alias
VideoDecoder = LTX2VideoDecoder

View File

@@ -0,0 +1,44 @@
"""Video VAE Encoder for LTX-2 Image-to-Video.
The encoder compresses input images/videos to latent representations.
Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation.
"""
import mlx.core as mx
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
def encode_image(
image: mx.array,
encoder: VideoEncoder,
) -> mx.array:
"""Encode a single image to latent space.
Args:
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
encoder: Loaded VAE encoder
Returns:
Latent tensor of shape (1, 128, 1, H//32, W//32)
"""
# Add batch dimension if needed
if image.ndim == 3:
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
# Convert from (B, H, W, C) to (B, C, H, W)
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
# Normalize to [-1, 1]
if image.max() > 1.0:
image = image / 255.0
image = image * 2.0 - 1.0
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
# Encode
latent = encoder(image)
return latent

View File

@@ -1,6 +1,5 @@
"""Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
x = mx.reshape(
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
)
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
@@ -101,7 +102,7 @@ class PerChannelStatistics(nn.Module):
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
dtype = x.dtype
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
@@ -117,7 +118,7 @@ class PerChannelStatistics(nn.Module):
Returns:
Denormalized tensor
"""
dtype = x.dtype
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)

View File

@@ -6,7 +6,7 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.utils import PixelNorm
@@ -44,7 +44,7 @@ class ResnetBlock3D(nn.Module):
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
out_channels = out_channels or in_channels
@@ -96,7 +96,7 @@ class ResnetBlock3D(nn.Module):
causal: bool = True,
generator: Optional[int] = None,
) -> mx.array:
residual = x
# First block
@@ -136,7 +136,7 @@ class UNetMidBlock3D(nn.Module):
attention_head_dim: Optional[int] = None,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.num_layers = num_layers

View File

@@ -5,7 +5,7 @@ from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
class SpaceToDepthDownsample(nn.Module):
@@ -104,7 +104,7 @@ class SpaceToDepthDownsample(nn.Module):
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int,
@@ -114,7 +114,7 @@ class DepthToSpaceUpsample(nn.Module):
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
return x
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
def __call__(
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
def _chunked_conv_depth_to_space(
self, x: mx.array, causal: bool = True
) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.

View File

@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
# Apply right ramp (fade out)
if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
fade_out = [
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
]
for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i]
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
raise ValueError(
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
)
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
raise ValueError(
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
)
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
raise ValueError(
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
)
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
raise ValueError(
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
)
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
raise ValueError(
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
)
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
raise ValueError(
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
)
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
@@ -113,15 +127,21 @@ class TilingConfig:
def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
),
)
@classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
),
temporal_config=None,
)
@@ -130,23 +150,33 @@ class TilingConfig:
"""Temporal tiling only (for long videos with small resolution)."""
return cls(
spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
),
)
@classmethod
def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=256, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=32, tile_overlap_in_frames=8
),
)
@classmethod
def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=768, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=96, tile_overlap_in_frames=24
),
)
@classmethod
@@ -160,6 +190,9 @@ class TilingConfig:
) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions.
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
enough context for CausalConv3d and sufficient overlap for clean blending.
Args:
height: Video height in pixels
width: Video width in pixels
@@ -176,37 +209,21 @@ class TilingConfig:
if not needs_spatial and not needs_temporal:
return None
# Estimate memory requirement (rough heuristic)
# Output size in bytes (float32): B * 3 * F * H * W * 4
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
# For very large videos, use aggressive tiling
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
return cls.aggressive()
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
# Smaller tiles cause quality degradation because CausalConv3d needs
# sufficient temporal context and overlap for clean blending.
spatial_config = None
temporal_config = None
if needs_spatial:
# Choose tile size based on resolution
max_dim = max(height, width)
if max_dim > 1024:
tile_size = 384 # Smaller tiles for very large resolutions
elif max_dim > 768:
tile_size = 512
else:
tile_size = 384
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
spatial_config = SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
)
if needs_temporal:
# Choose tile size based on frame count
if num_frames > 200:
tile_size, overlap = 32, 8 # Aggressive for very long videos
elif num_frames > 100:
tile_size, overlap = 48, 16
else:
tile_size, overlap = 64, 24
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
temporal_config = TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@@ -214,16 +231,21 @@ class TilingConfig:
@dataclass
class DimensionIntervals:
"""Intervals for splitting a single dimension."""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
def split_in_spatial(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a spatial dimension into intervals."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
@@ -232,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
return DimensionIntervals(
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
def split_in_temporal(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
# Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size)
@@ -251,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
return DimensionIntervals(
starts=starts,
ends=intervals.ends,
left_ramps=left_ramps,
right_ramps=intervals.right_ramps,
)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
def map_temporal_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask."""
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, True
)
return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
def map_spatial_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask."""
start = begin * scale
stop = end * scale
left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, False
)
return slice(start, stop), mask
@@ -332,7 +373,9 @@ def decode_with_tiling(
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_temporal(
temporal_tile_size, temporal_overlap, f_latent
)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -355,7 +398,9 @@ def decode_with_tiling(
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_temporal_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -364,7 +409,9 @@ def decode_with_tiling(
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
out_h_slice, h_mask = map_spatial_slice(
h_start, h_end, h_left, h_right, spatial_scale
)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
@@ -373,13 +420,23 @@ def decode_with_tiling(
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
out_w_slice, w_mask = map_spatial_slice(
w_start, w_end, w_left, w_right, spatial_scale
)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
tile_latents = latents[
:, :, t_start:t_end, h_start:h_end, w_start:w_end
]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
tile_output = decoder_fn(
tile_latents,
causal=causal,
timestep=timestep,
debug=False,
chunked_conv=chunked_conv,
)
mx.eval(tile_output)
# Clear tile_latents reference
@@ -402,13 +459,15 @@ def decode_with_tiling(
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
t_mask_slice.reshape(1, 1, -1, 1, 1)
* h_mask_slice.reshape(1, 1, 1, -1, 1)
* w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
tile_output_slice = tile_output[
:, :, :actual_t, :actual_h, :actual_w
].astype(mx.float32)
# Clear full tile_output
del tile_output
@@ -426,11 +485,37 @@ def decode_with_tiling(
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ blend_mask
)
# Force evaluation to free memory
@@ -462,10 +547,12 @@ def decode_with_tiling(
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
next_tile_start_out = (
1 + (next_tile_start_latent - 1) * temporal_scale
)
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
if not hasattr(decode_with_tiling, "_emitted_frames"):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
@@ -473,7 +560,10 @@ def decode_with_tiling(
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = (
output[:, :, emitted:next_tile_start_out, :, :]
/ finalized_weights
)
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
@@ -490,7 +580,7 @@ def decode_with_tiling(
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
@@ -498,7 +588,7 @@ def decode_with_tiling(
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
if hasattr(decode_with_tiling, "_emitted_frames"):
del decode_with_tiling._emitted_frames
# Clean up weights

View File

@@ -1,20 +1,24 @@
"""Video VAE Encoder and Decoder for LTX-2."""
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx.video_vae.resnet import (
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import (
PerChannelStatistics,
patchify,
unpatchify,
)
from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx.video_vae.sampling import (
from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample,
SpaceToDepthDownsample,
)
@@ -23,6 +27,7 @@ from mlx_video.utils import PixelNorm
class LogVarianceType(Enum):
"""Log variance mode for VAE."""
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
@@ -221,46 +226,31 @@ class VideoEncoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 3,
out_channels: int = 128,
encoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
latent_log_var: LogVarianceType = LogVarianceType.UNIFORM,
encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize VideoEncoder.
def __init__(self, config: "VideoEncoderModelConfig"):
"""Initialize VideoEncoder from config.
Args:
convolution_dimensions: Number of dimensions (3 for video)
in_channels: Input channels (3 for RGB)
out_channels: Output latent channels
encoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
latent_log_var: Log variance mode
encoder_spatial_padding_mode: Padding mode
config: VideoEncoderModelConfig with encoder parameters
"""
super().__init__()
if encoder_blocks is None:
encoder_blocks = []
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.patch_size = config.patch_size
self.norm_layer = config.norm_layer
self.latent_channels = config.out_channels
self.latent_log_var = config.latent_log_var
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels)
self.per_channel_statistics = PerChannelStatistics(
latent_channels=config.out_channels
)
# After patchify, channels increase by patch_size^2
in_channels = in_channels * patch_size ** 2
feature_channels = out_channels
in_channels = config.in_channels * config.patch_size**2
feature_channels = config.out_channels
# Initial convolution
self.conv_in = CausalConv3d(
@@ -273,39 +263,47 @@ class VideoEncoder(nn.Module):
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks - use dict with int keys for MLX parameter tracking
# Build encoder blocks
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {}
for i, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_encoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
convolution_dimensions=config.convolution_dimensions,
norm_layer=config.norm_layer,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks[i] = block
self.down_blocks[idx] = block
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
if config.norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif norm_layer == NormLayerType.PIXEL_NORM:
elif config.norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
# Calculate output convolution channels
conv_out_channels = out_channels
if latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels = config.out_channels
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
elif config.latent_log_var in {
LogVarianceType.UNIFORM,
LogVarianceType.CONSTANT,
}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
@@ -341,7 +339,8 @@ class VideoEncoder(nn.Module):
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for down_block in self.down_blocks.values():
for i in range(len(self.down_blocks)):
down_block = self.down_blocks[i]
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else:
@@ -362,15 +361,99 @@ class VideoEncoder(nn.Module):
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30
sample = mx.concatenate([
sample,
mx.full_like(sample, approx_ln_0),
], axis=1)
sample = mx.concatenate(
[
sample,
mx.full_like(sample, approx_ln_0),
],
axis=1,
)
# Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...]
means = sample[:, : self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
if "position_ids" in key:
continue
# Only process VAE encoder weights
if not key.startswith("vae."):
continue
# Handle per-channel statistics
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
elif key.startswith("vae.encoder."):
new_key = key.replace("vae.encoder.", "")
else:
continue
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
"""Load a pretrained VideoEncoder from a directory with weights and config.
Args:
model_path: Path to directory containing safetensors weights and config.json
Returns:
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
config = VideoEncoderModelConfig(**config_dict)
else:
config = VideoEncoderModelConfig()
# Load weights
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
if model_path.is_file():
weights = mx.load(str(model_path))
else:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
else:
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Create model, sanitize and load weights
model = cls(config)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=False)
return model
class VideoDecoder(nn.Module):
@@ -407,7 +490,7 @@ class VideoDecoder(nn.Module):
decoder_blocks = []
self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2
out_channels = out_channels * patch_size**2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
@@ -440,9 +523,14 @@ class VideoDecoder(nn.Module):
)
# Build decoder blocks (reversed order)
self.up_blocks = []
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_decoder_block(
block_name=block_name,
@@ -454,7 +542,7 @@ class VideoDecoder(nn.Module):
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks.append(block)
self.up_blocks[idx] = block
# Output normalization
if norm_layer == NormLayerType.GROUP_NORM:
@@ -509,7 +597,8 @@ class VideoDecoder(nn.Module):
sample = self.conv_in(sample, causal=self.causal)
# Process through decoder blocks
for up_block in self.up_blocks:
for i in range(len(self.up_blocks)):
up_block = self.up_blocks[i]
if isinstance(up_block, UNetMidBlock3D):
sample = up_block(sample, causal=self.causal)
elif isinstance(up_block, ResnetBlock3D):

View File

@@ -1,2 +0,0 @@
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel

View File

@@ -1,394 +0,0 @@
# Wan2.2 I2V-14B Diagnostic Report
This document records the systematic diagnostic methodology used to debug the Wan2.2 I2V-14B (Image-to-Video, 14 billion parameter) pipeline in mlx-video, along with every bug found, its root cause, and fix.
## Table of Contents
- [Overview](#overview)
- [Architecture Summary](#architecture-summary)
- [Diagnostic Methodology](#diagnostic-methodology)
- [Bug 1: Text Embedding Cross-Contamination](#bug-1-text-embedding-cross-contamination)
- [Bug 2: VAE Encoder Weights Excluded from Conversion](#bug-2-vae-encoder-weights-excluded-from-conversion)
- [Bug 3: RoPE Frequency Computation (original)](#bug-3-rope-frequency-computation-original)
- [Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)](#bug-6-rope-frequency-distribution-bug-3-fix-was-wrong)
- [Bug 4: VAE Encoder Temporal Downsample Order](#bug-4-vae-encoder-temporal-downsample-order)
- [Bug 5: Non-Chunked VAE Encoding](#bug-5-non-chunked-vae-encoding)
- [Verified Correct Components](#verified-correct-components)
- [Performance Optimizations](#performance-optimizations)
- [Resolved: CFG Effectiveness](#resolved-cfg-effectiveness-was-open-investigation)
- [Reference Implementation](#reference-implementation)
- [Useful Diagnostic Commands](#useful-diagnostic-commands)
---
## Overview
The I2V-14B pipeline takes an input image and generates a video using a dual-model diffusion transformer. The initial implementation produced severely broken output — first frame showed the image, subsequent frames degraded to noise, checkerboard artifacts, or flat grey.
Through a systematic component-by-component comparison against the reference PyTorch implementation, **five bugs** were found and fixed. The approach was to verify each component in isolation numerically, then narrow down failures to the subsystem level.
### Timeline of Symptoms
| Stage | Symptom | Root Cause |
|-------|---------|------------|
| Initial | Grey/blurry frames after frame 1 | Non-chunked VAE encoding (Bug 5) |
| After chunked encoding fix | First frame OK, rest degrades to noise | Text embedding cross-contamination (Bug 1) + RoPE frequencies (Bug 3) |
| After text + RoPE fix | Severe 8px checkerboard on frames 4+ | VAE encoder temporal downsample order (Bug 4) |
| After VAE fix | Image in frames 0-3, grey frames 4+ | CFG effectiveness issue (open investigation) |
---
## Architecture Summary
```
I2V-14B Pipeline:
Input Image → VAE Encoder → [16, T_lat, H_lat, W_lat]
Mask Construction → [4, T_lat, H_lat, W_lat]
y = concat(mask, encoded_video) → [20, T_lat, H_lat, W_lat]
Noise [16, T_lat, H_lat, W_lat] + y → [36, T_lat, H_lat, W_lat]
Dual DiT (40 layers, 5120 dim) × 40 denoising steps
Denoised Latent [16, T_lat, H_lat, W_lat]
VAE Decoder → Video [3, F, H, W]
```
**Key parameters:**
- `in_dim=36` (16 noise + 4 mask + 16 image latents), `out_dim=16`
- Dual model: HIGH noise (t ≥ 900) and LOW noise (t < 900)
- 40 steps, shift=5.0, guide_scale=(3.5, 3.5)
- Uses Wan2.1 VAE (z_dim=16, stride 4×8×8)
---
## Diagnostic Methodology
### 1. Component-Level Numerical Verification
Each component was tested in isolation against the reference PyTorch implementation:
1. **Load identical inputs** (same random seed, same image, same prompt)
2. **Run through reference** (on CPU where possible) and save intermediate tensors as `.npy`
3. **Run through MLX** with the same inputs
4. **Compare outputs** with `np.abs(ours - ref).max()` and relative difference metrics
Components tested this way:
- RoPE frequency parameters and rotation output
- Time embedding (sinusoidal → MLP → projection)
- Patchify (reshape+Linear vs Conv3d)
- Unpatchify (transpose-based vs einsum)
- Scheduler (UniPC) timesteps and step formulas
- VAE encoder output (frame-by-frame comparison)
- Text embeddings (per-model MLP output)
- Cross-attention K/V cache shapes
- Mask construction values
### 2. Artifact Analysis
When visual artifacts appeared, quantitative metrics were used to characterize them:
- **Checkerboard metric**: Difference between even-indexed and odd-indexed pixels at patch boundaries. Values > 20 indicate visible checkerboard.
- **FFT frequency analysis**: Power at the 8px spatial frequency (matches VAE stride). 3× normal power confirmed VAE-stride-aligned artifacts.
- **Per-frame statistics**: Mean, std, min, max for each decoded video frame to track temporal degradation.
- **Frame difference**: `mean(|frame[i] - frame[i-1]|)` to measure motion vs static content.
### 3. Isolation Testing
- **VAE round-trip test**: Encode image+zeros → decode. If clean, VAE decoder is not the source.
- **Single-step model output**: Run one diffusion step and compare cond vs uncond predictions to check CFG effectiveness.
- **Patchify/unpatchify synthetic test**: Pass structured gradient through unpatchify to verify spatial ordering.
- **Resolution sweeps**: Test at 480×272, 640×384, 1280×720 to check resolution dependence.
- **Step count sweeps**: Test at 5, 20, 40 steps to distinguish convergence issues from model bugs.
### 4. Weight Comparison
Direct comparison of converted MLX weights against original PyTorch weights:
```python
# Load both weight sets
pt_weights = torch.load("model.safetensors")
mlx_weights = mx.load("model.safetensors")
# Compare each key
for key in pt_weights:
diff = np.abs(np.array(pt_weights[key]) - np.array(mlx_weights[key])).max()
```
Expected: max diff ≈ 0.001 (bfloat16 rounding). Actual: confirmed for all keys.
---
## Bug 1: Text Embedding Cross-Contamination
**Symptom:** Model ignores text prompt, generated frames lack semantic content.
**Root Cause:** For the dual-model architecture (high-noise and low-noise experts), text embeddings were computed using only `low_noise_model.embed_text()` and reused for both models' cross-attention K/V caches. The two models have **different** text embedding MLP weights — 42% relative mean difference in output.
**How Found:** Compared `text_embedding_0.weight` and `text_embedding_1.weight` between `high_noise_model.safetensors` and `low_noise_model.safetensors`. Found 17.9% and 26.3% relative differences in the weight matrices.
**Fix:** Compute separate text embeddings per model:
```python
# Before (broken):
context_emb = low_noise_model.embed_text([context, context_null])
cross_kv = low_noise_model.prepare_cross_kv(context_emb) # used for BOTH models
# After (correct):
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
cross_kv_low = low_noise_model.prepare_cross_kv(context_emb_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_emb_high)
```
**File:** `mlx_video/generate_wan.py` (lines 333349)
**Commit:** `a85b1c21`
---
## Bug 2: VAE Encoder Weights Excluded from Conversion
**Symptom:** VAE encoder produces constant output regardless of input image (all-zero weights after conversion).
**Root Cause:** The conversion script only included encoder weights for `model_type == "ti2v"` (TI2V-5B), not for `"i2v"` (I2V-14B). Since `load_vae_encoder()` uses `strict=False`, missing encoder weights were silently ignored, resulting in random initialization.
**How Found:** Traced through `convert_wan.py` and found `include_encoder = config.model_type == "ti2v"`. Cross-referenced with the fact that I2V-14B also requires a VAE encoder (for image conditioning).
**Fix:**
```python
# Before:
include_encoder = config.model_type == "ti2v"
# After:
include_encoder = config.model_type in ("ti2v", "i2v")
```
**Note:** The user's specific model happened to be manually converted with encoder weights already present, so this fix was preventive for future conversions.
**File:** `mlx_video/convert_wan.py` (line 424)
---
## Bug 3: RoPE Frequency Computation (original)
**Symptom:** Progressive 2px checkerboard artifacts on generated frames, increasing with temporal distance from the conditioned frame.
**Root Cause (original):** Our original code called `rope_params` three times but applied them incorrectly (per-axis in the model init, then rope_apply did NOT split). This was initially "fixed" by switching to a single `rope_params(1024, head_dim=128)` call, which reduced checkerboard but introduced Bug 6 (see below).
**File:** `mlx_video/models/wan/model.py`
**Commit:** `3da4a637`
---
## Bug 6: RoPE Frequency Distribution (Bug 3 Fix Was Wrong)
**Symptom:** I2V generates input image in frames 03, colorful checkerboard on frame 4, then grey frames. CFG cond/uncond predictions nearly identical. Model cannot produce coherent motion.
**Root Cause:** The Bug 3 "fix" replaced three separate `rope_params` calls with a single `rope_params(1024, 128)`. But the reference (`wan/modules/model.py` lines 400405) actually uses **three separate calls with different dimension normalizations**, concatenated:
```python
# Reference (CORRECT):
d = dim // num_heads # 128
self.freqs = torch.cat([
rope_params(1024, d - 4 * (d // 6)), # rope_params(1024, 44)
rope_params(1024, 2 * (d // 6)), # rope_params(1024, 42)
rope_params(1024, 2 * (d // 6)) # rope_params(1024, 42)
], dim=1)
```
Each axis gets its own full frequency range [θ^0, θ^(-~0.95)]. The single-call approach gave:
- Temporal: low frequencies only [1.0 → 0.049]
- Height: medium frequencies only [0.042 → 0.002] (should start at 1.0!)
- Width: high frequencies only [0.002 → 0.0001] (should start at 1.0!)
The height/width position encoding was essentially destroyed — nearby spatial positions were indistinguishable (max diff 0.958 for height, 0.998 for width vs reference).
**How Found:** Direct line-by-line comparison of `WanModel.__init__` freq construction between reference `wan/modules/model.py` and our `models/wan/model.py`. Numerical verification confirmed the three-call approach gives each axis a full [0, ~1) exponent range, while the single-call monotonically assigns low→high across axes.
**Fix:**
```python
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
```
**Verification:** Max diff vs reference cos/sin: 0.00000000 (exact float32 match).
**Impact:** Affects ALL Wan models (T2V, I2V, TI2V). Resolves the "Open Investigation: CFG Effectiveness" issue — the model could not produce meaningful cond/uncond differences because it couldn't encode spatial positions.
**File:** `mlx_video/models/wan/model.py` (line 155)
---
## Bug 4: VAE Encoder Temporal Downsample Order
**Symptom:** Massive checkerboard artifacts aligned to VAE spatial stride (8px period). VAE encoder output for frames 14 showed decreasing std (0.37→1.19) while reference showed stable std (0.95→1.34).
**Root Cause:** The VAE encoder has 3 downsampling stages. Two perform spatial+temporal downsampling (`downsample3d`) and one performs spatial-only (`downsample2d`). The order matters:
```
Reference: [False, True, True] → stage 0: 2d, stage 1: 3d, stage 2: 3d
Ours: [True, True, False] → stage 0: 3d, stage 1: 3d, stage 2: 2d ← WRONG
```
This caused temporal downsampling to happen at the wrong resolution stages (96-dim instead of 384-dim), corrupting temporal feature propagation.
**How Found:** Installed `einops` in the reference environment and ran the reference PyTorch VAE encoder on CPU. Compared frame-by-frame latent output:
- Frame 0 matched exactly (diff=0.0000) — spatial-only processing was correct
- Frames 14 had massive differences — proved temporal processing was broken
Then traced through the reference `_video_vae()` function and found it sets `temperal_downsample=[False, True, True]`, while our `Encoder3d` class used the wrong default `[True, True, False]`.
**Fix:**
```python
# In Encoder3d.__init__, change default:
temporal_downsample = [False, True, True] # was [True, True, False]
```
**Impact:** Encoder output now matches reference within float32 precision (max_diff=2.2e-5). Checkerboard metric dropped from 6080 to 0.17.7.
**File:** `mlx_video/models/wan/vae.py` (line 370)
**Commit:** `3da4a637`
---
## Bug 5: Non-Chunked VAE Encoding
**Symptom:** First 45 frames grey, then blurred version of image appears.
**Root Cause:** The reference VAE encoder uses **chunked encoding** with temporal caching (`feat_cache`):
1. Encode first frame alone (1 frame)
2. Encode remaining frames in chunks of 4, with cached temporal features propagating across chunks
3. Each `CausalConv3d` caches last 2 temporal frames from its output, prepending them to the next chunk's input
Our original implementation encoded all frames at once with zero-padded causal convolutions. The temporal feature propagation is fundamentally different because:
- Chunked: real features from previous chunks serve as causal context
- Non-chunked: zeros serve as causal context for the start
**How Found:** Studied the reference `CausalConv3d` caching mechanism (`feat_cache`, `feat_idx`) and traced the temporal dimension through all encoding stages. Confirmed that non-chunked encoding produces different output by comparing tensor shapes and values.
**Fix:** Implemented full chunked encoding with temporal caching:
- Added `cache_x` parameter to `CausalConv3d.__call__`
- Added `feat_cache`/`feat_idx` propagation to `ResidualBlock`, `Resample`, `Encoder3d`
- Rewrote `WanVAE.encode()` with chunked loop (1-frame first chunk, then 4-frame chunks)
- 24 cache slots across the encoder (1 conv1 + 18 downsamples + 4 middle + 1 head)
**File:** `mlx_video/models/wan/vae.py` (multiple methods)
**Commit:** `b6a94c4c`
---
## Verified Correct Components
These components were numerically verified against the reference and are **not** sources of bugs:
| Component | Method | Max Diff | Notes |
|-----------|--------|----------|-------|
| Weight conversion | Direct tensor comparison | ~0.001 | bfloat16 rounding only |
| RoPE rotation | Standalone comparison (float32 vs float64) | 1.3e-5 | Complex vs real multiplication equivalent |
| Time embedding | Full MLP comparison (sinusoidal→embed→project) | 7e-4 | 0.03% relative |
| Patchify | Conv3d vs reshape+Linear | 3.5e-3 | 0.16% relative |
| Unpatchify | einsum vs transpose(6,0,3,1,4,2,5) | exact | Identical operation |
| Scheduler (UniPC) | Formula-level audit + timestep comparison | exact | Predictor, corrector, lambda, rhos all match |
| Mask construction | Value comparison | exact | [4, T_lat, H_lat, W_lat], first temporal=1 |
| CFG formula | Code audit | — | `uncond + gs * (cond - uncond)` correct order |
| VAE decoder | Round-trip test (encode→decode) | clean | No checkerboard in round-trip output |
| Cross-attention K/V | Shape and value audit | — | Batch dimension preserved correctly |
---
## Performance Optimizations
Applied alongside bug fixes to improve inference speed:
### Pre-Computation (Before Diffusion Loop)
- **Cross-attention K/V caching**: Precompute K/V projections for all 40 blocks once
- **RoPE cos/sin precomputation**: Build frequency tensors once instead of per-step broadcast/concat
- **Attention mask precomputation**: Build padding mask once, pass via kwargs
- **Inverse frequency caching**: Store sinusoidal `inv_freq` in `__init__` instead of recomputing
- **Timestep list conversion**: `sched.timesteps.tolist()` before loop to avoid `.item()` sync
### Per-Step Optimizations
- **Single patchify + broadcast for CFG B=2**: Detect identical batch inputs, patchify once and broadcast instead of duplicating the Linear projection
- **Vectorized RoPE**: When all batch elements share the same grid size, apply rotation to the full batch tensor instead of looping per element
- **Redundant type cast removal**: MLX type promotion handles `bfloat16 * float32 → float32` automatically — removed 240 unnecessary graph nodes per step (6 casts × 40 blocks)
- **Euler scheduler sync fix**: Pre-store sigmas as Python floats to avoid `.item()` evaluation sync
---
## Resolved: CFG Effectiveness (was Open Investigation)
**Symptom:** Generated video shows the input image in frames 03 (latent frame 0), then grey/flat frames for the rest. Cond and uncond predictions were nearly identical.
**Resolution:** This was caused by Bug 6 (incorrect RoPE frequency distribution). The single `rope_params(1024, 128)` call gave height frequencies starting at 0.042 and width at 0.002 (instead of 1.0 for both), making the model unable to encode spatial positions. This caused the transformer to produce nearly identical outputs regardless of text conditioning, explaining the tiny cond/uncond differences.
---
## Reference Implementation
The reference PyTorch implementation is at `/Users/daniel/Projects/Wan2.2/`:
| File | Contents |
|------|----------|
| `wan/image2video.py` | I2V pipeline (y construction, mask, diffusion loop) |
| `wan/modules/model.py` | DiT model (forward pass, RoPE, patchify) |
| `wan/modules/vae2_1.py` | VAE encoder/decoder with chunked encoding |
| `wan/utils/fm_solvers_unipc.py` | UniPC scheduler |
| `wan/configs/wan_i2v_A14B.py` | Model configuration |
Key structural differences between reference and our implementation:
- Reference runs **separate B=1 forward passes** for cond/uncond; we batch as B=2
- Reference uses `torch.amp.autocast('cuda', dtype=bfloat16)` with explicit float32 blocks; we cast via weight dtype
- Reference uses `Conv3d` for patchify; we use equivalent `reshape + Linear`
- Reference casts timesteps to `int64`; we keep as float (diff < 1.0)
---
## Useful Diagnostic Commands
### Run I2V-14B generation
```bash
python -m mlx_video.generate_wan \
--prompt "A woman smiles at camera" \
--image start.png \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-I2V-A14B-MLX \
--num-frames 17 --steps 40 \
--height 384 --width 640 \
--output output_i2v.mp4
```
### Check VAE encoder output
```python
import mlx.core as mx, numpy as np
from mlx_video.models.wan.vae import WanVAE
# Load VAE and encode an image
latents = vae.encode(video_tensor) # [1, 16, T_lat, H_lat, W_lat]
for t in range(latents.shape[2]):
frame = np.array(latents[0, :, t])
print(f"Frame {t}: mean={frame.mean():.4f} std={frame.std():.4f}")
```
### Analyze video frame quality
```python
import cv2, numpy as np
cap = cv2.VideoCapture("output.mp4")
while True:
ret, frame = cap.read()
if not ret: break
# Checkerboard metric: high values indicate patch-boundary artifacts
checker = np.abs(frame[::2, ::2].astype(float) - frame[1::2, 1::2].astype(float)).mean()
print(f"std={frame.std():.1f} checker={checker:.1f}")
```
### Compare weights between PyTorch and MLX
```python
import torch, mlx.core as mx, numpy as np
pt = torch.load("model.pt", map_location="cpu")
mlx_w = mx.load("model.safetensors")
for key in sorted(pt.keys()):
if key in mlx_w:
diff = np.abs(pt[key].float().numpy() - np.array(mlx_w[key])).max()
if diff > 0.01:
print(f"LARGE DIFF {key}: {diff:.6f}")
```

View File

@@ -1,285 +0,0 @@
# Wan2.2 MLX Implementation Notes
> Learnings and key decisions from porting Wan2.2 (TI2V-5B / T2V-14B / I2V-14B / T2V-1.3B) to Apple MLX.
## Architecture Overview
Wan2.2 is a Diffusion Transformer (DiT) for video generation. Despite early reports, the T2V/TI2V models do **not** use Mixture-of-Experts — they are dense DiT models with a dual-model architecture for the 14B variant (separate high-noise and low-noise denoisers with a boundary timestep).
### Key Parameters
| Model | dim | heads | layers | FFN mult | VAE z_dim | VAE stride | in_dim |
|-------|-----|-------|--------|----------|-----------|------------|--------|
| T2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 16 |
| I2V-14B | 5120 | 40 | 40 | 4×(5120×4/3) | 16 | (4, 8, 8) | 36 |
| TI2V-5B | 3072 | 24 | 32 | 4×(3072×4/3) | 48 | (4, 16, 16) | 48 |
| T2V-1.3B | 1536 | 12 | 30 | 4×(1536×4/3) | 16 | (4, 8, 8) | 16 |
### Codebase Structure (~3900 lines of Wan2.2 code)
```
mlx_video/
├── generate_wan.py # 483L - Generation pipeline (T2V + I2V)
├── convert_wan.py # 564L - Weight conversion from HuggingFace
└── models/wan/
├── config.py # 113L - Model configs (dataclass presets)
├── model.py # 320L - DiT model (time embed, patchify, unpatchify)
├── transformer.py # 91L - Attention block + FFN
├── attention.py # 211L - Self-attention + cross-attention
├── rope.py # 100L - 3D Rotary Position Embeddings
├── text_encoder.py # 240L - T5 encoder (UMT5-XXL)
├── scheduler.py # 428L - Euler, DPM++ 2M, UniPC schedulers
├── vae.py # 315L - Wan2.1 VAE decoder (4×8×8)
├── vae22.py # 836L - Wan2.2 VAE encoder + decoder (4×16×16)
├── loading.py # 154L - Model loading utilities
└── i2v_utils.py # 58L - I2V mask/preprocessing
```
---
## Critical Bugs & Fixes
### 1. MLX Underscore Attribute Gotcha
**Problem**: MLX's `nn.Module` silently ignores underscore-prefixed attributes (`_layer_0`, `_layer_1`, etc.) in `parameters()` and `load_weights()`. The Wan2.2 VAE had layers named `_layer_N`, causing **87 out of 110 weights to be silently dropped** during loading.
**Fix**: Rename all `_layer_N` attributes to `layer_N`. MLX treats underscore-prefixed attributes as "private" and excludes them from the parameter tree.
**Lesson**: Never use underscore-prefixed names for `nn.Module` sub-modules in MLX.
### 2. Patchify Channel Ordering
**Problem**: The patchify/unpatchify operations transposed channels incorrectly — producing `[C fastest]` layout instead of `[C slowest]`, causing completely garbled video output.
**Fix**: Changed reshape to produce correct `[B, T', H', W', pt*ph*pw*C]` ordering matching PyTorch's contiguous memory layout.
**Lesson**: When porting PyTorch reshape/view operations to MLX, pay close attention to memory layout — PyTorch is row-major by default, and reshape semantics differ when dimensions are reordered.
### 3. VAE AttentionBlock Reshape
**Problem**: Attention block merged batch (B) with channels (C) instead of batch with temporal (T), producing a green checker pattern in output.
**Fix**: Correct reshape from `[B*C, T, H, W]` to `[B*T, C, H, W]` for spatial attention.
### 4. RMS Norm vs L2 Norm
**Problem**: The Wan2.2 VAE uses a class named `RMS_norm` in PyTorch, but it actually computes **L2 normalization** (divide by L2 norm), not RMS normalization (divide by RMS). Using actual RMS norm caused exponential value explosion.
**Fix**: Implement as `x / ||x||₂` instead of `x / sqrt(mean(x²))`.
**Lesson**: Don't trust class names in reference code — read the actual computation.
### 5. Video Codec Green Output
**Problem**: OpenCV's `mp4v` codec on macOS produces green-tinted video.
**Fix**: Switch to `imageio` with `libx264` codec. Fallback chain: imageio → cv2 (avc1) → PNG frames.
---
## Precision & Dtype Flow
### The bfloat16 Autocast Pattern
The official PyTorch implementation uses `torch.autocast("cuda", dtype=torch.bfloat16)` which automatically casts matmul inputs. In MLX, we replicate this manually:
| Operation | Official (PyTorch) | MLX Implementation |
|---|---|---|
| Modulation/gates | float32 (explicit `autocast(enabled=False)`) | `x.astype(mx.float32)` before modulation |
| QKV projections | bfloat16 (outer autocast) | Cast input to `self.q.weight.dtype` |
| RoPE computation | float64 → float32 | float32 (MLX lacks float64 on GPU) |
| Q/K after RoPE | bfloat16 (`q.to(v.dtype)`) | Cast back to weight dtype after RoPE |
| FFN matmuls | bfloat16 (outer autocast) | Cast input to `self.fc1.weight.dtype` |
| Residual stream | float32 | float32 (no cast) |
**Result**: ~16% speedup (47s vs 56s for 20 steps at 480p) with no quality regression.
**Key insight**: Modulation parameters (scale, shift, gate) must stay in float32 — they are small values (~0.010.1) that lose significant precision in bfloat16. The official code explicitly disables autocast for these computations.
### T5 Encoder Precision
The T5 text encoder must run in float32. Bfloat16 weights cause the attention softmax to produce degenerate distributions, which corrupts text conditioning and manifests as blurry patches in generated video. Since T5 only runs once per generation, the performance cost is negligible.
### VAE Decoder Precision
VAE weights must be float32. Bfloat16 VAE decode introduces visible quality loss in the decoded video frames.
---
## Scheduler Implementation Details
### Three Schedulers: Euler, DPM++ 2M, UniPC
All operate in the flow-matching formulation where `sigma` represents the noise level (1.0 = pure noise, 0.0 = clean).
**Euler**: Simple first-order ODE solver. Most stable, recommended for debugging.
**DPM++ 2M**: Second-order multistep solver. Uses previous step's model output for higher-order correction. Requires special handling at boundaries (return `±inf` from `_lambda()` when sigma is 0 or 1).
**UniPC** (default, matches official): Second-order predictor-corrector. The "C" (corrector) part is critical — it refines each step using the already-computed model output at **zero additional model evaluation cost**.
### UniPC Corrector: Must Be Enabled
**Discovery**: Our implementation had `use_corrector=False` by default, but the official Wan2.2 code **always** enables it (there's no flag — the corrector runs whenever `step_index > 0`).
**Impact**: Without the corrector, UniPC degrades to a simple predictor, losing its second-order accuracy advantage.
### UniPC Corrector Coefficients
The corrector coefficients (`rhos_c`) must be computed by solving a linear system, not hardcoded. For order ≥ 2, hardcoding `rhos_c[-1] = 0.5` introduces ~613% error in the correction term across 47+ steps. The fix uses `np.linalg.solve()` to compute exact coefficients.
### Sigma Schedule
```python
# Flow-matching sigma schedule with shift
sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
```
Default shifts: T2V-14B uses 5.0, TI2V-5B uses 3.0, T2V-1.3B uses 3.0.
---
## Image-to-Video (I2V) Pipelines
Wan2.2 supports two distinct I2V approaches:
### TI2V-5B: Per-Token Timestep Masking
I2V conditions on a reference first frame by giving first-frame latent patches a timestep of 0 (clean) while other patches get the current diffusion timestep:
```python
# mask_tokens: [1, L] — 0 for first-frame patches, 1 for rest
t_tokens = mask_tokens * current_timestep # first-frame → t=0
```
The model receives 2D timestep input `[B, L]` instead of scalar, enabling per-token noise levels.
#### Mask Re-application
After each scheduler step, the first-frame latent is re-injected to prevent drift:
```python
latents = (1.0 - mask) * z_img + mask * latents
```
#### VAE Encoder Temporal Downsample Order
The Wan2.2 VAE encoder has `temporal_downsample = (False, True, True)`:
- Stage 0: Spatial-only downsampling
- Stages 12: Spatial + temporal downsampling
This was incorrectly set to `(True, True, False)` initially, causing wrong spatial processing paths.
### I2V-14B: Channel Concatenation
The I2V-14B model uses a fundamentally different approach — channel concatenation via a `y` tensor:
1. **Encode image**: Resize to target (H, W), create video tensor with image as first frame + zeros → VAE encode through Wan2.1 encoder → `[16, T_lat, H_lat, W_lat]`
2. **Build mask**: Binary mask with 1 for first frame, 0 for rest → rearranged to `[4, T_lat, H_lat, W_lat]`
3. **Construct y**: `y = concat([mask_4ch, encoded_16ch])``[20, T_lat, H_lat, W_lat]`
4. **Channel concat in model**: Before patchify, `x = concat([noise_16ch, y_20ch])` → 36 channels matching `in_dim=36`
Key differences from TI2V-5B:
- Uses **Wan2.1 VAE** (z_dim=16, stride 4,8,8), not Wan2.2 VAE
- Requires the **VAE encoder** (for encoding the reference image)
- Uses **scalar timesteps** (same as T2V) — no per-token masking
- **Dual model** pipeline with boundary=0.900
- Both conditional and unconditional predictions receive the same `y` tensor
---
## Dimension Constraints
### Patchify Alignment
Video dimensions must be divisible by `patch_size × vae_stride`:
- **TI2V-5B**: patch=(1,2,2), stride=(4,16,16) → alignment = **32** pixels
- **T2V-14B**: patch=(1,2,2), stride=(4,8,8) → alignment = **16** pixels
Example: 720p (1280×720) → 720 % 32 ≠ 0, auto-aligns to **704**.
### Frame Count
Frames must satisfy `num_frames = 4n + 1` (e.g., 5, 9, 13, ..., 81) due to temporal VAE stride of 4.
---
## Performance Optimizations
### Batched CFG
Instead of two separate forward passes for conditional and unconditional predictions, batch them into a single B=2 forward pass:
```python
preds = model([latents, latents], t=t_batch, context=context_cfg, ...)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
```
**Result**: ~40% speedup by amortizing attention overhead.
### Precomputed Text Embeddings & Cross-Attention KV Cache
Text embeddings and cross-attention K/V projections are constant across all diffusion steps. Computing them once and passing as caches eliminates redundant computation.
### Memory Management in Diffusion Loop
```python
# Release temporaries before eval to free memory for graph execution
del noise_pred_cond, noise_pred_uncond, noise_pred, preds
mx.eval(latents)
```
MLX's lazy evaluation means `mx.eval()` triggers the full computation graph. Deleting intermediate arrays before eval allows MLX to reuse their memory during execution.
---
## Weight Conversion
### Key Mapping Patterns
The PyTorch → MLX conversion (`convert_wan.py`) handles several systematic transforms:
1. **Conv3d weight transposition**: PyTorch `(out, in, D, H, W)` → MLX `(out, D, H, W, in)`
2. **Linear weight transposition**: PyTorch `(out, in)` → MLX `(out, in)` (same convention for `nn.Linear`)
3. **Nested module paths**: `blocks.0.self_attn.q.weight` → same paths, MLX loads by dotted key
### Dual-Model Splitting
The T2V-14B uses dual models (high-noise and low-noise). The conversion script splits a single checkpoint into separate files or handles pre-split checkpoints from HuggingFace.
---
## Testing Strategy
332 tests across 10 files, all running in ~5 seconds:
| File | Focus |
|------|-------|
| test_wan_config.py | Config presets, field validation |
| test_wan_attention.py | Self/cross attention, RMSNorm, bf16 autocast |
| test_wan_transformer.py | FFN, attention block, float32 modulation |
| test_wan_model.py | Full DiT forward pass, per-token timesteps |
| test_wan_t5.py | T5 encoder layers and full encoding |
| test_wan_vae.py | VAE 2.1 decoder, VAE 2.2 encoder + decoder |
| test_wan_scheduler.py | All 3 schedulers, cross-scheduler coherence |
| test_wan_convert.py | Weight sanitization and conversion |
| test_wan_generate.py | End-to-end pipeline, I2V masks, dimension alignment |
| test_wan_i2v.py | I2V-14B config, y parameter, VAE encoder, mask construction |
Tests use a tiny config (`dim=64, heads=2, layers=2`) for fast execution. Cross-scheduler coherence tests verify that all three schedulers produce similar outputs from the same noise.
---
## Known Issues
### I2V Quality Degradation
Frames 213 gradually degrade, and frame 14 often has a "flash" artifact. All implementation details have been verified against the official PyTorch code with no discrepancies found. Possible causes:
- Subtle numerical differences from float32 vs float64 RoPE (MLX lacks float64 on GPU)
- MLX-specific attention precision behavior
- Better prompts and 720p resolution (the model's native resolution) help reduce artifacts
### Chinese Negative Prompt
The official Wan2.2 uses a Chinese negative prompt that prevents oversaturation and comic-style artifacts. Correct tokenization requires `ftfy.fix_text()` to normalize fullwidth characters and double HTML unescaping. Without proper text cleaning, the negative prompt tokens don't match the training distribution, causing blurry patches.

View File

@@ -70,7 +70,7 @@ The conversion script auto-detects the model version from the directory structur
#### Wan2.1 T2V 1.3B
```bash
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.1-T2V-1.3B \
--output-dir ./Wan2.1-T2V-1.3B-MLX
```
@@ -78,7 +78,7 @@ python -m mlx_video.convert_wan \
#### Wan2.1 T2V 14B
```bash
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.1-T2V-14B \
--output-dir ./Wan2.1-T2V-14B-MLX
```
@@ -86,7 +86,7 @@ python -m mlx_video.convert_wan \
#### Wan2.2 T2V 14B
```bash
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.2-T2V-A14B \
--output-dir ./Wan2.2-T2V-A14B-MLX
```
@@ -94,7 +94,7 @@ python -m mlx_video.convert_wan \
#### Wan2.2 I2V 14B
```bash
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.2-I2V-A14B \
--output-dir ./Wan2.2-I2V-A14B-MLX
```
@@ -104,7 +104,7 @@ The I2V model is auto-detected from `config.json`; the output will include a `va
#### Wan2.2 TI2V 5B
```bash
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.2-TI2V-5B \
--output-dir ./Wan2.2-TI2V-5B-MLX
```
@@ -144,7 +144,7 @@ wan_mlx/
#### Wan2.1 T2V 1.3B
```bash
python -m mlx_video.generate_wan \
python -m mlx_video.wan2.gemer \
--model-dir ./Wan2.1-T2V-1.3B-MLX \
--prompt "A cat playing piano in a cozy living room, cinematic lighting" \
--width 832 --height 480 --num-frames 81 \
@@ -156,7 +156,7 @@ python -m mlx_video.generate_wan \
#### Wan2.1 T2V 14B
```bash
python -m mlx_video.generate_wan \
python -m mlx_video.wan2.gemer \
--model-dir ./Wan2.1-T2V-14B-MLX \
--prompt "A woman walks through a misty forest at dawn, slow motion, cinematic" \
--width 1280 --height 704 --num-frames 81 \
@@ -172,7 +172,7 @@ python -m mlx_video.generate_wan \
Wan2.2 uses a dual-model pipeline (separate high-noise and low-noise transformers) and takes guidance as a `high,low` pair:
```bash
python -m mlx_video.generate_wan \
python -m mlx_video.wan2.generate \
--model-dir ./Wan2.2-T2V-A14B-MLX \
--prompt "Two astronauts playing chess on the surface of the moon, dramatic lighting, 8K" \
--negative-prompt "low quality, blurry, distorted" \
@@ -189,7 +189,7 @@ python -m mlx_video.generate_wan \
Image-to-video: animates a starting image guided by a text prompt. Pass the image with `--image`:
```bash
python -m mlx_video.generate_wan \
python -m mlx_video.wan2.generate \
--model-dir ./Wan2.2-I2V-A14B-MLX \
--image ./my_photo.png \
--prompt "The person slowly turns their head and smiles, cinematic, natural lighting" \
@@ -207,7 +207,7 @@ python -m mlx_video.generate_wan \
Text+image-to-video: a single-model variant with a larger VAE (`z_dim=48`). Resolution must be divisible by **32** (not 16 as with other models):
```bash
python -m mlx_video.generate_wan \
python -m mlx_video.wan2.generate \
--model-dir ./Wan2.2-TI2V-5B-MLX \
--image ./my_photo.png \
--prompt "The subject waves hello, warm sunlight, film grain" \
@@ -251,27 +251,27 @@ Quantize the transformer weights to reduce memory usage by ~3.4×. Quantization
```bash
# Convert with 4-bit quantization (works for any variant)
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.1-T2V-1.3B \
--output-dir ./Wan2.1-T2V-1.3B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.1-T2V-14B \
--output-dir ./Wan2.1-T2V-14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.2-T2V-A14B \
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.2-I2V-A14B \
--output-dir ./Wan2.2-I2V-A14B-MLX-Q4 \
--quantize --bits 4 --group-size 64
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.2-TI2V-5B \
--output-dir ./Wan2.2-TI2V-5B-MLX-Q4 \
--quantize --bits 4 --group-size 64
@@ -280,7 +280,7 @@ python -m mlx_video.convert_wan \
You can also quantize an already-converted MLX model without re-converting from PyTorch:
```bash
python -m mlx_video.convert_wan \
python -m mlx_video.wan2.convert \
--checkpoint-dir ./Wan2.2-T2V-A14B-MLX \
--output-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--quantize-only --bits 4
@@ -289,7 +289,7 @@ python -m mlx_video.convert_wan \
Quantized models are used exactly the same way — the quantization is auto-detected from `config.json`:
```bash
python -m mlx_video.generate_wan \
python -m mlx_video.wan2.generate \
--model-dir ./Wan2.2-T2V-A14B-MLX-Q4 \
--prompt "A cat playing piano"
```
@@ -330,7 +330,7 @@ LoRA's can be used with the `--lora-high` and `--lora-low` command line switches
For example, for using the the distilled [Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) LoRA, use the following command. Lightning speeds up generation by using only 4 steps and a CFG scale of 1.
```bash
python -m mlx_video.generate_wan \
python -m mlx_video.wan2.generate \
--model-dir /Volumes/SSD/Wan-AI/Wan2.2-T2V-A14B-MLX \
--width 480 \
--height 704 \

View File

@@ -0,0 +1,2 @@
from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan_2.wan_2 import WanModel

View File

@@ -98,8 +98,12 @@ class WanSelfAttention(nn.Module):
v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
q = rope_apply(
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
k = rope_apply(
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
@@ -120,9 +124,7 @@ class WanSelfAttention(nn.Module):
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out)
@@ -213,9 +215,7 @@ class WanCrossAttention(nn.Module):
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out)

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Tuple, Union
from mlx_video.models.ltx.config import BaseModelConfig
from mlx_video.models.ltx_2.config import BaseModelConfig
@dataclass

View File

@@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.utils
import numpy as np
logger = logging.getLogger(__name__)
@@ -57,7 +56,9 @@ def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
return weights
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
def sanitize_wan_transformer_weights(
weights: Dict[str, mx.array]
) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern:
@@ -246,8 +247,8 @@ def _load_lora_configs(
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.models.wan_2.generate import Colors
from mlx_video.lora import LoRAConfig, load_multiple_loras
from mlx_video.generate_wan import Colors
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
@@ -264,7 +265,9 @@ def _load_lora_configs(
module_to_loras = load_multiple_loras(configs)
if not module_to_loras:
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
print(
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
)
return module_to_loras
@@ -279,8 +282,8 @@ def load_and_apply_loras(
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.models.wan_2.generate import Colors
from mlx_video.lora import apply_loras_to_weights
from mlx_video.generate_wan import Colors
if not lora_configs:
return model_weights
@@ -289,12 +292,17 @@ def load_and_apply_loras(
if not module_to_loras:
return model_weights
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
print(
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
)
if verbose:
print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights(
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
model_weights,
module_to_loras,
verbose=verbose,
quantization_bits=quantization_bits,
)
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
@@ -403,7 +411,7 @@ def convert_wan_checkpoint(
print(" Warning: No transformer weights found!")
# Save config — detect model size from source config.json or transformer weights
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
def _detect_config():
"""Detect config from source config.json or transformer weight shapes."""
@@ -435,8 +443,10 @@ def convert_wan_checkpoint(
src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512)
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
print(
f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}"
)
# Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072:
@@ -512,9 +522,12 @@ def convert_wan_checkpoint(
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
weights = sanitize_wan22_vae_weights(
weights, include_encoder=include_encoder
)
else:
weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
@@ -527,7 +540,9 @@ def convert_wan_checkpoint(
# Quantize transformer weights if requested
if quantize:
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
print(
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
)
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}")
@@ -543,9 +558,16 @@ def _quantize_predicate(path: str, module) -> bool:
return False
# Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = (
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
".ffn.fc1", ".ffn.fc2",
".self_attn.q",
".self_attn.k",
".self_attn.v",
".self_attn.o",
".cross_attn.q",
".cross_attn.k",
".cross_attn.v",
".cross_attn.o",
".ffn.fc1",
".ffn.fc2",
)
return any(path.endswith(p) for p in quantize_patterns)
@@ -572,7 +594,7 @@ def _quantize_saved_model(
import mlx.nn as nn
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
if source_dir is None:
source_dir = output_dir
@@ -682,16 +704,22 @@ def quantize_mlx_model(
).exists()
# Build model config
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
config_dict = {
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights)
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
transformer_files = {
"low_noise_model.safetensors",
"high_noise_model.safetensors",
"model.safetensors",
}
if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir():
@@ -763,11 +791,18 @@ if __name__ == "__main__":
if args.quantize_only:
quantize_mlx_model(
args.checkpoint_dir, args.output_dir,
bits=args.bits, group_size=args.group_size,
args.checkpoint_dir,
args.output_dir,
bits=args.bits,
group_size=args.group_size,
)
else:
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
args.checkpoint_dir,
args.output_dir,
args.dtype,
args.model_version,
quantize=args.quantize,
bits=args.bits,
group_size=args.group_size,
)

View File

@@ -4,25 +4,23 @@ import argparse
import gc
import math
import random
import sys
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan.loading import (
_clean_text,
from mlx_video.models.wan_2.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan_2.utils import (
encode_text,
load_t5_encoder,
load_vae_decoder,
load_vae_encoder,
load_wan_model,
)
from mlx_video.models.wan.postprocess import save_video
from mlx_video.models.wan_2.postprocess import save_video
class Colors:
"""ANSI color codes for terminal output."""
@@ -37,6 +35,7 @@ class Colors:
DIM = "\033[2m"
RESET = "\033[0m"
# Backward-compat alias (tests and external code may use the old name)
_build_i2v_mask = build_i2v_mask
@@ -122,8 +121,8 @@ def generate_video(
"""
import json
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -143,10 +142,13 @@ def generate_video(
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**{
k: v for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
})
config = WanModelConfig(
**{
k: v
for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
}
)
else:
# Auto-detect: dual model files → 2.2, single model → 2.1
if (model_dir / "low_noise_model.safetensors").exists():
@@ -182,7 +184,9 @@ def generate_video(
if "patch_embedding_proj.weight" in k:
actual_dim = v.shape[0]
if actual_dim != config.dim:
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
print(
f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}"
)
if actual_dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
@@ -192,13 +196,20 @@ def generate_video(
# Auto-correct Wan2.2 VAE params from stale configs
if config.in_dim == 48 and config.vae_z_dim != 48:
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
config = WanModelConfig(**{
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
})
print(
f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}"
)
config = WanModelConfig(
**{
**{
f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
}
)
# Apply defaults from config if not overridden
if steps is None:
@@ -227,7 +238,9 @@ def generate_video(
gen_frames = num_frames
if trim_first_frames > 0:
gen_frames = num_frames + trim_first_frames * 4
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
print(
f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}"
)
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
@@ -247,10 +260,16 @@ def generate_video(
if is_i2v:
print(f" Image: {image}")
if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
neg_display = (
neg_prompt_resolved[:60] + "..."
if len(neg_prompt_resolved) > 60
else neg_prompt_resolved
)
print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
print(
f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}"
)
if cfg_disabled:
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
print(f"{Colors.RESET}")
@@ -275,12 +294,16 @@ def generate_video(
height = align_h
if width == 0:
width = align_w
print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
print(
f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}"
)
# Enforce max_area constraint (model-specific resolution limit)
if config.max_area > 0 and height * width > config.max_area:
old_h, old_w = height, width
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
width, height = _best_output_size(
width, height, align_w, align_h, config.max_area
)
print(
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
@@ -309,6 +332,7 @@ def generate_video(
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
# Encode prompts
@@ -318,12 +342,15 @@ def generate_video(
context_null = None
mx.eval(context)
else:
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
context_null = encode_text(
t5_encoder, tokenizer, neg_prompt_resolved, config.text_len
)
mx.eval(context, context_null)
# Free T5 from memory
del t5_encoder
gc.collect(); mx.clear_cache()
gc.collect()
mx.clear_cache()
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
# I2V: encode image to latent space
@@ -346,18 +373,25 @@ def generate_video(
img = Image.open(image).convert("RGB")
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
img = img.resize(
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
)
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
img_arr = mx.array(
np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0
) # [H, W, 3]
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
video = mx.concatenate([
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
], axis=1)
video = mx.concatenate(
[
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
],
axis=1,
)
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
vae_enc = load_vae_encoder(vae_path, config)
@@ -367,12 +401,17 @@ def generate_video(
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
msk = mx.concatenate([
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
], axis=1)
msk = mx.concatenate(
[
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
],
axis=1,
)
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -395,13 +434,16 @@ def generate_video(
del vae_enc, img_tensor
gc.collect(); mx.clear_cache()
gc.collect()
mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
# Load transformer models
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
if quantization:
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
print(
f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}"
)
t2 = time.time()
# Merge per-model LoRAs with shared LoRAs
@@ -412,10 +454,16 @@ def generate_video(
if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
low_noise_model = load_wan_model(
low_noise_path, config, quantization, loras=_loras_low
)
high_noise_model = load_wan_model(
high_noise_path, config, quantization, loras=_loras_high
)
else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
single_model = load_wan_model(
model_dir / "model.safetensors", config, quantization, loras=_loras_single
)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step)
@@ -437,8 +485,12 @@ def generate_video(
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
context_cfg_low = mx.concatenate(
[context_emb_low[0:1], context_emb_low[1:2]], axis=0
)
context_cfg_high = mx.concatenate(
[context_emb_high[0:1], context_emb_high[1:2]], axis=0
)
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb)
@@ -534,7 +586,7 @@ def generate_video(
rcs = rope_cos_sin
# Use compiled forward when available (faster after first trace)
_call = getattr(model, '_compiled', model)
_call = getattr(model, "_compiled", model)
if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
@@ -552,7 +604,9 @@ def generate_video(
y_arg = [y_i2v] if is_i2v_channel_concat else None
if is_dual:
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
ctx = (
context_cond_high if timestep_val >= boundary else context_cond_low
)
else:
ctx = context_cond
preds = _call(
@@ -571,7 +625,11 @@ def generate_video(
if is_dual:
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
else:
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
gs = (
guide_scale
if isinstance(guide_scale, (int, float))
else guide_scale[0]
)
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
@@ -586,8 +644,10 @@ def generate_video(
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
ctx = (
context_cfg
if not is_dual
else (context_cfg_high if timestep_val >= boundary else context_cfg_low)
)
preds = _call(
[latents, latents],
@@ -618,16 +678,24 @@ def generate_video(
if debug_latents:
lat_np = np.array(latents) # [C, T, H, W]
n_t = lat_np.shape[1]
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
print(
f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}"
)
print(
f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}"
)
for t_pos in range(min(n_t, 8)):
frame = lat_np[:, t_pos, :, :]
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
print(
f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}"
)
if n_t > 8:
interior = lat_np[:, 4:, :, :]
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
print(
f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}"
)
print()
# Free transformer models and text embeddings
@@ -646,7 +714,8 @@ def generate_video(
del model, kv, context
if context_null is not None:
del context_null
gc.collect(); mx.clear_cache()
gc.collect()
mx.clear_cache()
# Load VAE and decode
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
@@ -660,7 +729,7 @@ def generate_video(
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
# This gives the first real frame a full temporal receptive field of real data.
# Select tiling configuration
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig
if tiling == "none":
tiling_config = None
@@ -677,16 +746,28 @@ def generate_video(
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
print(
f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}"
)
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
spatial_info = (
f"{tiling_config.spatial_config.tile_size_in_pixels}px"
if tiling_config.spatial_config
else "none"
)
temporal_info = (
f"{tiling_config.temporal_config.tile_size_in_frames}f"
if tiling_config.temporal_config
else "none"
)
print(
f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}"
)
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
from mlx_video.models.wan_2.vae22 import denormalize_latents
# latents: [C, T, H, W] → [1, T, H, W, C] (channels-last for Wan2.2 VAE)
z = latents.transpose(1, 2, 3, 0)[None]
@@ -718,7 +799,9 @@ def generate_video(
if trim_first_frames > 0:
trim_pixels = trim_first_frames * 4
video = video[trim_pixels:]
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
print(
f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}"
)
save_video(video, output_path, fps=config.sample_fps)
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
@@ -727,58 +810,124 @@ def generate_video(
def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--image", type=str, default=None,
help="Path to input image for I2V (omit for T2V mode)")
parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
parser.add_argument(
"--scheduler", type=str, default="unipc",
"--model-dir",
type=str,
required=True,
help="Path to converted MLX model directory",
)
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument(
"--image",
type=str,
default=None,
help="Path to input image for I2V (omit for T2V mode)",
)
parser.add_argument(
"--negative-prompt",
type=str,
default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)",
)
parser.add_argument(
"--no-negative-prompt",
action="store_true",
help="Disable negative prompt (use empty string instead of config default)",
)
parser.add_argument(
"--width", type=int, default=1280, help="Video width (default: 1280)"
)
parser.add_argument(
"--height",
type=int,
default=704,
help="Video height (default: 704; 720p models use 704)",
)
parser.add_argument(
"--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)"
)
parser.add_argument(
"--steps",
type=int,
default=None,
help="Number of diffusion steps (default: from config)",
)
parser.add_argument(
"--guide-scale",
type=str,
default=None,
help="Guidance scale: single float or low,high pair",
)
parser.add_argument(
"--shift",
type=float,
default=None,
help="Noise schedule shift (default: from config)",
)
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument(
"--output-path", type=str, default="output.mp4", help="Output video path"
)
parser.add_argument(
"--scheduler",
type=str,
default="unipc",
choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
)
parser.add_argument(
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
"--lora",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
)
parser.add_argument(
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
"--lora-high",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
"--lora-low",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
choices=[
"auto",
"none",
"default",
"aggressive",
"conservative",
"spatial",
"temporal",
],
help="VAE tiling mode to reduce memory during decoding (default: auto)",
)
parser.add_argument(
"--no-compile", action="store_true",
"--no-compile",
action="store_true",
help="Disable mx.compile on models (for debugging)",
)
parser.add_argument(
"--trim-first-frames", type=int, default=0, metavar="N",
"--trim-first-frames",
type=int,
default=0,
metavar="N",
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
"Default: 0 (disabled)",
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
"Default: 0 (disabled)",
)
parser.add_argument(
"--debug-latents", action="store_true",
"--debug-latents",
action="store_true",
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
)
args = parser.parse_args()

View File

@@ -21,7 +21,9 @@ def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
# Resize so that the image covers the target size (LANCZOS)
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
img = img.resize(
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
)
# Center crop
x1 = (img.width - width) // 2

View File

@@ -1,6 +1,8 @@
import numpy as np
from pathlib import Path
import numpy as np
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""Save video frames to MP4.
@@ -11,6 +13,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""
try:
import imageio
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
for frame in frames:
writer.append_data(frame)
@@ -18,6 +21,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
except ImportError:
try:
import cv2
h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
@@ -27,9 +31,11 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
except (ImportError, Exception):
# Last resort: save as individual PNGs
from PIL import Image
out_dir = Path(output_path).parent / Path(output_path).stem
out_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames):
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
print(
f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)"
)

View File

@@ -1,4 +1,3 @@
import math
import mlx.core as mx
import numpy as np
@@ -11,13 +10,16 @@ def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
Complex frequency tensor of shape [max_seq_len, dim // 2].
"""
assert dim % 2 == 0
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
1.0
/ np.power(
theta,
np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
freqs = (
np.arange(max_seq_len, dtype=np.float64)[:, None]
* (
1.0
/ np.power(
theta,
np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
)
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
cos_freqs = np.cos(freqs).astype(np.float32)
sin_freqs = np.sin(freqs).astype(np.float32)
@@ -46,9 +48,9 @@ def rope_apply(
# Check if all batch elements have the same grid (common for CFG B=2)
f0, h0, w0 = grid_sizes[0]
seq_len = f0 * h0 * w0
all_same_grid = all(
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
) if b > 1 else True
all_same_grid = (
all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True
)
if all_same_grid:
# Vectorized path: apply RoPE to all batch elements at once
@@ -57,7 +59,9 @@ def rope_apply(
x_imag = x_seq[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(
b, seq_len, n, d
)
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
return x_rotated
@@ -102,17 +106,11 @@ def rope_apply(
# Build per-position frequencies by expanding along grid dims
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
ft = mx.broadcast_to(
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
)
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
fh = mx.broadcast_to(
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
)
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
fw = mx.broadcast_to(
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
)
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
# Concatenate: [f*h*w, half_d, 2]
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)

View File

@@ -7,9 +7,8 @@ for the same quality as Euler.
import math
import numpy as np
import mlx.core as mx
import numpy as np
def _compute_sigmas(
@@ -25,9 +24,7 @@ def _compute_sigmas(
Returns num_steps+1 values (the last being 0.0 for the terminal state).
"""
# sigma bounds from unshifted training schedule (constructor uses shift=1)
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
::-1
]
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1]
sigmas_unshifted = 1.0 - alphas
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
sigma_min = float(sigmas_unshifted[-1]) # 0.0
@@ -65,7 +62,10 @@ class FlowMatchEulerScheduler:
sample: mx.array,
) -> mx.array:
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
dt = (
self._sigmas_float[self._step_index + 1]
- self._sigmas_float[self._step_index]
)
x_next = sample + dt * model_output
self._step_index += 1
return x_next
@@ -139,13 +139,8 @@ class FlowDPMPP2MScheduler:
# Decide order: 1st for first step, last step (if lower_order_final
# and few steps), otherwise 2nd
use_first_order = (
self._prev_x0 is None
or (
self.lower_order_final
and i == self._num_steps - 1
and self._num_steps < 15
)
use_first_order = self._prev_x0 is None or (
self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15
)
if use_first_order or sigma_next == 0.0:

View File

@@ -49,20 +49,19 @@ class T5RelativeEmbedding(nn.Module):
is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = (
max_exact
+ (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = max_exact + (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(
rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
)
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
rel_buckets = rel_buckets + mx.where(
is_small, rel_pos.astype(mx.int32), rel_pos_large
)
return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array:
@@ -115,7 +114,7 @@ class T5Attention(nn.Module):
v = v.transpose(0, 2, 1, 3)
# QK^T (no scaling) — compute in float32 for precision
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
# Add position bias
if pos_bias is not None:

View File

@@ -6,7 +6,7 @@ for non-causal temporal decoders (e.g. Wan2.1 where T latent frames → T*scale
output frames rather than LTX's 1+(T-1)*scale mapping).
# TODO: This function can be refactored to consolidate with
# mlx_video.models.ltx.video_vae.tiling.decode_with_tiling once the
# mlx_video.models.ltx_2.video_vae.tiling.decode_with_tiling once the
# causal_temporal generalisation is accepted upstream.
"""
@@ -14,7 +14,7 @@ from typing import Callable, Optional
import mlx.core as mx
from mlx_video.models.ltx.video_vae.tiling import (
from mlx_video.models.ltx_2.video_vae.tiling import (
SpatialTilingConfig,
TemporalTilingConfig,
TilingConfig,
@@ -75,7 +75,11 @@ def decode_with_tiling(
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
out_f = (
(1 + (f_latent - 1) * temporal_scale)
if causal_temporal
else (f_latent * temporal_scale)
)
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
@@ -98,9 +102,13 @@ def decode_with_tiling(
# Compute intervals for each dimension
if causal_temporal:
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_temporal(
temporal_tile_size, temporal_overlap, f_latent
)
else:
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_spatial(
temporal_tile_size, temporal_overlap, f_latent
)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -124,9 +132,13 @@ def decode_with_tiling(
# Map temporal coordinates
if causal_temporal:
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_temporal_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
else:
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_spatial_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -135,7 +147,9 @@ def decode_with_tiling(
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
out_h_slice, h_mask = map_spatial_slice(
h_start, h_end, h_left, h_right, spatial_scale
)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
@@ -144,13 +158,23 @@ def decode_with_tiling(
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
out_w_slice, w_mask = map_spatial_slice(
w_start, w_end, w_left, w_right, spatial_scale
)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
tile_latents = latents[
:, :, t_start:t_end, h_start:h_end, w_start:w_end
]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
tile_output = decoder_fn(
tile_latents,
causal=causal,
timestep=timestep,
debug=False,
chunked_conv=chunked_conv,
)
mx.eval(tile_output)
# Clear tile_latents reference
@@ -173,13 +197,15 @@ def decode_with_tiling(
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
t_mask_slice.reshape(1, 1, -1, 1, 1)
* h_mask_slice.reshape(1, 1, 1, -1, 1)
* w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
tile_output_slice = tile_output[
:, :, :actual_t, :actual_h, :actual_w
].astype(mx.float32)
# Clear full tile_output
del tile_output
@@ -196,11 +222,37 @@ def decode_with_tiling(
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ blend_mask
)
# Force evaluation to free memory
@@ -232,12 +284,14 @@ def decode_with_tiling(
if next_tile_start_latent == 0:
next_tile_start_out = 0
elif causal_temporal:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
next_tile_start_out = (
1 + (next_tile_start_latent - 1) * temporal_scale
)
else:
next_tile_start_out = next_tile_start_latent * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
if not hasattr(decode_with_tiling, "_emitted_frames"):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
@@ -245,7 +299,10 @@ def decode_with_tiling(
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = (
output[:, :, emitted:next_tile_start_out, :, :]
/ finalized_weights
)
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
@@ -262,7 +319,7 @@ def decode_with_tiling(
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
@@ -270,7 +327,7 @@ def decode_with_tiling(
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
if hasattr(decode_with_tiling, "_emitted_frames"):
del decode_with_tiling._emitted_frames
# Clean up weights

View File

@@ -25,9 +25,7 @@ class WanAttentionBlock(nn.Module):
# Cross-attention (with optional norm on context)
self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else None
WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
)
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
@@ -36,7 +34,9 @@ class WanAttentionBlock(nn.Module):
self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__(
self,
@@ -67,7 +67,14 @@ class WanAttentionBlock(nn.Module):
# Self-attention with modulation (hidden state stays in w_dtype)
x_mod = self.norm1(x) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
y = self.self_attn(
x_mod,
seq_lens,
grid_sizes,
freqs,
rope_cos_sin=rope_cos_sin,
attn_mask=attn_mask,
)
x = x + y * e2
# Cross-attention (no modulation, just norm)

View File

@@ -6,7 +6,12 @@ import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
def load_wan_model(
model_path: Path,
config,
quantization: dict | None = None,
loras: list | None = None,
):
"""Load and initialize WanModel, with optional quantization and LoRA support.
Args:
@@ -16,12 +21,12 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None, l
If provided, creates QuantizedLinear stubs before loading.
loras: Optional list of (lora_path, strength) tuples to apply.
"""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config)
if quantization:
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
nn.quantize(
model,
@@ -37,7 +42,7 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None, l
if quantization:
# Dequantize LoRA-targeted layers, merge delta, replace with bf16 Linear.
# Non-LoRA layers stay 4-bit. Zero per-step overhead.
from mlx_video.convert_wan import _load_lora_configs
from mlx_video.models.wan_2.convert import _load_lora_configs
from mlx_video.lora import apply_loras_to_model
model.load_weights(list(weights.items()), strict=False)
@@ -48,7 +53,7 @@ def load_wan_model(model_path: Path, config, quantization: dict | None = None, l
return model
else:
# Weight merging: fold LoRA into bf16 weights before loading
from mlx_video.convert_wan import load_and_apply_loras
from mlx_video.models.wan_2.convert import load_and_apply_loras
weights = load_and_apply_loras(dict(weights), loras)
@@ -64,7 +69,7 @@ def load_t5_encoder(model_path: Path, config):
only runs once per generation, so performance impact is negligible.
This matches the official which computes softmax in float32 explicitly.
"""
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=config.t5_vocab_size,
@@ -92,10 +97,12 @@ def load_vae_decoder(model_path: Path, config=None):
is_wan22 = config is not None and config.vae_z_dim == 48
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
@@ -113,11 +120,11 @@ def load_vae_encoder(model_path: Path, config=None):
For Wan2.1/I2V-14B (vae_z_dim=16), uses WanVAE with encoder=True.
"""
if config is not None and config.vae_z_dim == 16:
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
else:
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
vae = Wan22VAEEncoder(z_dim=config.vae_z_dim if config else 48)
@@ -140,6 +147,7 @@ def _clean_text(text: str) -> str:
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass

View File

@@ -6,19 +6,45 @@ so weights load directly without key sanitization.
import mlx.core as mx
import mlx.nn as nn
import numpy as np
CACHE_T = 2
# Per-channel normalization statistics for z_dim=16
VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
]
VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
]
@@ -50,7 +76,9 @@ class CausalConv3d(nn.Module):
self._pad_w = padding[2]
# MLX Conv3d: weight shape [O, D, H, W, I]
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
self.weight = mx.zeros(
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
)
self.bias = mx.zeros((out_channels,))
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
@@ -67,8 +95,16 @@ class CausalConv3d(nn.Module):
x = mx.concatenate([pad_t, x], axis=2)
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
x = mx.pad(
x,
[
(0, 0),
(0, 0),
(0, 0),
(self._pad_h, self._pad_h),
(self._pad_w, self._pad_w),
],
)
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
out = self._conv3d(x)
@@ -118,7 +154,11 @@ class RMS_norm(nn.Module):
def __call__(self, x: mx.array) -> mx.array:
norm_dim = 1 if self.channel_first else -1
# L2 normalize along channel dim (matches F.normalize)
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
norm = mx.sqrt(
mx.clip(
mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None
)
)
return (x / norm) * self.scale * self.gamma
@@ -133,12 +173,12 @@ class ResidualBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.residual = [
RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU
RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU
None, # [5] Dropout
RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU
None, # [5] Dropout
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
]
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
@@ -226,13 +266,16 @@ class Resample(nn.Module):
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
if mode == "upsample3d":
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
)
else:
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
if mode == "downsample3d":
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, C, T, H, W]"""
@@ -272,8 +315,7 @@ class Resample(nn.Module):
else:
# Subsequent chunks: use cached frame as temporal context
cache_x = x[:, :, -1:]
x = self.time_conv(
x, cache_x=feat_cache[idx][:, :, -1:])
x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
@@ -328,8 +370,8 @@ class Decoder3d(nn.Module):
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU
RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
]
@@ -405,8 +447,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -431,8 +472,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.head[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -549,7 +589,7 @@ class WanVAE(nn.Module):
Returns:
Video [B, 3, T_out, H_out, W_out] clamped to [-1, 1]
"""
from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling
from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()
@@ -583,7 +623,7 @@ class WanVAE(nn.Module):
decoder_fn=tile_decode,
latents=z_denorm,
tiling_config=tiling_config,
spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
)

View File

@@ -8,7 +8,6 @@ conversion (channels-first → channels-last) is needed.
"""
import logging
import math
import mlx.core as mx
import mlx.nn as nn
@@ -19,23 +18,111 @@ logger = logging.getLogger(__name__)
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
VAE22_MEAN = mx.array([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
])
VAE22_MEAN = mx.array(
[
-0.2289,
-0.0052,
-0.1323,
-0.2339,
-0.2799,
0.0174,
0.1838,
0.1557,
-0.1382,
0.0542,
0.2813,
0.0891,
0.1570,
-0.0098,
0.0375,
-0.1825,
-0.2246,
-0.1207,
-0.0698,
0.5109,
0.2665,
-0.2108,
-0.2158,
0.2502,
-0.2055,
-0.0322,
0.1109,
0.1567,
-0.0729,
0.0899,
-0.2799,
-0.1230,
-0.0313,
-0.1649,
0.0117,
0.0723,
-0.2839,
-0.2083,
-0.0520,
0.3748,
0.0152,
0.1957,
0.1433,
-0.2944,
0.3573,
-0.0548,
-0.1681,
-0.0667,
]
)
VAE22_STD = mx.array([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
])
VAE22_STD = mx.array(
[
0.4765,
1.0364,
0.4514,
1.1677,
0.5313,
0.4990,
0.4818,
0.5013,
0.8158,
1.0344,
0.5894,
1.0901,
0.6885,
0.6165,
0.8454,
0.4978,
0.5759,
0.3523,
0.7135,
0.6804,
0.5833,
1.4146,
0.8986,
0.5659,
0.7069,
0.5338,
0.4889,
0.4917,
0.4069,
0.4999,
0.6866,
0.4093,
0.5709,
0.6065,
0.6415,
0.4944,
0.5726,
1.2042,
0.5458,
1.6887,
0.3971,
1.0600,
0.3943,
0.5537,
0.5444,
0.4089,
0.7468,
0.7744,
]
)
class CausalConv3d(nn.Module):
@@ -65,9 +152,9 @@ class CausalConv3d(nn.Module):
self._pad_w = padding[2]
# Weight: [O, D, H, W, I] for MLX
self.weight = mx.zeros((
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
))
self.weight = mx.zeros(
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
)
self.bias = mx.zeros((out_channels,))
def __call__(self, x, cache_x=None):
@@ -96,8 +183,16 @@ class CausalConv3d(nn.Module):
# Spatial padding
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
(self._pad_w, self._pad_w), (0, 0)])
x = mx.pad(
x,
[
(0, 0),
(0, 0),
(self._pad_h, self._pad_h),
(self._pad_w, self._pad_w),
(0, 0),
],
)
T_padded = x.shape[1]
H_padded, W_padded = x.shape[2], x.shape[3]
@@ -113,8 +208,9 @@ class CausalConv3d(nn.Module):
for d in range(kd):
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
conv_out = mx.conv_general(frame, w2d,
stride=(self.stride[1], self.stride[2]))
conv_out = mx.conv_general(
frame, w2d, stride=(self.stride[1], self.stride[2])
)
accum = conv_out if accum is None else accum + conv_out
outputs.append(accum + self.bias)
@@ -126,7 +222,7 @@ class RMS_norm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.scale = dim**0.5
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
self.gamma = mx.ones((dim,))
@@ -134,7 +230,9 @@ class RMS_norm(nn.Module):
# x: [..., C] (channels-last)
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
return (
x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
)
class ResidualBlock(nn.Module):
@@ -145,11 +243,7 @@ class ResidualBlock(nn.Module):
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
# We store as named layers matching PyTorch's indices
self.residual = ResidualBlockLayers(in_dim, out_dim)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim
else None
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
def __call__(self, x, feat_cache=None, feat_idx=None):
h = self.shortcut(x) if self.shortcut is not None else x
@@ -182,9 +276,7 @@ class ResidualBlockLayers(nn.Module):
# Save last CACHE_T frames before conv (for next chunk's context)
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
out = conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -231,7 +323,9 @@ class AttentionBlock(nn.Module):
x = self.norm(x)
# QKV via 1x1 conv2d (equivalent to linear on last dim)
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
qkv = (
mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias
) # [BT, H, W, 3C]
qkv = qkv.reshape(B * T, H * W, 3 * C)
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
@@ -240,8 +334,10 @@ class AttentionBlock(nn.Module):
k = k[:, None, :, :]
v = v[:, None, :, :]
scale = C ** -0.5
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
scale = C**-0.5
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale
) # [BT, 1, HW, C]
out = out.squeeze(1).reshape(B * T, H, W, C)
# Project output
@@ -270,16 +366,24 @@ class DupUp3D(nn.Module):
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
x = x.reshape(
B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s
)
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
x = x.reshape(
B,
T * self.factor_t,
H * self.factor_s,
W * self.factor_s,
self.out_channels,
)
if first_chunk:
x = x[:, self.factor_t - 1:, :, :, :]
x = x[:, self.factor_t - 1 :, :, :, :]
return x
@@ -348,7 +452,9 @@ class Resample(nn.Module):
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else:
raise ValueError(f"Unsupported mode: {mode}")
@@ -369,7 +475,9 @@ class Resample(nn.Module):
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
return (
mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
)
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
# x: [B, T, H, W, C]
@@ -444,14 +552,17 @@ class Resample(nn.Module):
class Up_ResidualBlock(nn.Module):
"""Upsampling residual block with optional DupUp3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
def __init__(
self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False
):
super().__init__()
self.up_flag = up_flag
# DupUp3D shortcut (no learnable params)
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim, out_dim,
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
@@ -490,13 +601,21 @@ class Up_ResidualBlock(nn.Module):
class Down_ResidualBlock(nn.Module):
"""Downsampling residual block with AvgDown3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
def __init__(
self,
in_dim,
out_dim,
num_res_blocks,
temperal_downsample=False,
down_flag=False,
):
super().__init__()
self.down_flag = down_flag
# AvgDown3D shortcut (no learnable params, always present)
self.avg_shortcut = AvgDown3D(
in_dim, out_dim,
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
@@ -562,13 +681,15 @@ class Decoder3d(nn.Module):
self.upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
self.upsamples.append(Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
))
self.upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
)
)
# Output head: [RMS_norm, SiLU, CausalConv3d]
self.head = Head22(dims[-1])
@@ -612,13 +733,15 @@ class Encoder3d(nn.Module):
for i in range(len(dim_mult)):
in_d, out_d = dims[i], dims[i + 1]
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
self.downsamples.append(Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
))
self.downsamples.append(
Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
)
)
# Middle blocks (same as decoder)
out_dim = dims[-1]
@@ -658,9 +781,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -700,9 +821,7 @@ class Head22(nn.Module):
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
x = self.layer_2(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -768,7 +887,7 @@ class Wan22VAEEncoder(nn.Module):
if i == 0:
chunk = x[:, :1]
else:
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None:
out = chunk_out
@@ -778,7 +897,7 @@ class Wan22VAEEncoder(nn.Module):
# conv1 (pointwise) + split into mu, log_var
out = self.conv1(out)
mu = out[:, :, :, :, :self.z_dim]
mu = out[:, :, :, :, : self.z_dim]
# Normalize
mu = normalize_latents(mu)
@@ -847,7 +966,7 @@ class Wan22VAEDecoder(nn.Module):
Returns:
video: [B, T', H', W', 3] decoded RGB in [-1, 1]
"""
from mlx_video.models.wan.tiling import TilingConfig, decode_with_tiling
from mlx_video.models.wan_2.tiling import TilingConfig, decode_with_tiling
if tiling_config is None:
tiling_config = TilingConfig.default()
@@ -885,8 +1004,8 @@ class Wan22VAEDecoder(nn.Module):
decoder_fn=tile_decode,
latents=z_cf,
tiling_config=tiling_config,
spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
causal_temporal=True,
)

View File

@@ -1,4 +1,5 @@
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
@@ -37,7 +38,9 @@ class Head(nn.Module):
proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
"""
@@ -111,20 +114,23 @@ class WanModel(nn.Module):
# Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
self.freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Precompute sinusoidal inv_freq for time embedding.
half = config.freq_dim // 2
self._inv_freq = mx.array(
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
).astype(np.float32)
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
np.float32
)
)
def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings.
@@ -297,12 +303,19 @@ class WanModel(nn.Module):
seq_lens_list.append(p.shape[1])
x = mx.concatenate(
[
mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
(
mx.concatenate(
[
p,
mx.zeros(
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
),
],
axis=1,
)
if p.shape[1] < seq_len
else p
)
if p.shape[1] < seq_len
else p
for p in patches
],
axis=0,
@@ -315,9 +328,7 @@ class WanModel(nn.Module):
t = t[None]
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
sin_emb = mx.concatenate(
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
)
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]

View File

@@ -1,32 +0,0 @@
import mlx.core as mx
import mlx.nn as nn
class PixArtAlphaTextProjection(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
act_fn: str = "gelu_tanh",
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
if act_fn == "gelu_tanh":
self.act = nn.GELU(approx="tanh")
elif act_fn == "silu":
self.act = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear1(x)
x = self.act(x)
x = self.linear2(x)
return x

View File

@@ -1,29 +1,36 @@
import math
from typing import Optional, Tuple, Union
from functools import partial
from pathlib import Path
from typing import Optional, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from functools import partial
from pathlib import Path
from huggingface_hub import snapshot_download
from PIL import Image
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
if Path(model_repo).exists():
return Path(model_repo)
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
return Path(
snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
)
)
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
if quantization is not None:
def get_class_predicate(p, m):
# Handle custom per layer quantizations
if p in quantization:
@@ -44,23 +51,24 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
class_predicate=get_class_predicate,
)
@partial(mx.compile, shapeless=True)
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
@partial(mx.compile, shapeless=True)
def to_denoised(
noisy: mx.array,
velocity: mx.array,
sigma: mx.array | float
noisy: mx.array, velocity: mx.array, sigma: mx.array | float
) -> mx.array:
"""Convert velocity prediction to denoised output.
Given noisy input x_t and velocity prediction v, compute denoised x_0:
x_0 = x_t - sigma * v
Uses float32 for computation precision (matching PyTorch behavior),
then converts back to input dtype.
Args:
noisy: Noisy input tensor x_t
velocity: Velocity prediction v
@@ -69,16 +77,21 @@ def to_denoised(
Returns:
Denoised tensor x_0
"""
original_dtype = noisy.dtype
# Cast to float32 for precision (PyTorch uses calc_dtype=torch.float32)
noisy_f32 = noisy.astype(mx.float32)
velocity_f32 = velocity.astype(mx.float32)
if isinstance(sigma, (int, float)):
# Convert to array with matching dtype to avoid float32 promotion
sigma_arr = mx.array(sigma, dtype=velocity.dtype)
return noisy - sigma_arr * velocity
sigma_f32 = mx.array(sigma, dtype=mx.float32)
else:
# sigma is per-sample - ensure dtype matches
sigma = sigma.astype(velocity.dtype)
while sigma.ndim < velocity.ndim:
sigma = mx.expand_dims(sigma, axis=-1)
return noisy - sigma * velocity
sigma_f32 = sigma.astype(mx.float32)
while sigma_f32.ndim < velocity_f32.ndim:
sigma_f32 = mx.expand_dims(sigma_f32, axis=-1)
result = noisy_f32 - sigma_f32 * velocity_f32
return result.astype(original_dtype)
def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array:
@@ -274,7 +287,9 @@ def prepare_image_for_encoding(
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
pil_image = pil_image.resize(
(target_width, target_height), Image.Resampling.LANCZOS
)
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
# Normalize to [-1, 1]

View File

@@ -1 +1 @@
__version__ = "0.0.1"
__version__ = "0.0.1"

View File

@@ -20,6 +20,8 @@ dependencies = [
"opencv-python>=4.12.0.88",
"Pillow>=10.3.0",
"mlx-vlm",
"rich>=14.2.0",
"librosa>=0.10.0",
"imageio>=2.37.2",
"imageio-ffmpeg>=0.6.0",
"ftfy",
@@ -44,8 +46,8 @@ Repository = "https://github.com/Blaizzy/mlx-video"
Issues = "https://github.com/Blaizzy/mlx-video/issues"
[project.scripts]
"mlx_video.generate" = "mlx_video.generate:main"
"mlx_video.generate_wan" = "mlx_video.generate_wan:main"
"mlx_video.ltx_2.generate" = "mlx_video.models.ltx_2.generate:main"
"mlx_video.wan_2.generate" = "mlx_video.models.wan_2.generate:main"
[tool.setuptools.packages.find]
include = ["mlx_video*"]
@@ -57,3 +59,4 @@ version = {attr = "mlx_video.version.__version__"}
dev = [
"pytest",
]

View File

@@ -170,19 +170,33 @@ def print_report(results, ref_path, test_path):
print("AGGREGATE METRICS")
print("-" * 40)
print(f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}")
print(f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}")
print(f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}")
print(f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}")
print(f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}")
print(
f" PSNR (dB): mean={np.mean(psnr):6.2f} min={np.min(psnr):6.2f} max={np.max(psnr):6.2f}"
)
print(
f" SSIM: mean={np.mean(ssim):.4f} min={np.min(ssim):.4f} max={np.max(ssim):.4f}"
)
print(
f" Mean diff: mean={np.mean(md):6.2f} min={np.min(md):6.2f} max={np.max(md):6.2f}"
)
print(
f" Max diff: mean={np.mean(mx):6.1f} min={np.min(mx):6.1f} max={np.max(mx):6.1f}"
)
print(
f" Color dist: mean={np.mean(cd):.4f} min={np.min(cd):.4f} max={np.max(cd):.4f}"
)
print()
print("TEMPORAL COHERENCE (mean frame-to-frame diff, lower = smoother)")
print("-" * 40)
print(f" Reference: {results['ref_temporal_coherence']:.2f}")
print(f" Test: {results['test_temporal_coherence']:.2f}")
ratio = results["test_temporal_coherence"] / (results["ref_temporal_coherence"] + 1e-10)
print(f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}")
ratio = results["test_temporal_coherence"] / (
results["ref_temporal_coherence"] + 1e-10
)
print(
f" Ratio: {ratio:.2f}x {'(test is smoother)' if ratio < 1 else '(test is jerkier)' if ratio > 1.05 else '(similar)'}"
)
print()
# Identify worst frames
@@ -190,7 +204,9 @@ def print_report(results, ref_path, test_path):
print("-" * 40)
worst_idx = np.argsort(psnr)[:5]
for i in worst_idx:
print(f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}")
print(
f" Frame {i:4d}: PSNR={psnr[i]:6.2f} dB SSIM={ssim[i]:.4f} mean_diff={md[i]:.2f}"
)
print()
# Quality assessment
@@ -210,7 +226,9 @@ def print_report(results, ref_path, test_path):
grade = "Very different"
print(f" Overall: {grade} (PSNR={mean_psnr:.1f} dB, SSIM={mean_ssim:.4f})")
if mean_psnr < 30:
print(" ⚠ Videos differ significantly — likely a bug or different generation seed")
print(
" ⚠ Videos differ significantly — likely a bug or different generation seed"
)
print("=" * 72)
@@ -242,9 +260,7 @@ def main():
parser.add_argument(
"--diff-video", help="Save side-by-side diff visualization to this path"
)
parser.add_argument(
"--max-frames", type=int, help="Compare only first N frames"
)
parser.add_argument("--max-frames", type=int, help="Compare only first N frames")
parser.add_argument(
"--ssim-win", type=int, default=7, help="SSIM window size (default: 7)"
)
@@ -254,26 +270,29 @@ def main():
default=5.0,
help="Diff heatmap amplification (default: 5.0)",
)
parser.add_argument(
"--csv", help="Export per-frame metrics to CSV file"
)
parser.add_argument("--csv", help="Export per-frame metrics to CSV file")
args = parser.parse_args()
print(f"Loading reference: {args.reference}")
ref_frames, ref_fps = load_video(args.reference, args.max_frames)
print(f"{len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}")
print(
f"{len(ref_frames)} frames, {ref_fps:.1f} fps, {ref_frames[0].shape[1]}x{ref_frames[0].shape[0]}"
)
print(f"Loading test: {args.test}")
test_frames, test_fps = load_video(args.test, args.max_frames)
print(f"{len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}")
print(
f"{len(test_frames)} frames, {test_fps:.1f} fps, {test_frames[0].shape[1]}x{test_frames[0].shape[0]}"
)
if ref_frames[0].shape != test_frames[0].shape:
print(f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}")
print(
f"Warning: resolution mismatch {ref_frames[0].shape} vs {test_frames[0].shape}"
)
print("Resizing test frames to match reference...")
h, w = ref_frames[0].shape[:2]
test_frames = [
cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4)
for f in test_frames
cv2.resize(f, (w, h), interpolation=cv2.INTER_LANCZOS4) for f in test_frames
]
print("Computing metrics...")
@@ -282,23 +301,29 @@ def main():
print_report(results, args.reference, args.test)
if args.diff_video:
save_diff_video(ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale)
save_diff_video(
ref_frames, test_frames, args.diff_video, ref_fps, args.diff_scale
)
if args.csv:
import csv
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"])
writer.writerow(
["frame", "psnr", "ssim", "mean_diff", "max_diff", "color_dist"]
)
for i in range(results["num_frames"]):
writer.writerow([
i,
f"{results['psnr'][i]:.4f}",
f"{results['ssim'][i]:.6f}",
f"{results['mean_diff'][i]:.4f}",
f"{results['max_diff'][i]:.1f}",
f"{results['color_dist'][i]:.6f}",
])
writer.writerow(
[
i,
f"{results['psnr'][i]:.4f}",
f"{results['ssim'][i]:.6f}",
f"{results['mean_diff'][i]:.4f}",
f"{results['max_diff'][i]:.1f}",
f"{results['color_dist'][i]:.6f}",
]
)
print(f"Per-frame metrics saved to {args.csv}")

View File

@@ -158,10 +158,14 @@ def analyze_video(frames, chunk_size=None, compute_flow=False):
boundary_metrics = []
for b in boundaries:
if b < n and b > 0:
pre = metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1]
pre = (
metrics["frame_diff"][b - 1] if b > 1 else metrics["frame_diff"][1]
)
at = metrics["frame_diff"][b]
ratio = at / (pre + 1e-10)
brightness_jump = metrics["brightness"][b] - metrics["brightness"][b - 1]
brightness_jump = (
metrics["brightness"][b] - metrics["brightness"][b - 1]
)
contrast_jump = (
(metrics["contrast"][b] - metrics["contrast"][b - 1])
/ (metrics["contrast"][b - 1] + 1e-10)
@@ -198,7 +202,9 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
print("VIDEO QUALITY REPORT")
print("=" * 72)
print(f" File: {path}")
print(f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}")
print(
f" Total frames: {total_frames} Analyzed: {frames_analyzed} FPS: {fps:.1f}"
)
duration = total_frames / fps if fps > 0 else 0
print(f" Duration: {duration:.1f}s")
print()
@@ -211,52 +217,76 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
print("-" * 40)
if n_uniform:
frames_list = np.where(metrics["is_uniform"])[0][:10]
print(f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}")
print(
f" Uniform/blank frames: {n_uniform} — frames {list(frames_list)}{'...' if n_uniform > 10 else ''}"
)
if n_noisy:
frames_list = np.where(metrics["is_noisy"])[0][:10]
print(f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}")
print(
f" Noisy frames: {n_noisy} — frames {list(frames_list)}{'...' if n_noisy > 10 else ''}"
)
print()
print("SHARPNESS")
print("-" * 40)
print(f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}")
print(f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}")
print(
f" Laplacian var: mean={np.mean(sl):8.1f} min={np.min(sl):8.1f} max={np.max(sl):8.1f} std={np.std(sl):.1f}"
)
print(
f" Gradient mag: mean={np.mean(sg):8.2f} min={np.min(sg):8.2f} max={np.max(sg):8.2f} std={np.std(sg):.2f}"
)
if np.std(sl) / (np.mean(sl) + 1e-10) > 0.3:
print(" ⚠ High sharpness variation — possible blur artifacts")
print()
print("BRIGHTNESS & CONTRAST")
print("-" * 40)
print(f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}")
print(f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}")
print(
f" Brightness: mean={np.mean(br):6.1f} min={np.min(br):6.1f} max={np.max(br):6.1f} std={np.std(br):.2f}"
)
print(
f" Contrast (std): mean={np.mean(ct):6.1f} min={np.min(ct):6.1f} max={np.max(ct):6.1f} std={np.std(ct):.2f}"
)
if np.std(br) > 3.0:
print(" ⚠ Brightness instability — may indicate chunk boundary artifacts")
print()
print("COLOR DISTRIBUTION (BGR)")
print("-" * 40)
print(f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}")
print(f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}")
print(f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}")
print(
f" Blue: mean={np.mean(metrics['color_mean_b']):6.1f} std={np.std(metrics['color_mean_b']):.2f}"
)
print(
f" Green: mean={np.mean(metrics['color_mean_g']):6.1f} std={np.std(metrics['color_mean_g']):.2f}"
)
print(
f" Red: mean={np.mean(metrics['color_mean_r']):6.1f} std={np.std(metrics['color_mean_r']):.2f}"
)
print()
print("TEMPORAL STABILITY")
print("-" * 40)
fd_nz = fd[1:] # skip first frame (always 0)
if len(fd_nz) > 0:
print(f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}")
print(
f" Frame diff: mean={np.mean(fd_nz):6.2f} min={np.min(fd_nz):6.2f} max={np.max(fd_nz):6.2f} std={np.std(fd_nz):.2f}"
)
if np.std(fd_nz) / (np.mean(fd_nz) + 1e-10) > 0.5:
print(" ⚠ High diff variance — jitter or discontinuities")
if "flow_mean" in metrics:
fm = metrics["flow_mean"][1:]
print(f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}")
print(
f" Optical flow: mean={np.mean(fm):6.2f} max_frame={np.max(metrics['flow_max'][1:]):.1f}"
)
print()
# Chunk boundaries
if "boundaries" in metrics and metrics["boundaries"]:
print("CHUNK BOUNDARIES")
print("-" * 40)
print(f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}")
print(
f" {'Frame':>6} {'Diff ratio':>10} {'Brightness':>10} {'Contrast %':>10} {'Sharpness %':>11}"
)
for bm in metrics["boundaries"]:
print(
f" {bm['frame']:6d}"
@@ -267,7 +297,9 @@ def print_report(metrics, path, fps, total_frames, frames_analyzed):
)
avg_ratio = np.mean([b["diff_ratio"] for b in metrics["boundaries"]])
if avg_ratio > 2.0:
print(f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions")
print(
f" ⚠ Boundary diff ratio {avg_ratio:.1f}x — visible chunk transitions"
)
print()
# Overall grade
@@ -303,9 +335,7 @@ def main():
type=int,
help="Frames per chunk for boundary analysis (e.g., 32)",
)
parser.add_argument(
"--start", type=int, default=0, help="Start frame (default: 0)"
)
parser.add_argument("--start", type=int, default=0, help="Start frame (default: 0)")
parser.add_argument("--end", type=int, help="End frame (default: all)")
parser.add_argument(
"--flow",
@@ -329,8 +359,14 @@ def main():
import csv
keys = [
"sharpness_lap", "sharpness_grad", "brightness", "contrast",
"color_mean_b", "color_mean_g", "color_mean_r", "frame_diff",
"sharpness_lap",
"sharpness_grad",
"brightness",
"contrast",
"color_mean_b",
"color_mean_g",
"color_mean_r",
"frame_diff",
]
if args.flow:
keys += ["flow_mean", "flow_max"]

443
tests/test_generate_dev.py Normal file
View File

@@ -0,0 +1,443 @@
"""Tests for LTX-2 dev model generation pipeline."""
import mlx.core as mx
import pytest
from mlx_video.generate_dev import (
AUDIO_LATENTS_PER_SECOND,
AUDIO_SAMPLE_RATE,
DEFAULT_NEGATIVE_PROMPT,
cfg_delta,
compute_audio_frames,
create_audio_position_grid,
create_position_grid,
ltx2_scheduler,
)
class TestLTX2Scheduler:
"""Tests for the LTX-2 sigma scheduler."""
def test_scheduler_output_shape(self):
"""Scheduler should return steps+1 sigma values."""
steps = 20
sigmas = ltx2_scheduler(steps=steps)
assert sigmas.shape == (
steps + 1,
), f"Expected ({steps + 1},), got {sigmas.shape}"
def test_scheduler_starts_at_one(self):
"""Sigma schedule should start at 1.0."""
sigmas = ltx2_scheduler(steps=20)
assert (
abs(sigmas[0].item() - 1.0) < 1e-5
), f"Expected 1.0, got {sigmas[0].item()}"
def test_scheduler_ends_at_zero(self):
"""Sigma schedule should end at 0.0."""
sigmas = ltx2_scheduler(steps=20)
assert abs(sigmas[-1].item()) < 1e-5, f"Expected 0.0, got {sigmas[-1].item()}"
def test_scheduler_monotonically_decreasing(self):
"""Sigma values should monotonically decrease."""
sigmas = ltx2_scheduler(steps=20)
sigmas_list = sigmas.tolist()
for i in range(len(sigmas_list) - 1):
assert (
sigmas_list[i] >= sigmas_list[i + 1]
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
def test_scheduler_dtype(self):
"""Scheduler should return float32 array."""
sigmas = ltx2_scheduler(steps=20)
assert sigmas.dtype == mx.float32, f"Expected float32, got {sigmas.dtype}"
def test_scheduler_with_num_tokens(self):
"""Scheduler should accept num_tokens parameter."""
sigmas_default = ltx2_scheduler(steps=20, num_tokens=None)
sigmas_custom = ltx2_scheduler(steps=20, num_tokens=1920)
# Both should be valid arrays
assert sigmas_default.shape == (21,)
assert sigmas_custom.shape == (21,)
def test_scheduler_no_stretch(self):
"""Scheduler without stretching should still work."""
sigmas = ltx2_scheduler(steps=20, stretch=False)
assert sigmas.shape == (21,)
assert sigmas[0].item() > 0
assert sigmas[-1].item() == 0.0
def test_scheduler_different_steps(self):
"""Scheduler should work with different step counts."""
for steps in [5, 10, 20, 40, 50]:
sigmas = ltx2_scheduler(steps=steps)
assert sigmas.shape == (steps + 1,), f"Failed for steps={steps}"
class TestCreatePositionGrid:
"""Tests for position grid creation."""
def test_position_grid_shape(self):
"""Position grid should have correct shape (B, 3, num_patches, 2)."""
batch_size = 1
num_frames = 5
height = 16
width = 24
positions = create_position_grid(batch_size, num_frames, height, width)
num_patches = num_frames * height * width
expected_shape = (batch_size, 3, num_patches, 2)
assert (
positions.shape == expected_shape
), f"Expected {expected_shape}, got {positions.shape}"
def test_position_grid_dtype(self):
"""Position grid should be float32 for RoPE precision."""
positions = create_position_grid(1, 5, 16, 24)
assert (
positions.dtype == mx.float32
), f"Expected float32 for RoPE precision, got {positions.dtype}"
def test_position_grid_batch_size(self):
"""Position grid should respect batch size."""
for batch_size in [1, 2, 4]:
positions = create_position_grid(batch_size, 5, 16, 24)
assert positions.shape[0] == batch_size
def test_position_grid_temporal_dimension(self):
"""Temporal dimension should have values scaled by fps."""
positions = create_position_grid(1, 5, 16, 24, fps=24.0)
temporal = positions[0, 0, :, :] # (num_patches, 2)
# Values should be in seconds (divided by fps)
max_temporal = mx.max(temporal).item()
# For 5 latent frames at scale 8, max pixel frame ~ 40, divided by 24 fps ~ 1.67s
assert max_temporal < 10, f"Temporal values too large: {max_temporal}"
def test_position_grid_spatial_dimensions(self):
"""Spatial dimensions should have pixel-space values."""
positions = create_position_grid(1, 5, 16, 24, spatial_scale=32)
# Height dimension
height_vals = positions[0, 1, :, :]
max_height = mx.max(height_vals).item()
# 16 latent * 32 scale = 512 pixels
assert max_height <= 512, f"Height values too large: {max_height}"
# Width dimension
width_vals = positions[0, 2, :, :]
max_width = mx.max(width_vals).item()
# 24 latent * 32 scale = 768 pixels
assert max_width <= 768, f"Width values too large: {max_width}"
def test_position_grid_causal_fix(self):
"""Causal fix should adjust first frame temporal values."""
positions_causal = create_position_grid(1, 5, 16, 24, causal_fix=True)
positions_no_causal = create_position_grid(1, 5, 16, 24, causal_fix=False)
# They should be different due to causal fix
diff = mx.abs(positions_causal - positions_no_causal)
assert mx.max(diff).item() > 0, "Causal fix should change position values"
def test_position_grid_no_nan_or_inf(self):
"""Position grid should not contain NaN or Inf values."""
positions = create_position_grid(1, 5, 16, 24)
assert not mx.any(mx.isnan(positions)).item(), "Position grid contains NaN"
assert not mx.any(mx.isinf(positions)).item(), "Position grid contains Inf"
class TestCFGDelta:
"""Tests for CFG (Classifier-Free Guidance) delta calculation."""
def test_cfg_delta_shape(self):
"""CFG delta should have same shape as inputs."""
shape = (1, 1920, 128)
cond = mx.random.normal(shape)
uncond = mx.random.normal(shape)
delta = cfg_delta(cond, uncond, scale=4.0)
assert delta.shape == shape
def test_cfg_delta_scale_one(self):
"""CFG with scale=1.0 should return zero delta."""
shape = (1, 1920, 128)
cond = mx.random.normal(shape)
uncond = mx.random.normal(shape)
mx.eval(cond, uncond)
delta = cfg_delta(cond, uncond, scale=1.0)
mx.eval(delta)
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
assert (
mx.max(mx.abs(delta)).item() < 1e-6
), "CFG delta with scale=1.0 should be zero"
def test_cfg_delta_formula(self):
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
cond = mx.array([[[1.0, 2.0, 3.0]]])
uncond = mx.array([[[0.5, 1.0, 1.5]]])
scale = 4.0
delta = cfg_delta(cond, uncond, scale)
expected = (scale - 1.0) * (cond - uncond)
mx.eval(delta, expected)
diff = mx.max(mx.abs(delta - expected)).item()
assert diff < 1e-6, f"CFG delta formula mismatch: diff={diff}"
def test_cfg_delta_dtype_preservation(self):
"""CFG delta should preserve input dtype."""
for dtype in [mx.float32, mx.bfloat16]:
cond = mx.random.normal((1, 100, 64)).astype(dtype)
uncond = mx.random.normal((1, 100, 64)).astype(dtype)
delta = cfg_delta(cond, uncond, scale=4.0)
assert delta.dtype == dtype, f"Expected {dtype}, got {delta.dtype}"
class TestDefaultNegativePrompt:
"""Tests for the default negative prompt."""
def test_default_negative_prompt_exists(self):
"""Default negative prompt should be defined."""
assert DEFAULT_NEGATIVE_PROMPT is not None
assert len(DEFAULT_NEGATIVE_PROMPT) > 0
def test_default_negative_prompt_contains_quality_terms(self):
"""Default negative prompt should contain quality-related terms."""
prompt_lower = DEFAULT_NEGATIVE_PROMPT.lower()
# Check for common negative quality terms
assert "blurry" in prompt_lower, "Should contain 'blurry'"
assert (
"low quality" in prompt_lower or "low contrast" in prompt_lower
), "Should contain quality-related terms"
class TestInputValidation:
"""Tests for input validation in generate_video_dev."""
def test_height_divisible_by_32(self):
"""Height must be divisible by 32."""
# This would be tested via the actual function, but we can test the validation logic
valid_heights = [256, 384, 512, 640, 768]
invalid_heights = [100, 300, 500, 700]
for h in valid_heights:
assert h % 32 == 0, f"Height {h} should be valid"
for h in invalid_heights:
assert h % 32 != 0, f"Height {h} should be invalid"
def test_width_divisible_by_32(self):
"""Width must be divisible by 32."""
valid_widths = [256, 384, 512, 640, 768, 1024]
invalid_widths = [100, 300, 500, 700]
for w in valid_widths:
assert w % 32 == 0, f"Width {w} should be valid"
for w in invalid_widths:
assert w % 32 != 0, f"Width {w} should be invalid"
def test_num_frames_formula(self):
"""Number of frames should be 1 + 8*k."""
valid_frames = [1, 9, 17, 25, 33, 41, 49, 57, 65]
for f in valid_frames:
assert (f - 1) % 8 == 0, f"Frames {f} should be valid (1 + 8*k)"
def test_num_frames_adjustment(self):
"""Invalid frame counts should be adjusted to nearest valid value."""
# Test the adjustment logic
test_cases = [
(30, 33), # 30 -> nearest valid is 33
(35, 33), # 35 -> nearest valid is 33
(40, 41), # 40 -> nearest valid is 41
(1, 1), # 1 is already valid
(33, 33), # 33 is already valid
]
for input_frames, expected in test_cases:
if input_frames % 8 != 1:
adjusted = round((input_frames - 1) / 8) * 8 + 1
assert (
adjusted == expected
), f"Expected {expected} for input {input_frames}, got {adjusted}"
class TestDenoiseWithCFGMocked:
"""Tests for denoise_with_cfg with mocked transformer."""
def test_sigmas_list_conversion(self):
"""Sigmas should be convertible to list."""
sigmas = ltx2_scheduler(steps=5)
sigmas_list = sigmas.tolist()
assert isinstance(sigmas_list, list)
assert len(sigmas_list) == 6 # steps + 1
class TestTilingDefault:
"""Tests for tiling default behavior."""
def test_tiling_default_is_none(self):
"""Default tiling should be 'none' for performance."""
import inspect
from mlx_video.generate_dev import generate_video_dev
sig = inspect.signature(generate_video_dev)
tiling_param = sig.parameters.get("tiling")
assert tiling_param is not None
assert (
tiling_param.default == "none"
), f"Expected default tiling='none', got '{tiling_param.default}'"
class TestLatentDimensions:
"""Tests for latent dimension calculations."""
def test_latent_height_calculation(self):
"""Latent height should be height // 32."""
test_cases = [(512, 16), (768, 24), (1024, 32)]
for height, expected_latent_h in test_cases:
latent_h = height // 32
assert (
latent_h == expected_latent_h
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
def test_latent_width_calculation(self):
"""Latent width should be width // 32."""
test_cases = [(512, 16), (768, 24), (1024, 32)]
for width, expected_latent_w in test_cases:
latent_w = width // 32
assert (
latent_w == expected_latent_w
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
def test_latent_frames_calculation(self):
"""Latent frames should be 1 + (num_frames - 1) // 8."""
test_cases = [(1, 1), (9, 2), (17, 3), (33, 5), (65, 9)]
for num_frames, expected_latent_f in test_cases:
latent_f = 1 + (num_frames - 1) // 8
assert (
latent_f == expected_latent_f
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
def test_num_tokens_calculation(self):
"""Number of tokens should be latent_f * latent_h * latent_w."""
# For 33 frames at 512x768
num_frames, height, width = 33, 512, 768
latent_f = 1 + (num_frames - 1) // 8 # 5
latent_h = height // 32 # 16
latent_w = width // 32 # 24
num_tokens = latent_f * latent_h * latent_w
expected = 5 * 16 * 24 # 1920
assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}"
class TestAudioPositionGrid:
"""Tests for audio position grid creation."""
def test_audio_position_grid_shape(self):
"""Audio position grid should have correct shape (B, 1, T, 2)."""
batch_size = 1
audio_frames = 34 # ~1.36 seconds at 25 latent frames/sec
positions = create_audio_position_grid(batch_size, audio_frames)
expected_shape = (batch_size, 1, audio_frames, 2)
assert (
positions.shape == expected_shape
), f"Expected {expected_shape}, got {positions.shape}"
def test_audio_position_grid_dtype(self):
"""Audio position grid should be float32."""
positions = create_audio_position_grid(1, 34)
assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
def test_audio_position_grid_batch_size(self):
"""Audio position grid should respect batch size."""
for batch_size in [1, 2, 4]:
positions = create_audio_position_grid(batch_size, 34)
assert positions.shape[0] == batch_size
def test_audio_position_grid_temporal_values(self):
"""Audio positions should be in seconds."""
positions = create_audio_position_grid(1, 34)
# Values should be in seconds (small values for ~1 second of audio)
max_val = mx.max(positions).item()
assert max_val < 10, f"Audio positions seem too large: {max_val}"
assert max_val > 0, "Audio positions should be positive"
def test_audio_position_grid_no_nan_or_inf(self):
"""Audio position grid should not contain NaN or Inf."""
positions = create_audio_position_grid(1, 34)
assert not mx.any(
mx.isnan(positions)
).item(), "Audio position grid contains NaN"
assert not mx.any(
mx.isinf(positions)
).item(), "Audio position grid contains Inf"
class TestComputeAudioFrames:
"""Tests for audio frame count calculation."""
def test_audio_frames_basic(self):
"""Audio frames should be proportional to video duration."""
# 33 frames at 24 fps = ~1.375 seconds
# At 25 latent frames/sec = ~34 audio frames
audio_frames = compute_audio_frames(33, 24.0)
assert audio_frames > 0
assert isinstance(audio_frames, int)
def test_audio_frames_scales_with_video(self):
"""More video frames should produce more audio frames."""
audio_33 = compute_audio_frames(33, 24.0)
audio_65 = compute_audio_frames(65, 24.0)
assert (
audio_65 > audio_33
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
def test_audio_frames_formula(self):
"""Audio frames should match expected formula."""
num_video_frames = 33
fps = 24.0
duration = num_video_frames / fps # ~1.375 seconds
expected = round(duration * AUDIO_LATENTS_PER_SECOND)
actual = compute_audio_frames(num_video_frames, fps)
assert actual == expected, f"Expected {expected}, got {actual}"
class TestAudioConstants:
"""Tests for audio constants."""
def test_audio_sample_rate(self):
"""Audio sample rate should be 24000 Hz."""
assert AUDIO_SAMPLE_RATE == 24000
def test_audio_latents_per_second(self):
"""Audio latents per second should be 25."""
assert AUDIO_LATENTS_PER_SECOND == 25.0
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -1,11 +1,9 @@
import pytest
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx.rope import (
precompute_freqs_cis,
)
from mlx_video.models.ltx.config import LTXRopeType
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
def create_video_position_grid(
@@ -20,7 +18,7 @@ def create_video_position_grid(
h_coords = np.arange(0, height)
w_coords = np.arange(0, width)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij")
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_ends = patch_starts + 1
@@ -36,6 +34,73 @@ def create_video_position_grid(
return mx.array(pixel_coords, dtype=dtype)
def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
"""Compute RoPE cos/sin using NumPy float64 as ground truth reference.
This mirrors the regular (non-double-precision) path in rope.py exactly,
but uses float64 throughout, so we can verify that the float32 MLX path
stays close to the true values.
"""
# positions_np: (B, 3, T, 2) in float64
n_pos_dims = positions_np.shape[1]
n_elem = 2 * n_pos_dims
# Middle-of-interval positions
mid = (positions_np[..., 0] + positions_np[..., 1]) / 2.0 # (B, 3, T)
# Frequency grid — matches generate_freq_grid() in rope.py:
# log_start = log(1)/log(theta) = 0
# log_end = log(theta)/log(theta) = 1
# pow_indices = theta^linspace(0, 1, num_indices) * pi/2
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = np.linspace(0.0, 1.0, num_indices, dtype=np.float64)
freq_indices = np.power(theta, lin_space) * (np.pi / 2) # (num_indices,)
# Fractional positions and scaling — matches generate_freqs()
# frac = pos / max_pos for each dim, then scale to [-1, 1]
frac_list = []
for d in range(n_pos_dims):
frac = mid[:, d, :] / max_pos[d] # (B, T)
frac_list.append(frac)
fractional = np.stack(frac_list, axis=-1) # (B, T, n_dims)
scaled = fractional * 2 - 1 # [-1, 1]
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
freqs = (
scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
)
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(
freqs.shape[0], freqs.shape[1], -1
) # (B, T, num_indices * n_dims)
cos_ref = np.cos(freqs)
sin_ref = np.sin(freqs)
# Split RoPE: pad to dim//2, reshape to (B, H, T, dim_per_head//2)
expected = dim // 2
pad_size = expected - cos_ref.shape[-1]
if pad_size > 0:
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
cos_ref = np.concatenate(
[np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1
)
sin_ref = np.concatenate(
[np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1
)
B, T, _ = cos_ref.shape
dim_per_head = dim // num_heads
cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
return cos_ref, sin_ref
class TestRoPEPositionPrecision:
"""Test suite for RoPE position precision requirements."""
@@ -65,10 +130,12 @@ class TestRoPEPositionPrecision:
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
# Verify cos/sin are in valid range [-1, 1]
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
"cos_freq values out of [-1, 1] range"
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
"sin_freq values out of [-1, 1] range"
assert (
mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item()
), "cos_freq values out of [-1, 1] range"
assert (
mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item()
), "sin_freq values out of [-1, 1] range"
def test_bfloat16_positions_cause_precision_loss(self):
"""bfloat16 positions should produce different (less precise) results than float32.
@@ -116,7 +183,9 @@ class TestRoPEPositionPrecision:
# The threshold here is intentionally low to catch the issue
precision_threshold = 1e-6
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
has_precision_loss = (
max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
)
# Document the precision loss (this is expected behavior)
if has_precision_loss:
@@ -125,18 +194,14 @@ class TestRoPEPositionPrecision:
print(f" Max sin difference: {max_sin_diff:.6e}")
# This assertion documents the issue - bfloat16 positions cause precision loss
assert has_precision_loss, \
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
assert (
has_precision_loss
), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
def test_double_precision_converts_to_float32_internally(self):
"""Verify that double_precision mode converts bfloat16 to float32 first."""
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
# The double precision path in rope.py line 434:
# indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
# This means bfloat16 -> float32 -> float64
# The precision is already lost at the bfloat16 -> float32 step
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
@@ -161,20 +226,127 @@ class TestRoPEPositionPrecision:
# Recommended: create positions in float32
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
assert positions.dtype == mx.float32, \
"Position grids should be created in float32 for RoPE precision"
assert (
positions.dtype == mx.float32
), "Position grids should be created in float32 for RoPE precision"
# Verify the position values are reasonable
# Temporal positions should be small (seconds)
temporal_positions = positions[:, 0, :, :]
assert mx.max(temporal_positions).item() < 100, \
"Temporal positions should be in seconds (small values)"
assert (
mx.max(temporal_positions).item() < 100
), "Temporal positions should be in seconds (small values)"
# Spatial positions should be larger (pixels)
spatial_h = positions[:, 1, :, :]
spatial_w = positions[:, 2, :, :]
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
assert (
mx.max(spatial_h).item() > 0
), "Spatial height positions should be positive"
assert (
mx.max(spatial_w).item() > 0
), "Spatial width positions should be positive"
def test_float32_positions_match_numpy_float64_reference(self):
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
This is the key correctness test. We compute RoPE with NumPy in float64
(ground truth) and verify that the MLX float32 path produces nearly
identical results. The max allowed diff (1e-5) is well below the error
we saw with bfloat16 positions (~2.0 max diff, cosine sim 0.88).
"""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_np = np.array(positions).astype(np.float64)
dim = 128
theta = 10000.0
max_pos = [20, 2048, 2048]
num_heads = 32
# MLX result (float32 path, non-double-precision)
cos_mlx, sin_mlx = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=theta,
max_pos=max_pos,
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
# NumPy float64 reference
cos_ref, sin_ref = _numpy_reference_rope(
positions_np, dim, theta, max_pos, num_heads
)
cos_mlx_np = np.array(cos_mlx)
sin_mlx_np = np.array(sin_mlx)
max_cos_diff = np.max(np.abs(cos_mlx_np - cos_ref))
max_sin_diff = np.max(np.abs(sin_mlx_np - sin_ref))
# Cosine similarity (flatten for single scalar)
cos_flat = cos_mlx_np.flatten()
ref_flat = cos_ref.flatten()
cosine_sim = np.dot(cos_flat, ref_flat) / (
np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)
)
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
assert (
max_cos_diff < 0.01
), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert (
max_sin_diff < 0.01
), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert (
cosine_sim > 0.9999
), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
def test_high_frequency_amplification_regression(self):
"""Regression test for the specific failure mode: high-frequency index amplification.
With production-sized grids (5x16x16 = 1280 tokens), fractional positions
like 0.000391 get multiplied by frequency indices up to ~15708. In bfloat16
the fractional part is quantized, producing raw freq errors of ~6.14 and
cos/sin sign flips (max_diff ~2.0). Float32 must keep max_diff < 0.01.
"""
# Use a production-like grid size
positions = create_video_position_grid(1, 5, 16, 16, dtype=mx.float32)
positions_np = np.array(positions).astype(np.float64)
dim = 128
theta = 10000.0
max_pos = [20, 2048, 2048]
num_heads = 32
cos_mlx, sin_mlx = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=theta,
max_pos=max_pos,
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
cos_ref, sin_ref = _numpy_reference_rope(
positions_np, dim, theta, max_pos, num_heads
)
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
assert (
max_cos_diff < 0.01
), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
assert (
max_sin_diff < 0.01
), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
class TestRoPEInterleaved:
@@ -201,43 +373,144 @@ class TestRoPEInterleaved:
assert not mx.any(mx.isnan(sin_freq)).item()
class TestRoPEWarnings:
"""Tests for RoPE warnings."""
class TestRoPEInputCasting:
"""Tests that precompute_freqs_cis casts positions to float32 internally.
def test_bfloat16_positions_trigger_warning(self):
"""Verify that bfloat16 positions trigger a UserWarning."""
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
The fix in rope.py ensures that regardless of the input dtype, positions are
cast to float32 before any computation. This class verifies that behavior
for both the regular and double-precision paths.
"""
with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"):
precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
def test_float32_positions_no_warning(self):
"""Verify that float32 positions do NOT trigger a warning."""
def test_regular_path_outputs_float32(self):
"""Regular path: both float32 and bfloat16 inputs produce float32 output."""
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = positions_f32.astype(mx.bfloat16)
# This should not raise any warnings
import warnings
with warnings.catch_warnings():
warnings.simplefilter("error") # Turn warnings into errors
precompute_freqs_cis(
indices_grid=positions_f32,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
kwargs = dict(
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs)
# Both produce float32 output regardless of input dtype
assert cos_f32.dtype == mx.float32
assert cos_bf16.dtype == mx.float32
assert sin_f32.dtype == mx.float32
assert sin_bf16.dtype == mx.float32
# No NaN/Inf in either
assert not mx.any(mx.isnan(cos_bf16)).item()
assert not mx.any(mx.isinf(cos_bf16)).item()
def test_double_precision_path_outputs_float32(self):
"""Double-precision path: both float32 and bfloat16 inputs produce float32 output."""
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict(
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs)
assert cos_f32.dtype == mx.float32
assert cos_bf16.dtype == mx.float32
assert sin_f32.dtype == mx.float32
assert sin_bf16.dtype == mx.float32
assert not mx.any(mx.isnan(cos_bf16)).item()
assert not mx.any(mx.isinf(cos_bf16)).item()
def test_float16_input_also_cast_to_float32(self):
"""Float16 input should also be handled correctly."""
positions_f16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float16)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_f16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
assert not mx.any(mx.isnan(cos_freq)).item()
class TestDoublePrecisionRopeConfig:
"""Tests for the conditional double_precision_rope logic in LTXModelConfig."""
def test_ltx2_forces_double_precision_rope_false(self):
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
assert (
config.double_precision_rope is False
), "LTX-2 should force double_precision_rope=False regardless of input"
def test_ltx23_preserves_double_precision_rope_true(self):
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
assert (
config.double_precision_rope is True
), "LTX-2.3 should preserve double_precision_rope=True"
def test_ltx23_preserves_double_precision_rope_false(self):
"""LTX-2.3 with double_precision_rope=False should stay False."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
assert (
config.double_precision_rope is False
), "LTX-2.3 should respect double_precision_rope=False when explicitly set"
def test_ltx2_default_double_precision_rope(self):
"""LTX-2 default (double_precision_rope not set) should be False."""
config = LTXModelConfig(has_prompt_adaln=False)
assert config.double_precision_rope is False
def test_ltx23_default_double_precision_rope(self):
"""LTX-2.3 default (double_precision_rope not set) should be False (field default)."""
config = LTXModelConfig(has_prompt_adaln=True)
# The field default is False and __post_init__ doesn't override for LTX-2.3
assert config.double_precision_rope is False
def test_config_from_dict_ltx2(self):
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
config = LTXModelConfig.from_dict(
{
"has_prompt_adaln": False,
"double_precision_rope": True,
"rope_type": "split",
}
)
assert config.double_precision_rope is False
def test_config_from_dict_ltx23(self):
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
config = LTXModelConfig.from_dict(
{
"has_prompt_adaln": True,
"double_precision_rope": True,
"rope_type": "split",
}
)
assert config.double_precision_rope is True
class TestRoPESplit:
@@ -270,10 +543,12 @@ class TestRoPESplit:
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
dim_per_head = dim // num_heads
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
assert cos_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert sin_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {sin_freq.shape}"
assert (
cos_freq.shape == expected_shape
), f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert (
sin_freq.shape == expected_shape
), f"Expected shape {expected_shape}, got {sin_freq.shape}"
if __name__ == "__main__":

View File

@@ -1,11 +1,11 @@
"""Tests for VAE streaming and chunked conv features."""
import pytest
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import (
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
compute_trapezoidal_mask_1d,
decode_with_tiling,
@@ -50,7 +50,7 @@ class TestChunkedConv:
np.array(out_chunked),
rtol=1e-5,
atol=1e-5,
err_msg="Chunked conv output differs from regular output"
err_msg="Chunked conv output differs from regular output",
)
def test_chunked_conv_small_input_passthrough(self):
@@ -117,13 +117,17 @@ class TestProgressiveFrameSaving:
frames_received = []
def on_frames_ready(frames: mx.array, start_idx: int):
frames_received.append({
'shape': frames.shape,
'start_idx': start_idx,
})
frames_received.append(
{
"shape": frames.shape,
"start_idx": start_idx,
}
)
# Create a mock decoder that just returns scaled input
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
def mock_decoder(
x, causal=False, timestep=None, debug=False, chunked_conv=False
):
# Simulate VAE output: upsample 8x temporal, 32x spatial
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
@@ -154,7 +158,9 @@ class TestProgressiveFrameSaving:
# All received frames should have correct channel count
for received in frames_received:
assert received['shape'][1] == 3, f"Expected 3 channels, got {received['shape'][1]}"
assert (
received["shape"][1] == 3
), f"Expected 3 channels, got {received['shape'][1]}"
def test_on_frames_ready_covers_all_frames(self):
"""Verify all frames are emitted via callbacks."""
@@ -165,7 +171,9 @@ class TestProgressiveFrameSaving:
for i in range(num_frames):
all_frame_indices.add(start_idx + i)
def mock_decoder(x, causal=False, timestep=None, debug=False, chunked_conv=False):
def mock_decoder(
x, causal=False, timestep=None, debug=False, chunked_conv=False
):
b, c, f, h, w = x.shape
out_f = 1 + (f - 1) * 8
out_h = h * 32
@@ -191,24 +199,29 @@ class TestProgressiveFrameSaving:
expected_frames = 1 + (12 - 1) * 8 # 89 frames
# All frames should have been emitted
assert len(all_frame_indices) == expected_frames, \
f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
assert all_frame_indices == set(range(expected_frames)), \
"Not all frame indices were covered"
assert (
len(all_frame_indices) == expected_frames
), f"Expected {expected_frames} frames, got {len(all_frame_indices)}"
assert all_frame_indices == set(
range(expected_frames)
), "Not all frame indices were covered"
class TestAutoChunkedConv:
"""Tests for auto-enabling chunked_conv based on tiling mode."""
@pytest.mark.parametrize("tiling_mode,should_enable", [
("conservative", True),
("none", True),
("auto", True),
("default", True),
("spatial", True),
("aggressive", False),
("temporal", False),
])
@pytest.mark.parametrize(
"tiling_mode,should_enable",
[
("conservative", True),
("none", True),
("auto", True),
("default", True),
("spatial", True),
("aggressive", False),
("temporal", False),
],
)
def test_chunked_conv_auto_enable(self, tiling_mode: str, should_enable: bool):
"""Verify chunked_conv is auto-enabled for correct tiling modes."""
# The logic is: tiling_mode in ("conservative", "none", "auto", "default", "spatial")
@@ -216,8 +229,9 @@ class TestAutoChunkedConv:
use_chunked_conv = tiling_mode in expected_modes
assert use_chunked_conv == should_enable, \
f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
assert (
use_chunked_conv == should_enable
), f"For tiling_mode='{tiling_mode}', expected chunked_conv={should_enable}"
class TestTrapezoidalMask:
@@ -250,7 +264,9 @@ class TestTrapezoidalMask:
# Right ramp should be decreasing
right_ramp = mask_np[-8:]
assert np.all(np.diff(right_ramp) <= 0), "Right ramp not monotonically decreasing"
assert np.all(
np.diff(right_ramp) <= 0
), "Right ramp not monotonically decreasing"
def test_temporal_mask_starts_from_zero(self):
"""Verify temporal mask (left_starts_from_0=True) starts from 0."""

View File

@@ -2,31 +2,33 @@
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# RoPE Tests
# ---------------------------------------------------------------------------
class TestRoPE:
"""Tests for 3-way factorized RoPE."""
def test_rope_params_shape(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(1024, 64)
mx.eval(freqs)
assert freqs.shape == (1024, 32, 2) # [max_seq_len, dim//2, 2]
def test_rope_params_different_dims(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
for dim in [32, 64, 128]:
freqs = rope_params(512, dim)
mx.eval(freqs)
assert freqs.shape == (512, dim // 2, 2)
def test_rope_params_cos_sin_range(self):
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(256, 64)
mx.eval(freqs)
cos_vals = np.array(freqs[:, :, 0])
@@ -36,14 +38,16 @@ class TestRoPE:
def test_rope_params_position_zero(self):
"""At position 0, cos should be 1 and sin should be 0."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
freqs = rope_params(10, 64)
mx.eval(freqs)
np.testing.assert_allclose(np.array(freqs[0, :, 0]), 1.0, atol=1e-6)
np.testing.assert_allclose(np.array(freqs[0, :, 1]), 0.0, atol=1e-6)
def test_rope_apply_output_shape(self):
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, L, N, D = 1, 24, 4, 32 # batch, seq, heads, head_dim
x = mx.random.normal((B, L, N, D))
freqs = rope_params(1024, D)
@@ -54,7 +58,8 @@ class TestRoPE:
def test_rope_apply_preserves_norm(self):
"""RoPE rotation should preserve vector norms."""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 3, 4
L = F * H * W
@@ -74,7 +79,8 @@ class TestRoPE:
def test_rope_apply_with_padding(self):
"""When seq_len < L, extra tokens should be preserved unchanged."""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 1, 2, 16
F, H, W = 2, 2, 2
seq_len = F * H * W # 8
@@ -94,7 +100,8 @@ class TestRoPE:
def test_rope_apply_batch(self):
"""Test with batch_size > 1 and different grid sizes."""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan_2.rope import rope_apply, rope_params
B, N, D = 2, 2, 16
grids = [(2, 3, 4), (2, 3, 4)]
L = 2 * 3 * 4
@@ -122,9 +129,11 @@ class TestRoPE:
# Attention Tests
# ---------------------------------------------------------------------------
class TestWanRMSNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
@@ -133,7 +142,8 @@ class TestWanRMSNorm:
def test_zero_mean_variance(self):
"""RMS norm should make RMS ≈ 1 before scaling."""
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(64)
x = mx.random.normal((1, 5, 64)) * 10.0
out = norm(x)
@@ -146,7 +156,8 @@ class TestWanRMSNorm:
def test_dtype_preservation(self):
"""RMSNorm weight is float32, so output is promoted to float32."""
from mlx_video.models.wan.attention import WanRMSNorm
from mlx_video.models.wan_2.attention import WanRMSNorm
norm = WanRMSNorm(32)
x = mx.random.normal((1, 4, 32)).astype(mx.bfloat16)
out = norm(x)
@@ -157,7 +168,8 @@ class TestWanRMSNorm:
class TestWanLayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
@@ -165,7 +177,8 @@ class TestWanLayerNorm:
assert out.shape == (2, 10, 64)
def test_without_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(64, elementwise_affine=False)
x = mx.random.normal((1, 4, 64))
out = norm(x)
@@ -177,7 +190,8 @@ class TestWanLayerNorm:
np.testing.assert_allclose(np.std(out_np[i]), 1.0, rtol=0.1)
def test_with_affine(self):
from mlx_video.models.wan.attention import WanLayerNorm
from mlx_video.models.wan_2.attention import WanLayerNorm
norm = WanLayerNorm(32, elementwise_affine=True)
assert hasattr(norm, "weight")
assert hasattr(norm, "bias")
@@ -194,8 +208,9 @@ class TestWanSelfAttention:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
B, L = 1, 24
F, H, W = 2, 3, 4
@@ -206,21 +221,24 @@ class TestWanSelfAttention:
assert out.shape == (B, L, self.dim)
def test_with_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan_2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=True)
assert attn.norm_q is not None
assert attn.norm_k is not None
def test_without_qk_norm(self):
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan_2.attention import WanSelfAttention
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
assert attn.norm_q is None
assert attn.norm_k is None
def test_masking(self):
"""Test that masking works: shorter seq_lens should mask later tokens."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads, qk_norm=False)
B, L = 1, 24
F, H, W = 2, 3, 4
@@ -244,7 +262,8 @@ class TestWanCrossAttention:
self.num_heads = 4
def test_output_shape(self):
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 24, 16
x = mx.random.normal((B, L_q, self.dim))
@@ -254,7 +273,8 @@ class TestWanCrossAttention:
assert out.shape == (B, L_q, self.dim)
def test_with_context_mask(self):
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
B, L_q, L_kv = 1, 12, 16
x = mx.random.normal((B, L_q, self.dim))
@@ -268,6 +288,7 @@ class TestWanCrossAttention:
# bfloat16 Autocast Tests
# ---------------------------------------------------------------------------
class TestBFloat16Autocast:
"""Tests that attention and FFN cast inputs to weight dtype (bfloat16)
for efficient matmul, matching official PyTorch autocast behavior."""
@@ -290,8 +311,9 @@ class TestBFloat16Autocast:
def test_self_attn_casts_to_weight_dtype(self):
"""Self-attention should cast input to weight dtype for QKV projections."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -304,7 +326,8 @@ class TestBFloat16Autocast:
def test_cross_attn_casts_to_weight_dtype(self):
"""Cross-attention should cast input to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -317,7 +340,8 @@ class TestBFloat16Autocast:
def test_cross_attn_kv_cache_uses_weight_dtype(self):
"""prepare_kv should cast context to weight dtype."""
from mlx_video.models.wan.attention import WanCrossAttention
from mlx_video.models.wan_2.attention import WanCrossAttention
attn = WanCrossAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -329,7 +353,8 @@ class TestBFloat16Autocast:
def test_ffn_casts_to_weight_dtype(self):
"""FFN should cast input to weight dtype for linear layers."""
from mlx_video.models.wan.transformer import WanFFN
from mlx_video.models.wan_2.transformer import WanFFN
ffn = WanFFN(self.dim, 128)
ffn.update(self._to_bf16(ffn.parameters()))
@@ -341,8 +366,9 @@ class TestBFloat16Autocast:
def test_self_attn_rope_in_float32(self):
"""RoPE should be applied in float32 for precision, even with bf16 weights."""
from mlx_video.models.wan.attention import WanSelfAttention
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.attention import WanSelfAttention
from mlx_video.models.wan_2.rope import rope_params
attn = WanSelfAttention(self.dim, self.num_heads)
attn.update(self._to_bf16(attn.parameters()))
@@ -355,8 +381,9 @@ class TestBFloat16Autocast:
def test_block_float32_residual_with_bf16_weights(self):
"""Full block: residual stream stays float32, matmuls use bf16 weights."""
from mlx_video.models.wan.transformer import WanAttentionBlock
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
from mlx_video.models.wan_2.transformer import WanAttentionBlock
block = WanAttentionBlock(self.dim, 128, self.num_heads, cross_attn_norm=True)
block.update(self._to_bf16(block.parameters()))

View File

@@ -1,17 +1,17 @@
"""Tests for Wan model configuration."""
import pytest
# ---------------------------------------------------------------------------
# Config Tests
# ---------------------------------------------------------------------------
class TestWanModelConfig:
"""Tests for WanModelConfig dataclass."""
def test_default_values(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
assert config.dim == 5120
assert config.ffn_dim == 13824
@@ -32,12 +32,14 @@ class TestWanModelConfig:
assert config.text_len == 512
def test_head_dim_property(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
assert config.head_dim == 128 # 5120 // 40
def test_to_dict_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
d = config.to_dict()
assert isinstance(d, dict)
@@ -46,7 +48,8 @@ class TestWanModelConfig:
assert d["boundary"] == 0.875
def test_t5_config_values(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
assert config.t5_vocab_size == 256384
assert config.t5_dim == 4096
@@ -61,11 +64,13 @@ class TestWanModelConfig:
# Wan2.1 Config Tests
# ---------------------------------------------------------------------------
class TestWan21Config:
"""Tests for Wan2.1 config presets."""
def test_wan21_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
assert config.model_version == "2.1"
assert config.dual_model is False
@@ -80,7 +85,8 @@ class TestWan21Config:
assert config.boundary == 0.0
def test_wan21_1_3b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
assert config.model_version == "2.1"
assert config.dual_model is False
@@ -92,7 +98,8 @@ class TestWan21Config:
assert config.sample_guide_scale == 5.0
def test_wan22_14b_factory(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_t2v_14b()
assert config.model_version == "2.2"
assert config.dual_model is True
@@ -103,7 +110,8 @@ class TestWan21Config:
assert config.boundary == 0.875
def test_wan21_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
@@ -111,7 +119,8 @@ class TestWan21Config:
assert d["sample_guide_scale"] == 5.0
def test_wan21_1_3b_config_to_dict(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
d = config.to_dict()
assert d["dim"] == 1536
@@ -119,7 +128,8 @@ class TestWan21Config:
def test_default_config_is_wan22(self):
"""Default WanModelConfig() should be Wan2.2 14B."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig()
assert config.model_version == "2.2"
assert config.dual_model is True

View File

@@ -3,17 +3,16 @@
import logging
import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Transformer Weight Conversion Tests
# ---------------------------------------------------------------------------
class TestSanitizeTransformerWeights:
def test_patch_embedding_reshape(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
@@ -24,7 +23,8 @@ class TestSanitizeTransformerWeights:
assert out["patch_embedding_proj.weight"].shape == (5120, 16 * 1 * 2 * 2)
def test_text_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"text_embedding.0.weight": mx.zeros((64, 32)),
"text_embedding.0.bias": mx.zeros((64,)),
@@ -38,7 +38,8 @@ class TestSanitizeTransformerWeights:
assert "text_embedding_1.bias" in out
def test_time_embedding_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"time_embedding.0.weight": mx.zeros((64, 32)),
"time_embedding.2.weight": mx.zeros((64, 64)),
@@ -48,7 +49,8 @@ class TestSanitizeTransformerWeights:
assert "time_embedding_1.weight" in out
def test_time_projection_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"time_projection.1.weight": mx.zeros((384, 64)),
"time_projection.1.bias": mx.zeros((384,)),
@@ -58,7 +60,8 @@ class TestSanitizeTransformerWeights:
assert "time_projection.bias" in out
def test_ffn_rename(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"blocks.0.ffn.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.0.bias": mx.zeros((128,)),
@@ -72,7 +75,8 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.ffn.fc2.bias" in out
def test_freqs_skipped(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"freqs": mx.zeros((1024, 64, 2)),
"blocks.0.norm1.weight": mx.zeros((64,)),
@@ -82,7 +86,8 @@ class TestSanitizeTransformerWeights:
assert "blocks.0.norm1.weight" in out
def test_passthrough_keys(self):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"blocks.0.self_attn.q.weight": mx.zeros((64, 64)),
"blocks.0.self_attn.k.weight": mx.zeros((64, 64)),
@@ -97,7 +102,8 @@ class TestSanitizeTransformerWeights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_transformer_weights
from mlx_video.models.wan_2.convert import sanitize_wan_transformer_weights
weights = {
"patch_embedding.weight": mx.random.normal((5120, 16, 1, 2, 2)),
"patch_embedding.bias": mx.random.normal((5120,)),
@@ -113,14 +119,15 @@ class TestSanitizeTransformerWeights:
"head.head.weight": mx.zeros((64, 64)),
"freqs": mx.zeros((1024, 64, 2)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_transformer_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeT5Weights:
def test_gate_rename(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = {
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
"blocks.0.ffn.fc1.weight": mx.zeros((128, 64)),
@@ -132,7 +139,8 @@ class TestSanitizeT5Weights:
assert "blocks.0.ffn.fc2.weight" in out
def test_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.attn.q.weight": mx.zeros((64, 64)),
@@ -143,7 +151,8 @@ class TestSanitizeT5Weights:
assert key in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_t5_weights
from mlx_video.models.wan_2.convert import sanitize_wan_t5_weights
weights = {
"token_embedding.weight": mx.zeros((100, 64)),
"blocks.0.ffn.gate.0.weight": mx.zeros((128, 64)),
@@ -151,14 +160,15 @@ class TestSanitizeT5Weights:
"blocks.0.ffn.fc2.weight": mx.zeros((64, 128)),
"norm.weight": mx.zeros((64,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_t5_weights(weights)
assert "Unconsumed" not in caplog.text
class TestSanitizeVAEWeights:
def test_conv3d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)), # [O, I, D, H, W]
}
@@ -166,7 +176,8 @@ class TestSanitizeVAEWeights:
assert out["decoder.conv1.weight"].shape == (8, 3, 3, 3, 4) # [O, D, H, W, I]
def test_conv2d_transpose(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)), # [O, I, H, W]
}
@@ -174,7 +185,8 @@ class TestSanitizeVAEWeights:
assert out["decoder.proj.weight"].shape == (16, 3, 3, 8) # [O, H, W, I]
def test_non_conv_passthrough(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"decoder.norm.weight": mx.zeros((64,)), # 1D, no transpose
"decoder.bias": mx.zeros((16,)),
@@ -184,7 +196,8 @@ class TestSanitizeVAEWeights:
assert out["decoder.bias"].shape == (16,)
def test_mixed_weights(self):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"conv3d.weight": mx.zeros((8, 4, 3, 3, 3)), # 5D
"conv2d.weight": mx.zeros((8, 4, 3, 3)), # 4D
@@ -198,14 +211,15 @@ class TestSanitizeVAEWeights:
assert out["norm.weight"].shape == (8,)
def test_no_unconsumed_keys(self, caplog):
from mlx_video.convert_wan import sanitize_wan_vae_weights
from mlx_video.models.wan_2.convert import sanitize_wan_vae_weights
weights = {
"decoder.conv1.weight": mx.zeros((8, 4, 3, 3, 3)),
"decoder.proj.weight": mx.zeros((16, 8, 3, 3)),
"decoder.norm.weight": mx.zeros((64,)),
"decoder.bias": mx.zeros((16,)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.convert_wan"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.convert"):
sanitize_wan_vae_weights(weights)
assert "Unconsumed" not in caplog.text
@@ -214,6 +228,7 @@ class TestSanitizeVAEWeights:
# Wan2.1 Conversion Tests
# ---------------------------------------------------------------------------
class TestWan21Convert:
"""Tests for Wan2.1 conversion support."""
@@ -222,7 +237,7 @@ class TestWan21Convert:
# Create a Wan2.1-style directory (no low_noise_model subdir)
(tmp_path / "dummy.safetensors").touch()
# The auto-detect logic: no low_noise_model dir → 2.1
from pathlib import Path
low = tmp_path / "low_noise_model"
assert not low.exists()
# Simulates auto detection
@@ -233,7 +248,7 @@ class TestWan21Convert:
"""Auto-detect dual-model directory as Wan2.2."""
(tmp_path / "low_noise_model").mkdir()
(tmp_path / "high_noise_model").mkdir()
from pathlib import Path
low = tmp_path / "low_noise_model"
assert low.exists()
version = "2.2" if low.exists() else "2.1"
@@ -241,7 +256,8 @@ class TestWan21Convert:
def test_wan21_config_saved_correctly(self):
"""Verify config dict has correct fields for Wan2.1."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
d = config.to_dict()
assert d["model_version"] == "2.1"
@@ -254,11 +270,12 @@ class TestWan21Convert:
# Encoder Weight Sanitization Tests
# ---------------------------------------------------------------------------
class TestSanitizeEncoderWeights:
"""Tests for sanitize_wan22_vae_weights with include_encoder."""
def test_exclude_encoder_by_default(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -270,7 +287,7 @@ class TestSanitizeEncoderWeights:
assert not any("encoder" in k or k.startswith("conv1") for k in out)
def test_include_encoder(self):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
@@ -283,25 +300,25 @@ class TestSanitizeEncoderWeights:
assert "conv2.weight" in out
def test_no_unconsumed_keys(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=True)
assert "Unconsumed" not in caplog.text
def test_no_unconsumed_keys_exclude_encoder(self, caplog):
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
weights = {
"encoder.conv1.weight": mx.zeros((8, 1, 3, 3, 3)),
"conv1.weight": mx.zeros((8, 1, 1, 1, 8)),
"conv2.weight": mx.zeros((8, 1, 1, 1, 8)),
}
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan.vae22"):
with caplog.at_level(logging.WARNING, logger="mlx_video.models.wan_2.vae22"):
sanitize_wan22_vae_weights(weights, include_encoder=False)
assert "Unconsumed" not in caplog.text

View File

@@ -2,22 +2,20 @@
import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Integration: end-to-end tiny model forward pass
# ---------------------------------------------------------------------------
class TestEndToEnd:
"""End-to-end test with tiny model (no real weights needed)."""
def test_tiny_model_denoise_step(self):
"""Simulate one denoising step with tiny model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(42)
config = _make_tiny_config()
@@ -45,8 +43,8 @@ class TestEndToEnd:
def test_tiny_model_full_loop(self):
"""Run a complete (tiny) diffusion loop."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(123)
config = _make_tiny_config()
@@ -78,11 +76,12 @@ class TestEndToEnd:
# I2V Mask Tests
# ---------------------------------------------------------------------------
class TestI2VMask:
"""Tests for _build_i2v_mask."""
def test_mask_shapes(self):
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4) # C, T, H, W
patch_size = (1, 2, 2)
@@ -92,7 +91,7 @@ class TestI2VMask:
assert mask_tokens.shape == (1, 20)
def test_first_frame_zero(self):
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (48, 5, 4, 4)
mask, mask_tokens = _build_i2v_mask(z_shape, (1, 2, 2))
@@ -112,7 +111,8 @@ class TestI2VMaskAlignment:
def test_mask_with_ti2v_dimensions(self):
"""Mask should work with TI2V-5B typical dimensions."""
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan_2.generate import _build_i2v_mask
# TI2V: z_dim=48, vae_stride=(4,16,16), patch=(1,2,2)
# 704x1280 → latent 44x80, t_latent=21 for 81 frames
z_shape = (48, 21, 44, 80)
@@ -132,7 +132,8 @@ class TestI2VMaskAlignment:
def test_mask_per_token_timestep(self):
"""Per-token timesteps: first-frame tokens get t=0, rest get t=sigma."""
from mlx_video.generate_wan import _build_i2v_mask
from mlx_video.models.wan_2.generate import _build_i2v_mask
z_shape = (4, 3, 4, 4)
patch_size = (1, 2, 2)
_, mask_tokens = _build_i2v_mask(z_shape, patch_size)
@@ -144,13 +145,16 @@ class TestI2VMaskAlignment:
first_tokens = 1 * 2 * 2 # pt * (H/ph) * (W/pw)
np.testing.assert_allclose(np.array(t_tokens[0, :first_tokens]), 0.0, atol=1e-7)
np.testing.assert_allclose(np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7)
np.testing.assert_allclose(
np.array(t_tokens[0, first_tokens:]), timestep_val, atol=1e-7
)
# ---------------------------------------------------------------------------
# Dimension Alignment Tests
# ---------------------------------------------------------------------------
class TestDimensionAlignment:
"""Tests for automatic dimension alignment in generate_wan."""
@@ -197,7 +201,8 @@ class TestDimensionAlignment:
def test_patchify_valid_after_alignment(self):
"""After alignment, patchify should succeed without reshape errors."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -222,11 +227,16 @@ class TestDimensionAlignment:
patches, grid_size = model._patchify(vid)
mx.eval(patches)
assert patches.ndim == 3 # [1, L, dim]
assert grid_size == (t_latent, h_latent // patch_size[1], w_latent // patch_size[2])
assert grid_size == (
t_latent,
h_latent // patch_size[1],
w_latent // patch_size[2],
)
def test_alignment_with_ti2v_config(self):
"""TI2V-5B uses vae_stride=(4,16,16), patch_size=(1,2,2) → align=32."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_ti2v_5b()
align_h = config.patch_size[1] * config.vae_stride[1]
align_w = config.patch_size[2] * config.vae_stride[2]

View File

@@ -1,9 +1,6 @@
"""Tests for Wan2.2 I2V-14B support."""
import mlx.core as mx
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
@@ -26,7 +23,7 @@ class TestI2VConfig:
"""Test I2V-14B config preset."""
def test_wan22_i2v_14b_preset(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
assert config.model_type == "i2v"
@@ -42,7 +39,7 @@ class TestI2VConfig:
assert config.vae_z_dim == 16
def test_i2v_vs_t2v_differences(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
i2v = WanModelConfig.wan22_i2v_14b()
t2v = WanModelConfig.wan22_t2v_14b()
@@ -54,7 +51,7 @@ class TestI2VConfig:
assert i2v.sample_shift == 5.0 and t2v.sample_shift == 12.0
def test_i2v_serialization_roundtrip(self):
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan22_i2v_14b()
d = config.to_dict()
@@ -69,7 +66,7 @@ class TestModelYParameter:
def test_forward_without_y(self):
"""Standard T2V forward pass (no y) still works."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -88,7 +85,7 @@ class TestModelYParameter:
def test_forward_with_y(self):
"""I2V forward pass with y channel concatenation."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -111,7 +108,7 @@ class TestModelYParameter:
def test_y_none_is_noop(self):
"""Passing y=None should be identical to not passing y."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -132,7 +129,7 @@ class TestModelYParameter:
def test_batched_cfg_with_y(self):
"""Batched CFG (B=2) with y should work."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_i2v_config()
model = WanModel(config)
@@ -145,7 +142,10 @@ class TestModelYParameter:
latents = mx.random.normal((C_noise, F, H, W))
y = mx.random.normal((C_y, F, H, W))
t = mx.array([500.0, 500.0])
ctx = [mx.random.normal((6, config.text_dim)), mx.random.normal((6, config.text_dim))]
ctx = [
mx.random.normal((6, config.text_dim)),
mx.random.normal((6, config.text_dim)),
]
out = model([latents, latents], t, ctx, seq_len, y=[y, y])
mx.eval(out[0], out[1])
@@ -158,16 +158,18 @@ class TestVAEEncoder:
"""Test Wan2.1 VAE encoder."""
def test_encoder3d_instantiation(self):
from mlx_video.models.wan.vae import Encoder3d
from mlx_video.models.wan_2.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
enc = Encoder3d(
dim=32, z_dim=8
) # z_dim=8 (will output 8ch, but WanVAE wraps with z*2)
assert enc.conv1 is not None
assert len(enc.downsamples) > 0
assert len(enc.middle) == 3
def test_encoder3d_output_shape(self):
"""Encoder should downsample spatially by 8x and temporally by 4x."""
from mlx_video.models.wan.vae import Encoder3d
from mlx_video.models.wan_2.vae import Encoder3d
enc = Encoder3d(dim=32, z_dim=8)
# Random input: [B=1, 3, T=5, H=32, W=32]
@@ -184,7 +186,7 @@ class TestVAEEncoder:
def test_wan_vae_encode(self):
"""WanVAE with encoder=True should produce normalized latents."""
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae = WanVAE(z_dim=16, encoder=True)
# Input: [B=1, 3, T=5, H=32, W=32]
@@ -196,20 +198,20 @@ class TestVAEEncoder:
def test_wan_vae_encoder_flag(self):
"""WanVAE without encoder flag should not have encoder attribute."""
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan_2.vae import WanVAE
vae_no_enc = WanVAE(z_dim=4, encoder=False)
assert not hasattr(vae_no_enc, 'encoder')
assert not hasattr(vae_no_enc, "encoder")
vae_enc = WanVAE(z_dim=4, encoder=True)
assert hasattr(vae_enc, 'encoder')
assert hasattr(vae_enc, "encoder")
class TestResampleDownsample:
"""Test downsample modes in Resample."""
def test_downsample2d(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="downsample2d")
x = mx.random.normal((1, 16, 2, 8, 8))
@@ -219,7 +221,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4)
def test_downsample3d(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="downsample3d")
x = mx.random.normal((1, 16, 4, 8, 8))
@@ -229,7 +231,7 @@ class TestResampleDownsample:
assert out.shape == (1, 16, 2, 4, 4)
def test_upsample2d_still_works(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="upsample2d")
x = mx.random.normal((1, 16, 2, 4, 4))
@@ -238,7 +240,7 @@ class TestResampleDownsample:
assert out.shape == (1, 8, 2, 8, 8)
def test_upsample3d_still_works(self):
from mlx_video.models.wan.vae import Resample
from mlx_video.models.wan_2.vae import Resample
r = Resample(dim=16, mode="upsample3d")
x = mx.random.normal((1, 16, 2, 4, 4))
@@ -258,7 +260,9 @@ class TestI2VMaskConstruction:
# Build mask following reference logic
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -272,7 +276,9 @@ class TestI2VMaskConstruction:
t_latent = (num_frames - 1) // 4 + 1 # = 3
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0]
@@ -301,9 +307,9 @@ class TestI2VEndToEndPipeline:
def test_full_i2v_pipeline(self):
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.vae import WanVAE
mx.random.seed(0)
@@ -311,7 +317,9 @@ class TestI2VEndToEndPipeline:
config = _make_tiny_i2v_config()
config.vae_z_dim = 16
config.out_dim = 16 # must match VAE z_dim for decode
config.in_dim = 16 + 4 + 16 # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
config.in_dim = (
16 + 4 + 16
) # noise(out_dim=16) + mask(4) + image(z_dim=16) = 36
model = WanModel(config)
# --- Tiny VAE (with encoder) ---
@@ -323,10 +331,13 @@ class TestI2VEndToEndPipeline:
img = mx.random.uniform(-1, 1, (1, 3, 1, height, width))
# Build video: first frame = image, rest = zeros -> [1, 3, F, H, W]
video = mx.concatenate([
img,
mx.zeros((1, 3, num_frames - 1, height, width)),
], axis=2)
video = mx.concatenate(
[
img,
mx.zeros((1, 3, num_frames - 1, height, width)),
],
axis=2,
)
# --- VAE encode ---
z_video = vae.encode(video) # [1, z_dim, T_lat, H_lat, W_lat]
@@ -341,7 +352,9 @@ class TestI2VEndToEndPipeline:
# --- Build I2V mask (4 channels) ---
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
msk = mx.concatenate([mx.repeat(msk[:, :1], 4, axis=1), msk[:, 1:]], axis=1)
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -397,8 +410,8 @@ class TestDualModelSwitching:
def test_model_selection_by_timestep(self):
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(1)
config = _make_tiny_i2v_config()
@@ -453,7 +466,9 @@ class TestDualModelSwitching:
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
0
)
mx.eval(latents)
# With shift=5.0, early timesteps should be high (>=900), later ones low
@@ -461,17 +476,17 @@ class TestDualModelSwitching:
assert len(low_used_steps) > 0, "Low-noise model was never selected"
# High-noise steps should come before low-noise steps (timesteps decrease)
if high_used_steps and low_used_steps:
assert max(high_used_steps) < min(low_used_steps) or \
min(high_used_steps) < max(low_used_steps), \
"Model switching should happen during the loop"
assert max(high_used_steps) < min(low_used_steps) or min(
high_used_steps
) < max(low_used_steps), "Model switching should happen during the loop"
assert latents.shape == (C_noise, F, H, W)
assert not mx.any(mx.isnan(latents)).item()
def test_guide_scale_tuple_applied_per_model(self):
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(2)
config = _make_tiny_i2v_config()
@@ -515,7 +530,9 @@ class TestDualModelSwitching:
y=[y_i2v, y_i2v],
)
noise_pred = pred[1] + gs * (pred[0] - pred[1])
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(
0
)
mx.eval(latents)
# Verify both guide scales were used
@@ -528,8 +545,8 @@ class TestDualModelSwitching:
def test_single_model_fallback_with_tuple_guide_scale(self):
"""When dual_model=False, guide_scale tuple should use first element."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
mx.random.seed(3)
config = _make_tiny_config()

View File

@@ -4,7 +4,6 @@ import tempfile
from pathlib import Path
import mlx.core as mx
import numpy as np
import pytest
@@ -40,7 +39,9 @@ class TestLoRATypes:
lora_a = mx.ones((2, 4))
lora_b = mx.ones((8, 2))
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
applied = AppliedLoRA(weights=w, strength=0.5)
delta = applied.compute_delta()
# scale=1.0, strength=0.5, B@A = [[2,2,2,2]]*8 (each row sum of 2 ones)
@@ -51,7 +52,9 @@ class TestLoRATypes:
class TestLoRALoader:
"""Test LoRA weight loading from safetensors."""
def _make_lora_file(self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"):
def _make_lora_file(
self, tmp_dir, module_names, rank=4, in_dim=64, out_dim=128, key_format="AB"
):
"""Helper to create a mock LoRA safetensors file."""
weights = {}
for name in module_names:
@@ -133,8 +136,16 @@ class TestWanKeyNormalization:
"""Simulate typical Wan2.2 MLX model weight keys."""
keys = set()
for i in range(2):
for layer in ["self_attn.q", "self_attn.k", "self_attn.v", "self_attn.o",
"cross_attn.q", "cross_attn.k", "cross_attn.v", "cross_attn.o"]:
for layer in [
"self_attn.q",
"self_attn.k",
"self_attn.v",
"self_attn.o",
"cross_attn.q",
"cross_attn.k",
"cross_attn.v",
"cross_attn.o",
]:
keys.add(f"blocks.{i}.{layer}.weight")
keys.add(f"blocks.{i}.ffn.fc1.weight")
keys.add(f"blocks.{i}.ffn.fc2.weight")
@@ -150,7 +161,10 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("blocks.0.self_attn.q", keys) == "blocks.0.self_attn.q"
assert (
_normalize_wan_lora_key("blocks.0.self_attn.q", keys)
== "blocks.0.self_attn.q"
)
def test_strip_diffusion_model_prefix(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
@@ -163,7 +177,9 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
result = _normalize_wan_lora_key("model.diffusion_model.blocks.0.self_attn.k", keys)
result = _normalize_wan_lora_key(
"model.diffusion_model.blocks.0.self_attn.k", keys
)
assert result == "blocks.0.self_attn.k"
def test_ffn_key_mapping(self):
@@ -197,7 +213,9 @@ class TestWanKeyNormalization:
from mlx_video.lora.apply import _normalize_wan_lora_key
keys = self._wan_model_keys()
assert _normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
assert (
_normalize_wan_lora_key("patch_embedding", keys) == "patch_embedding_proj"
)
def test_combined_prefix_and_ffn(self):
from mlx_video.lora.apply import _normalize_wan_lora_key
@@ -219,7 +237,9 @@ class TestApplyLoRA:
# LoRA weights in float32 (typical when loaded from safetensors)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.bfloat16, f"Expected bfloat16, got {result.dtype}"
@@ -230,7 +250,9 @@ class TestApplyLoRA:
original = mx.ones((8, 4), dtype=mx.float16)
lora_a = mx.ones((2, 4), dtype=mx.float32) * 0.1
lora_b = mx.ones((8, 2), dtype=mx.float32) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
assert result.dtype == mx.float16, f"Expected float16, got {result.dtype}"
@@ -241,7 +263,9 @@ class TestApplyLoRA:
original = mx.ones((8, 4))
lora_a = mx.ones((2, 4)) * 0.1
lora_b = mx.ones((8, 2)) * 0.1
w = LoRAWeights(lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test")
w = LoRAWeights(
lora_A=lora_a, lora_B=lora_b, rank=2, alpha=2.0, module_name="test"
)
result = apply_lora_to_linear(original, [(w, 1.0)])
# delta = 1.0 * (B @ A) = ones(8,2)*0.1 @ ones(2,4)*0.1 = 0.02 * ones(8,4)
expected = original + 0.02 * mx.ones((8, 4))
@@ -255,12 +279,16 @@ class TestApplyLoRA:
w1 = LoRAWeights(
lora_A=mx.ones((2, 4)),
lora_B=mx.ones((8, 2)),
rank=2, alpha=2.0, module_name="a",
rank=2,
alpha=2.0,
module_name="a",
)
w2 = LoRAWeights(
lora_A=mx.ones((2, 4)) * 2,
lora_B=mx.ones((8, 2)) * 2,
rank=2, alpha=4.0, module_name="b",
rank=2,
alpha=4.0,
module_name="b",
)
result = apply_lora_to_linear(original, [(w1, 1.0), (w2, 0.5)])
# w1 delta: 1.0 * 1.0 * (ones(8,2) @ ones(2,4)) = 2 * ones(8,4)
@@ -282,7 +310,9 @@ class TestApplyLoRA:
w = LoRAWeights(
lora_A=mx.ones((4, 64)) * 0.01,
lora_B=mx.ones((128, 4)) * 0.01,
rank=4, alpha=4.0, module_name="blocks.0.self_attn.q",
rank=4,
alpha=4.0,
module_name="blocks.0.self_attn.q",
)
module_to_loras = {"blocks.0.self_attn.q": [(w, 1.0)]}
result = apply_loras_to_weights(model_weights, module_to_loras)
@@ -301,7 +331,7 @@ class TestEndToEnd:
"""End-to-end LoRA loading and application."""
def test_load_and_apply_loras(self):
from mlx_video.convert_wan import load_and_apply_loras
from mlx_video.models.wan_2.convert import load_and_apply_loras
with tempfile.TemporaryDirectory() as tmp:
# Create mock LoRA safetensors
@@ -319,9 +349,7 @@ class TestEndToEnd:
"blocks.0.self_attn.k.weight": mx.ones((128, 64)),
}
result = load_and_apply_loras(
model_weights, [(str(lora_path), 1.0)]
)
result = load_and_apply_loras(model_weights, [(str(lora_path), 1.0)])
# q weight should be modified, k unchanged
assert not mx.array_equal(

View File

@@ -3,18 +3,17 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Sinusoidal Embedding Tests
# ---------------------------------------------------------------------------
class TestSinusoidalEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.arange(10).astype(mx.float32)
emb = sinusoidal_embedding_1d(256, pos)
mx.eval(emb)
@@ -22,7 +21,8 @@ class TestSinusoidalEmbedding:
def test_position_zero(self):
"""Position 0 should have cos=1 for all dims and sin=0."""
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0])
emb = sinusoidal_embedding_1d(64, pos)
mx.eval(emb)
@@ -33,7 +33,8 @@ class TestSinusoidalEmbedding:
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
def test_different_positions_differ(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 999.0])
emb = sinusoidal_embedding_1d(128, pos)
mx.eval(emb)
@@ -46,9 +47,11 @@ class TestSinusoidalEmbedding:
# Head Tests
# ---------------------------------------------------------------------------
class TestHead:
def test_output_shape(self):
from mlx_video.models.wan.model import Head
from mlx_video.models.wan_2.wan_2 import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
B, L = 1, 24
x = mx.random.normal((B, L, 64))
@@ -59,7 +62,8 @@ class TestHead:
assert out.shape == (B, L, expected_proj_dim)
def test_modulation_shape(self):
from mlx_video.models.wan.model import Head
from mlx_video.models.wan_2.wan_2 import Head
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
assert head.modulation.shape == (1, 2, 64)
@@ -68,19 +72,22 @@ class TestHead:
# WanModel (Tiny) Tests
# ---------------------------------------------------------------------------
class TestWanModel:
def setup_method(self):
mx.random.seed(42)
def test_instantiation(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(model.parameters()))
assert num_params > 0
def test_patchify_shape(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
# Input: [C=4, F=1, H=4, W=4]
@@ -92,7 +99,8 @@ class TestWanModel:
assert patches.shape == (1, 1 * 2 * 2, config.dim)
def test_patchify_various_sizes(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
for f, h, w in [(1, 4, 4), (2, 6, 8), (3, 4, 6)]:
@@ -107,7 +115,8 @@ class TestWanModel:
def test_unpatchify_inverse(self):
"""Patchify then unpatchify should reconstruct original spatial dims."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 2, 4, 6
@@ -122,7 +131,8 @@ class TestWanModel:
assert out[0].shape == (config.out_dim, F, H, W)
def test_forward_pass(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
@@ -139,7 +149,8 @@ class TestWanModel:
assert out[0].shape == (C, F, H, W)
def test_forward_batch(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
@@ -148,7 +159,10 @@ class TestWanModel:
x_list = [mx.random.normal((C, F, H, W)), mx.random.normal((C, F, H, W))]
t = mx.array([500.0, 200.0])
context = [mx.random.normal((6, config.text_dim)), mx.random.normal((4, config.text_dim))]
context = [
mx.random.normal((6, config.text_dim)),
mx.random.normal((4, config.text_dim)),
]
out = model(x_list, t, context, seq_len)
mx.eval(out[0], out[1])
@@ -157,13 +171,18 @@ class TestWanModel:
assert o.shape == (C, F, H, W)
def test_output_is_float32(self):
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
C, F, H, W = config.in_dim, 1, 4, 4
seq_len = (F // 1) * (H // 2) * (W // 2)
out = model([mx.random.normal((C, F, H, W))], mx.array([100.0]),
[mx.random.normal((4, config.text_dim))], seq_len)
out = model(
[mx.random.normal((C, F, H, W))],
mx.array([100.0]),
[mx.random.normal((4, config.text_dim))],
seq_len,
)
mx.eval(out[0])
assert out[0].dtype == mx.float32
@@ -172,6 +191,7 @@ class TestWanModel:
# Wan2.1 Model Tests
# ---------------------------------------------------------------------------
class TestWan21Model:
"""Test tiny Wan2.1-style model (single model mode)."""
@@ -180,7 +200,8 @@ class TestWan21Model:
def _make_tiny_wan21_config(self):
"""Create a tiny config mimicking Wan2.1 (single model)."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_14b()
# Override to tiny values
config.dim = 64
@@ -196,7 +217,8 @@ class TestWan21Model:
def _make_tiny_wan21_1_3b_config(self):
"""Create a tiny config mimicking Wan2.1 1.3B."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
config = WanModelConfig.wan21_t2v_1_3b()
# Override to tiny values (preserve 1.3B head structure: 12 heads)
config.dim = 48
@@ -212,7 +234,7 @@ class TestWan21Model:
def test_wan21_tiny_model_forward(self):
"""Forward pass with Wan2.1 tiny config."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = self._make_tiny_wan21_config()
model = WanModel(config)
@@ -230,7 +252,7 @@ class TestWan21Model:
def test_wan21_1_3b_tiny_model_forward(self):
"""Forward pass with Wan2.1 1.3B tiny config."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = self._make_tiny_wan21_1_3b_config()
model = WanModel(config)
@@ -248,8 +270,8 @@ class TestWan21Model:
def test_wan21_single_model_loop(self):
"""Full diffusion loop with single model (Wan2.1 style)."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.wan_2 import WanModel
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
config = self._make_tiny_wan21_config()
model = WanModel(config)
@@ -271,7 +293,9 @@ class TestWan21Model:
for i in range(3):
t = sched.timesteps[i]
pred_cond = model([latents], mx.array([t.item()]), [context], seq_len)[0]
pred_uncond = model([latents], mx.array([t.item()]), [context_null], seq_len)[0]
pred_uncond = model(
[latents], mx.array([t.item()]), [context_null], seq_len
)[0]
pred = pred_uncond + gs * (pred_cond - pred_uncond)
latents = sched.step(pred[None], t, latents[None]).squeeze(0)
mx.eval(latents)
@@ -281,7 +305,7 @@ class TestWan21Model:
def test_wan21_vs_wan22_config_differences(self):
"""Verify key differences between Wan2.1 and Wan2.2 configs."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan_2.config import WanModelConfig
c21 = WanModelConfig.wan21_t2v_14b()
c22 = WanModelConfig.wan22_t2v_14b()
@@ -304,25 +328,26 @@ class TestWan21Model:
# Per-Token Timestep Tests
# ---------------------------------------------------------------------------
class TestPerTokenTimestep:
"""Tests for per-token sinusoidal embedding."""
def test_1d_unchanged(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([0.0, 100.0, 500.0])
emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (3, 256)
def test_2d_per_token(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
emb = sinusoidal_embedding_1d(256, pos)
assert emb.shape == (2, 3, 256)
def test_consistency(self):
from mlx_video.models.wan.model import sinusoidal_embedding_1d
from mlx_video.models.wan_2.wan_2 import sinusoidal_embedding_1d
pos_1d = mx.array([0.0, 100.0])
emb_1d = sinusoidal_embedding_1d(256, pos_1d)

View File

@@ -1,67 +1,82 @@
"""Tests for Wan model quantization pipeline."""
import json
import mlx.core as mx
import mlx.nn as nn
import mlx.utils
import numpy as np
import pytest
from wan_test_helpers import _make_tiny_config
# ---------------------------------------------------------------------------
# Quantize Predicate Tests
# ---------------------------------------------------------------------------
class TestQuantizePredicate:
def test_matches_self_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.self_attn.{suffix}"
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_cross_attention_layers(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for suffix in ["q", "k", "v", "o"]:
path = f"blocks.0.cross_attn.{suffix}"
assert _quantize_predicate(path, mock_linear), f"Should match {path}"
def test_matches_ffn_layers(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
assert _quantize_predicate("blocks.0.ffn.fc1", mock_linear)
assert _quantize_predicate("blocks.0.ffn.fc2", mock_linear)
def test_rejects_embeddings(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
for path in ["patch_embedding_proj", "text_embedding_fc1", "time_embedding.fc1"]:
for path in [
"patch_embedding_proj",
"text_embedding_fc1",
"time_embedding.fc1",
]:
assert not _quantize_predicate(path, mock_linear), f"Should reject {path}"
def test_rejects_norms(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_norm = nn.RMSNorm(64)
assert not _quantize_predicate("blocks.0.self_attn.norm_q", mock_norm)
def test_rejects_non_quantizable_modules(self):
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_norm = nn.RMSNorm(64)
# Even if path matches, module must have to_quantized
assert not _quantize_predicate("blocks.0.self_attn.q", mock_norm)
def test_all_10_patterns_covered(self):
"""Verify exactly 10 layer patterns are targeted."""
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
mock_linear = nn.Linear(64, 64)
patterns = [
"blocks.0.self_attn.q", "blocks.0.self_attn.k",
"blocks.0.self_attn.v", "blocks.0.self_attn.o",
"blocks.0.cross_attn.q", "blocks.0.cross_attn.k",
"blocks.0.cross_attn.v", "blocks.0.cross_attn.o",
"blocks.0.ffn.fc1", "blocks.0.ffn.fc2",
"blocks.0.self_attn.q",
"blocks.0.self_attn.k",
"blocks.0.self_attn.v",
"blocks.0.self_attn.o",
"blocks.0.cross_attn.q",
"blocks.0.cross_attn.k",
"blocks.0.cross_attn.v",
"blocks.0.cross_attn.o",
"blocks.0.ffn.fc1",
"blocks.0.ffn.fc2",
]
matched = [p for p in patterns if _quantize_predicate(p, mock_linear)]
assert len(matched) == 10
@@ -71,11 +86,12 @@ class TestQuantizePredicate:
# Quantize Round-Trip Tests
# ---------------------------------------------------------------------------
class TestQuantizeRoundTrip:
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
"""Helper: create model, quantize, save to tmp_path."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config)
nn.quantize(
@@ -100,9 +116,11 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 4, "group_size": 64},
)
@@ -118,9 +136,11 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, saved_weights = self._quantize_and_save(config, tmp_path, bits=8)
from mlx_video.models.wan.loading import load_wan_model
from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 8, "group_size": 64},
)
@@ -131,9 +151,11 @@ class TestQuantizeRoundTrip:
config = _make_tiny_config()
model_path, _ = self._quantize_and_save(config, tmp_path, bits=4)
from mlx_video.models.wan.loading import load_wan_model
from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(
model_path, config,
model_path,
config,
quantization={"bits": 4, "group_size": 64},
)
@@ -142,7 +164,7 @@ class TestQuantizeRoundTrip:
def test_loading_without_quantization_flag(self, tmp_path):
"""Loading a non-quantized model should have standard Linear layers."""
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -150,7 +172,8 @@ class TestQuantizeRoundTrip:
model_path = tmp_path / "model.safetensors"
mx.save_safetensors(str(model_path), weights_dict)
from mlx_video.models.wan.loading import load_wan_model
from mlx_video.models.wan_2.utils import load_wan_model
loaded = load_wan_model(model_path, config, quantization=None)
assert isinstance(loaded.blocks[0].self_attn.q, nn.Linear)
@@ -161,10 +184,11 @@ class TestQuantizeRoundTrip:
# Quantized Inference Tests
# ---------------------------------------------------------------------------
class TestQuantizedInference:
def _make_quantized_model(self, config, bits=4):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan_2.wan_2 import WanModel
model = WanModel(config)
nn.quantize(
@@ -214,8 +238,8 @@ class TestQuantizedInference:
def test_quantized_output_differs_from_unquantized(self):
"""Sanity check: quantization should change the weights."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_predicate
from mlx_video.models.wan_2.convert import _quantize_predicate
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
mx.random.seed(42)
@@ -243,11 +267,12 @@ class TestQuantizedInference:
# Config Metadata Tests
# ---------------------------------------------------------------------------
class TestQuantizationConfig:
def test_config_metadata_written(self, tmp_path):
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -270,8 +295,8 @@ class TestQuantizationConfig:
assert cfg["quantization"]["group_size"] == 64
def test_config_metadata_8bit(self, tmp_path):
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()
model = WanModel(config)
@@ -291,8 +316,8 @@ class TestQuantizationConfig:
def test_dual_model_quantization(self, tmp_path):
"""Verify dual-model quantization writes both model files."""
from mlx_video.models.wan.model import WanModel
from mlx_video.convert_wan import _quantize_saved_model
from mlx_video.models.wan_2.convert import _quantize_saved_model
from mlx_video.models.wan_2.wan_2 import WanModel
config = _make_tiny_config()

View File

@@ -27,8 +27,8 @@ class TestRoPEFrequencyConstruction:
def _get_model_freqs(self, dim=64, num_heads=4):
"""Instantiate a tiny WanModel and return its .freqs tensor."""
from mlx_video.models.wan.config import WanModelConfig
from mlx_video.models.wan.model import WanModel
from mlx_video.models.wan_2.config import WanModelConfig
from mlx_video.models.wan_2.wan_2 import WanModel
config = WanModelConfig()
config.dim = dim
@@ -51,22 +51,27 @@ class TestRoPEFrequencyConstruction:
def test_three_call_vs_single_call_differ(self):
"""Three separate rope_params calls must differ from single call."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128 # head_dim for all Wan models
# Reference: three separate calls
correct = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
correct = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Wrong: single call
wrong = rope_params(1024, d)
mx.eval(correct, wrong)
assert correct.shape == wrong.shape
diff = np.abs(np.array(correct) - np.array(wrong)).max()
assert diff > 0.1, f"Three-call and single-call should differ significantly, got max diff {diff}"
assert (
diff > 0.1
), f"Three-call and single-call should differ significantly, got max diff {diff}"
def test_each_axis_starts_at_frequency_one(self):
"""Each axis (temporal/height/width) should have cos=1, sin=0 at position 0.
@@ -74,14 +79,17 @@ class TestRoPEFrequencyConstruction:
This verifies each axis gets its own independent frequency range
starting from theta^0 = 1.0 (i.e., exponent 0/dim).
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -95,29 +103,35 @@ class TestRoPEFrequencyConstruction:
# At position 1, each axis should have its FIRST frequency near cos(1/theta^0)=cos(1)
# Temporal axis first freq
np.testing.assert_allclose(f[1, 0, 0], np.cos(1.0), atol=1e-5,
err_msg="temporal[0] cos at pos 1")
np.testing.assert_allclose(
f[1, 0, 0], np.cos(1.0), atol=1e-5, err_msg="temporal[0] cos at pos 1"
)
# Height axis first freq (starts at index d_t)
np.testing.assert_allclose(f[1, d_t, 0], np.cos(1.0), atol=1e-5,
err_msg="height[0] cos at pos 1")
np.testing.assert_allclose(
f[1, d_t, 0], np.cos(1.0), atol=1e-5, err_msg="height[0] cos at pos 1"
)
# Width axis first freq (starts at index d_t + d_h)
np.testing.assert_allclose(f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5,
err_msg="width[0] cos at pos 1")
np.testing.assert_allclose(
f[1, d_t + d_h, 0], np.cos(1.0), atol=1e-5, err_msg="width[0] cos at pos 1"
)
def test_height_width_frequencies_identical(self):
"""Height and width axes should have identical frequency tables.
Both use rope_params(1024, 2*(d//6)) = rope_params(1024, 42).
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
d_h_dim = 2 * (d // 6) # 42
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim),
rope_params(1024, d_h_dim),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, d_h_dim),
rope_params(1024, d_h_dim),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -125,8 +139,8 @@ class TestRoPEFrequencyConstruction:
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
height_freqs = f[:, d_t:d_t + d_h]
width_freqs = f[:, d_t + d_h:]
height_freqs = f[:, d_t : d_t + d_h]
width_freqs = f[:, d_t + d_h :]
np.testing.assert_array_equal(height_freqs, width_freqs)
def test_frequency_range_per_axis(self):
@@ -136,14 +150,17 @@ class TestRoPEFrequencyConstruction:
axis should be 1.0 (theta^0). A single-call approach would give height
starting at ~0.04 and width at ~0.002 instead of 1.0.
"""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs)
f = np.array(freqs)
@@ -157,37 +174,46 @@ class TestRoPEFrequencyConstruction:
pos1_h = f[1, d_t, 0] # height first freq
pos1_w = f[1, d_t + d_h, 0] # width first freq
assert pos1_t > 0.5, f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert (
pos1_t > 0.5
), f"Temporal first freq at pos 1 should be >0.5, got {pos1_t}"
assert pos1_h > 0.5, f"Height first freq at pos 1 should be >0.5, got {pos1_h}"
assert pos1_w > 0.5, f"Width first freq at pos 1 should be >0.5, got {pos1_w}"
def test_model_freqs_match_manual_construction(self):
"""WanModel.freqs should match manually constructed three-call freqs."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
freqs_model, head_dim = self._get_model_freqs(dim=64, num_heads=4)
d = head_dim # 16
freqs_manual = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs_manual = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(freqs_model, freqs_manual)
np.testing.assert_array_equal(
np.array(freqs_model), np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction"
np.array(freqs_model),
np.array(freqs_manual),
err_msg="WanModel.freqs should use three-call construction",
)
def test_model_freqs_14b_dimensions(self):
"""Verify freq dimensions for 14B-scale head_dim=128."""
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)), # dim=44 → 22 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
rope_params(1024, 2 * (d // 6)), # dim=42 → 21 freq pairs
],
axis=1,
)
mx.eval(freqs)
assert freqs.shape == (1024, 64, 2)
@@ -206,7 +232,8 @@ class TestRoPEFrequencyMatchesReference:
@pytest.fixture
def has_torch(self):
try:
import torch
pass
return True
except ImportError:
pytest.skip("PyTorch not installed")
@@ -214,7 +241,8 @@ class TestRoPEFrequencyMatchesReference:
def test_freqs_match_pytorch_reference(self, has_torch):
"""Numerically compare MLX and PyTorch frequency tables."""
import torch
from mlx_video.models.wan.rope import rope_params
from mlx_video.models.wan_2.rope import rope_params
d = 128
@@ -222,22 +250,30 @@ class TestRoPEFrequencyMatchesReference:
def pt_rope_params(max_seq_len, dim, theta=10000):
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
1.0
/ torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
ref = torch.cat([
pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
], dim=1)
ref = torch.cat(
[
pt_rope_params(1024, d - 4 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
pt_rope_params(1024, 2 * (d // 6)),
],
dim=1,
)
# MLX
ours = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
ours = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
mx.eval(ours)
our_cos = np.array(ours[:, :, 0])
@@ -245,10 +281,12 @@ class TestRoPEFrequencyMatchesReference:
ref_cos = ref.real.float().numpy()
ref_sin = ref.imag.float().numpy()
np.testing.assert_allclose(our_cos, ref_cos, atol=1e-6,
err_msg="cos mismatch vs PyTorch reference")
np.testing.assert_allclose(our_sin, ref_sin, atol=1e-6,
err_msg="sin mismatch vs PyTorch reference")
np.testing.assert_allclose(
our_cos, ref_cos, atol=1e-6, err_msg="cos mismatch vs PyTorch reference"
)
np.testing.assert_allclose(
our_sin, ref_sin, atol=1e-6, err_msg="sin mismatch vs PyTorch reference"
)
class TestRoPEApplyWithCorrectFreqs:
@@ -260,14 +298,17 @@ class TestRoPEApplyWithCorrectFreqs:
This is the key property that was broken by the single-call bug:
height/width frequencies were too low to distinguish nearby positions.
"""
from mlx_video.models.wan.rope import rope_params, rope_apply
from mlx_video.models.wan_2.rope import rope_apply, rope_params
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
B, N = 1, 4
F, H, W = 1, 4, 4
@@ -289,30 +330,37 @@ class TestRoPEApplyWithCorrectFreqs:
# Max diff should be >0.5 for both axes. With the bug, height was ~0.04
# and width was ~0.002. With correct freqs, both are ~1.3.
assert height_diff > 0.5, (
f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
)
assert width_diff > 0.5, (
f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
)
assert (
height_diff > 0.5
), f"Adjacent height positions should differ significantly, got {height_diff:.4f}"
assert (
width_diff > 0.5
), f"Adjacent width positions should differ significantly, got {width_diff:.4f}"
# Height and width should have identical frequency tables → same diffs
np.testing.assert_allclose(height_diff, width_diff, rtol=1e-5,
err_msg="Height and width should use identical frequency tables")
np.testing.assert_allclose(
height_diff,
width_diff,
rtol=1e-5,
err_msg="Height and width should use identical frequency tables",
)
def test_precomputed_matches_online(self):
"""rope_precompute_cos_sin + rope_apply should match non-precomputed path."""
from mlx_video.models.wan.rope import (
from mlx_video.models.wan_2.rope import (
rope_apply,
rope_params,
rope_precompute_cos_sin,
)
d = 128
freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
B, N = 2, 4
F, H, W = 2, 3, 4
@@ -329,6 +377,8 @@ class TestRoPEApplyWithCorrectFreqs:
mx.eval(out_online, out_precomp)
np.testing.assert_allclose(
np.array(out_online), np.array(out_precomp), atol=1e-5,
err_msg="Precomputed and online RoPE should match"
np.array(out_online),
np.array(out_precomp),
atol=1e-5,
err_msg="Precomputed and online RoPE should match",
)

View File

@@ -6,21 +6,23 @@ import mlx.core as mx
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# Euler Scheduler Tests
# ---------------------------------------------------------------------------
class TestFlowMatchEulerScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
assert sched.num_train_timesteps == 1000
assert sched.timesteps is None
assert sched.sigmas is None
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -28,7 +30,8 @@ class TestFlowMatchEulerScheduler:
assert sched.sigmas.shape == (41,) # 40 steps + terminal
def test_timesteps_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(40, shift=12.0)
mx.eval(sched.timesteps)
@@ -37,7 +40,8 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(ts) < 0), f"Timesteps not decreasing: {ts[:5]}..."
def test_sigmas_decreasing(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=1.0)
mx.eval(sched.sigmas)
@@ -45,7 +49,8 @@ class TestFlowMatchEulerScheduler:
assert np.all(np.diff(sigmas) <= 0), "Sigmas not decreasing"
def test_terminal_sigma_is_zero(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.sigmas)
@@ -53,7 +58,8 @@ class TestFlowMatchEulerScheduler:
def test_shift_effect(self):
"""Larger shift should push sigmas toward higher values."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched1 = FlowMatchEulerScheduler()
sched2 = FlowMatchEulerScheduler()
sched1.set_timesteps(20, shift=1.0)
@@ -64,7 +70,8 @@ class TestFlowMatchEulerScheduler:
assert mean2 > mean1, "Higher shift should push sigmas higher"
def test_step_euler(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(10, shift=1.0)
mx.eval(sched.sigmas)
@@ -82,11 +89,14 @@ class TestFlowMatchEulerScheduler:
# Euler: x_next = x + (sigma_next - sigma) * v
expected = 1.0 + (sigma_next - sigma) * 0.5
np.testing.assert_allclose(
np.array(result).flatten()[0], expected, rtol=1e-4,
np.array(result).flatten()[0],
expected,
rtol=1e-4,
)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
assert sched._step_index == 0
@@ -98,7 +108,8 @@ class TestFlowMatchEulerScheduler:
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -110,7 +121,8 @@ class TestFlowMatchEulerScheduler:
@pytest.mark.parametrize("steps", [10, 20, 40, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(steps, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -119,7 +131,8 @@ class TestFlowMatchEulerScheduler:
def test_full_denoise_loop(self):
"""Run a complete denoise loop with zero velocity -> sample unchanged."""
from mlx_video.models.wan.scheduler import FlowMatchEulerScheduler
from mlx_video.models.wan_2.scheduler import FlowMatchEulerScheduler
sched = FlowMatchEulerScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -140,23 +153,27 @@ class TestComputeSigmas:
"""Tests for the shared _compute_sigmas helper."""
def test_length(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert len(sigmas) == 21 # num_steps + terminal
def test_terminal_zero(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
assert sigmas[-1] == 0.0
def test_starts_near_one(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
# Reference applies shift twice, so sigma[0] ≈ 0.99996 (not exactly 1.0)
np.testing.assert_allclose(sigmas[0], 1.0, atol=1e-3)
def test_decreasing(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(20, shift=5.0)
assert np.all(np.diff(sigmas) <= 0)
@@ -168,7 +185,8 @@ class TestComputeSigmas:
sigma_max/sigma_min come from the *unshifted* training schedule, and the
shift is applied only once (single-shift).
"""
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
steps, shift, N = 50, 5.0, 1000
sigmas = _compute_sigmas(steps, shift, N)
# Official single-shift: unshifted bounds, then shift once
@@ -182,7 +200,8 @@ class TestComputeSigmas:
np.testing.assert_allclose(sigmas, official, atol=1e-6)
def test_shift_one_is_near_linear(self):
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(10, shift=1.0)
# With shift=1, f(sigma)=sigma, but sigma_max = 0.999 (from alpha schedule)
# so schedule is nearly linear from ~0.999 to 0
@@ -191,11 +210,12 @@ class TestComputeSigmas:
def test_all_schedulers_same_sigmas(self):
"""All three schedulers should produce identical sigma schedules."""
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
@@ -209,11 +229,12 @@ class TestComputeSigmas:
np.testing.assert_allclose(np.array(s.sigmas), ref, atol=1e-6)
def test_all_schedulers_same_timesteps(self):
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
)
scheds = [
FlowMatchEulerScheduler(1000),
FlowDPMPP2MScheduler(1000),
@@ -234,13 +255,15 @@ class TestComputeSigmas:
class TestFlowDPMPP2MScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
assert sched.num_train_timesteps == 1000
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -248,7 +271,8 @@ class TestFlowDPMPP2MScheduler:
assert sched.sigmas.shape == (21,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 4, 1, 2, 2))
@@ -260,7 +284,8 @@ class TestFlowDPMPP2MScheduler:
assert sched._step_index == 2
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -271,7 +296,8 @@ class TestFlowDPMPP2MScheduler:
def test_full_loop_finite(self):
"""Full loop with constant velocity should produce finite output."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -283,7 +309,8 @@ class TestFlowDPMPP2MScheduler:
def test_first_step_is_first_order(self):
"""First step should use 1st-order (no prev_x0 available)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 2, 4, 4))
@@ -297,7 +324,8 @@ class TestFlowDPMPP2MScheduler:
def test_second_step_uses_correction(self):
"""After first step, DPM++ should have stored prev_x0 for correction."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
@@ -314,11 +342,14 @@ class TestFlowDPMPP2MScheduler:
x0_after_second = sched._prev_x0
assert x0_after_second is not None
# The stored x0 should differ from the first step's
assert not np.allclose(np.array(x0_after_first), np.array(x0_after_second), atol=1e-6)
assert not np.allclose(
np.array(x0_after_first), np.array(x0_after_second), atol=1e-6
)
def test_denoise_to_target(self):
"""Perfect oracle should denoise to target with any solver."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
@@ -332,7 +363,8 @@ class TestFlowDPMPP2MScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -341,7 +373,8 @@ class TestFlowDPMPP2MScheduler:
def test_terminal_sigma_produces_x0(self):
"""When sigma_next=0 the scheduler should return x0 directly."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sched = FlowDPMPP2MScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1)) * 3.0
@@ -361,14 +394,16 @@ class TestFlowDPMPP2MScheduler:
class TestFlowUniPCScheduler:
def test_initialization(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched.num_train_timesteps == 1000
assert sched.solver_order == 2
assert sched.lower_order_final is True
def test_set_timesteps(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(30, shift=12.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -376,7 +411,8 @@ class TestFlowUniPCScheduler:
assert sched.sigmas.shape == (31,)
def test_step_index_increments(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -386,7 +422,8 @@ class TestFlowUniPCScheduler:
assert sched._step_index == 1
def test_reset(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 1, 1))
@@ -398,7 +435,8 @@ class TestFlowUniPCScheduler:
assert all(m is None for m in sched._model_outputs)
def test_full_loop_finite(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(10, shift=1.0)
sample = mx.ones((1, 2, 1, 2, 2))
@@ -410,7 +448,8 @@ class TestFlowUniPCScheduler:
def test_corrector_not_applied_first_step(self):
"""First step should skip the corrector (no history)."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 4, 1, 2, 2))
@@ -423,7 +462,8 @@ class TestFlowUniPCScheduler:
def test_corrector_applied_after_first_step(self):
"""Steps after the first should use the corrector when enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 4, 4))
@@ -435,7 +475,8 @@ class TestFlowUniPCScheduler:
assert sched._lower_order_nums >= 2
def test_denoise_to_target(self):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(20, shift=5.0)
target = mx.zeros((1, 2, 1, 4, 4))
@@ -449,7 +490,8 @@ class TestFlowUniPCScheduler:
@pytest.mark.parametrize("steps", [5, 10, 20, 50])
def test_various_step_counts(self, steps):
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
sched.set_timesteps(steps, shift=5.0)
mx.eval(sched.timesteps, sched.sigmas)
@@ -458,7 +500,8 @@ class TestFlowUniPCScheduler:
def test_disable_corrector(self):
"""Disabling corrector on step 0 should still work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(use_corrector=True, disable_corrector=[0])
sched.set_timesteps(5, shift=1.0)
sample = mx.ones((1, 1, 1, 2, 2))
@@ -470,7 +513,8 @@ class TestFlowUniPCScheduler:
def test_solver_order_3(self):
"""Order 3 should work without error."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler(solver_order=3, use_corrector=True)
sched.set_timesteps(10, shift=5.0)
sample = mx.random.normal((1, 2, 1, 2, 2))
@@ -483,10 +527,11 @@ class TestFlowUniPCScheduler:
def test_corrector_rhos_c_not_hardcoded(self):
"""Corrector rhos_c should be computed via linalg.solve, not hardcoded 0.5."""
import math
# For 50-step schedule with shift=5.0, order 2 corrector at step 5:
# rhos_c[0] (history) should be ~0.07, NOT 0.5
# rhos_c[1] (D1_t) should be ~0.45, NOT 0.5
from mlx_video.models.wan.scheduler import _compute_sigmas
from mlx_video.models.wan_2.scheduler import _compute_sigmas
sigmas = _compute_sigmas(50, shift=5.0)
@@ -525,16 +570,23 @@ class TestFlowUniPCScheduler:
rhos_c = np.linalg.solve(R, b)
# History weight should be small (~0.07-0.09), not 0.5
assert rhos_c[0] < 0.15, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert rhos_c[0] > 0.0, f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
assert (
rhos_c[0] < 0.15
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} too large"
assert (
rhos_c[0] > 0.0
), f"Step {step_idx}: rhos_c[0]={rhos_c[0]:.4f} should be positive"
# D1_t weight should be ~0.42-0.45, not 0.5
assert 0.3 < rhos_c[1] < 0.5, f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
assert (
0.3 < rhos_c[1] < 0.5
), f"Step {step_idx}: rhos_c[1]={rhos_c[1]:.4f} out of range"
# ---------------------------------------------------------------------------
# Scheduler Coherence Tests
# ---------------------------------------------------------------------------
class TestSchedulerCoherence:
"""Tests that Euler, DPM++, and UniPC schedulers produce coherent results.
@@ -545,7 +597,7 @@ class TestSchedulerCoherence:
@staticmethod
def _make_schedulers(steps=10, shift=5.0):
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowMatchEulerScheduler,
FlowUniPCScheduler,
@@ -599,11 +651,15 @@ class TestSchedulerCoherence:
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
results["dpm++"],
results["euler"],
atol=1e-5,
err_msg="DPM++ step 0 should match Euler",
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
results["unipc"],
results["euler"],
atol=1e-5,
err_msg="UniPC step 0 should match Euler",
)
@@ -621,11 +677,15 @@ class TestSchedulerCoherence:
unipc_r = scheds["unipc"].step(vel, scheds["unipc"].timesteps[0], noise)
mx.eval(euler_r, dpm_r, unipc_r)
np.testing.assert_allclose(
np.array(dpm_r), np.array(euler_r), atol=1e-5,
np.array(dpm_r),
np.array(euler_r),
atol=1e-5,
err_msg=f"DPM++ step 0 differs from Euler at shift={shift}",
)
np.testing.assert_allclose(
np.array(unipc_r), np.array(euler_r), atol=1e-5,
np.array(unipc_r),
np.array(euler_r),
atol=1e-5,
err_msg=f"UniPC step 0 differs from Euler at shift={shift}",
)
@@ -644,7 +704,9 @@ class TestSchedulerCoherence:
latents = sched.step(v, sched.timesteps[i], latents)
mx.eval(latents)
np.testing.assert_allclose(
np.array(latents), 0.0, atol=1e-3,
np.array(latents),
0.0,
atol=1e-3,
err_msg=f"{name} did not converge to target with oracle",
)
@@ -669,12 +731,12 @@ class TestSchedulerCoherence:
# Higher-order solvers should not be significantly worse than Euler
# (add small epsilon to handle near-zero errors from floating point noise)
eps = 1e-6
assert errors["dpm++"] <= errors["euler"] * 1.5 + eps, (
f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert errors["unipc"] <= errors["euler"] * 1.5 + eps, (
f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
)
assert (
errors["dpm++"] <= errors["euler"] * 1.5 + eps
), f"DPM++ error {errors['dpm++']:.6f} much worse than Euler {errors['euler']:.6f}"
assert (
errors["unipc"] <= errors["euler"] * 1.5 + eps
), f"UniPC error {errors['unipc']:.6f} much worse than Euler {errors['euler']:.6f}"
def test_multistep_trajectory_similar_magnitude(self):
"""Over a full denoising loop with constant velocity, all solvers
@@ -696,9 +758,9 @@ class TestSchedulerCoherence:
# All solvers should produce results within the same order of magnitude
vals = list(final_means.values())
ratio = max(vals) / max(min(vals), 1e-10)
assert ratio < 10.0, (
f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
)
assert (
ratio < 10.0
), f"Scheduler outputs diverge too much: {final_means}, ratio={ratio:.1f}"
def test_intermediate_values_finite(self):
"""Every intermediate latent value must be finite for all solvers."""
@@ -712,33 +774,33 @@ class TestSchedulerCoherence:
vel = mx.random.normal(shape)
latents = sched.step(vel, sched.timesteps[i], latents)
mx.eval(latents)
assert np.isfinite(np.array(latents)).all(), (
f"{name} produced non-finite values at step {i}"
)
assert np.isfinite(
np.array(latents)
).all(), f"{name} produced non-finite values at step {i}"
def test_lambda_boundary_values(self):
"""_lambda must return -inf at sigma=1.0 and +inf at sigma=0.0."""
from mlx_video.models.wan.scheduler import (
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
assert cls._lambda(1.0) == -math.inf, (
f"{cls.__name__}._lambda(1.0) should be -inf"
)
assert cls._lambda(0.0) == math.inf, (
f"{cls.__name__}._lambda(0.0) should be +inf"
)
assert (
cls._lambda(1.0) == -math.inf
), f"{cls.__name__}._lambda(1.0) should be -inf"
assert (
cls._lambda(0.0) == math.inf
), f"{cls.__name__}._lambda(0.0) should be +inf"
# Interior values should be finite
lam = cls._lambda(0.5)
assert math.isfinite(lam) and lam == 0.0, (
f"{cls.__name__}._lambda(0.5) should be 0.0"
)
assert (
math.isfinite(lam) and lam == 0.0
), f"{cls.__name__}._lambda(0.5) should be 0.0"
def test_lambda_monotonically_decreasing(self):
"""_lambda(sigma) should decrease as sigma increases (more noise → lower SNR)."""
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler
from mlx_video.models.wan_2.scheduler import FlowDPMPP2MScheduler
sigmas = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]
lambdas = [FlowDPMPP2MScheduler._lambda(s) for s in sigmas]
@@ -770,7 +832,9 @@ class TestSchedulerCoherence:
result = scheds[name].step(vel, scheds[name].timesteps[0], sample)
mx.eval(result)
np.testing.assert_allclose(
np.array(result), np.array(expected), atol=5e-4,
np.array(result),
np.array(expected),
atol=5e-4,
err_msg=f"{name} step 0 doesn't match DDIM formula (shift={shift})",
)
@@ -790,10 +854,14 @@ class TestSchedulerCoherence:
results[name] = np.array(r)
np.testing.assert_allclose(
results["dpm++"], results["euler"], atol=1e-5,
results["dpm++"],
results["euler"],
atol=1e-5,
)
np.testing.assert_allclose(
results["unipc"], results["euler"], atol=1e-5,
results["unipc"],
results["euler"],
atol=1e-5,
)
def test_dpmpp_unipc_agree_on_step1(self):
@@ -834,7 +902,10 @@ class TestSchedulerCoherence:
shape = (1, 2, 1, 2, 2)
noise = mx.random.normal(shape)
from mlx_video.models.wan.scheduler import FlowDPMPP2MScheduler, FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import (
FlowDPMPP2MScheduler,
FlowUniPCScheduler,
)
for cls in (FlowDPMPP2MScheduler, FlowUniPCScheduler):
sched = cls()
@@ -857,27 +928,34 @@ class TestSchedulerCoherence:
mx.eval(latents)
result2 = np.array(latents)
np.testing.assert_allclose(result1, result2, atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()")
np.testing.assert_allclose(
result1,
result2,
atol=1e-5,
err_msg=f"{cls.__name__} not reproducible after reset()",
)
# ---------------------------------------------------------------------------
# UniPC Corrector Default Tests
# ---------------------------------------------------------------------------
class TestUniPCCorrectorDefault:
"""Tests that the UniPC corrector is enabled by default,
matching official FlowUniPCMultistepScheduler behavior."""
def test_corrector_enabled_by_default(self):
"""Default construction should have corrector enabled."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
sched = FlowUniPCScheduler()
assert sched._use_corrector is True
def test_corrector_affects_output(self):
"""Corrector should produce different results than no corrector after step 1."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)
@@ -900,7 +978,8 @@ class TestUniPCCorrectorDefault:
def test_corrector_does_not_affect_first_step(self):
"""Step 0 should be identical regardless of corrector setting."""
from mlx_video.models.wan.scheduler import FlowUniPCScheduler
from mlx_video.models.wan_2.scheduler import FlowUniPCScheduler
mx.random.seed(42)
shape = (1, 4, 1, 4, 4)
noise = mx.random.normal(shape)

View File

@@ -3,16 +3,16 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import pytest
# ---------------------------------------------------------------------------
# T5 Encoder Tests
# ---------------------------------------------------------------------------
class TestT5LayerNorm:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5LayerNorm
from mlx_video.models.wan_2.text_encoder import T5LayerNorm
norm = T5LayerNorm(64)
x = mx.random.normal((2, 10, 64))
out = norm(x)
@@ -21,7 +21,8 @@ class TestT5LayerNorm:
def test_rms_normalization(self):
"""After T5LayerNorm with weight=1, RMS should be ~1."""
from mlx_video.models.wan.text_encoder import T5LayerNorm
from mlx_video.models.wan_2.text_encoder import T5LayerNorm
norm = T5LayerNorm(128)
x = mx.random.normal((1, 5, 128)) * 5.0
out = norm(x)
@@ -34,14 +35,16 @@ class TestT5LayerNorm:
class TestT5RelativeEmbedding:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(10, 10)
mx.eval(out)
assert out.shape == (1, 4, 10, 10) # [1, N, lq, lk]
def test_asymmetric_lengths(self):
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=4)
out = rel_emb(8, 12)
mx.eval(out)
@@ -49,7 +52,8 @@ class TestT5RelativeEmbedding:
def test_symmetry(self):
"""Position bias should have structure (not all zeros/random)."""
from mlx_video.models.wan.text_encoder import T5RelativeEmbedding
from mlx_video.models.wan_2.text_encoder import T5RelativeEmbedding
rel_emb = T5RelativeEmbedding(num_buckets=32, num_heads=2)
out = rel_emb(6, 6)
mx.eval(out)
@@ -63,7 +67,8 @@ class TestT5RelativeEmbedding:
class TestT5Attention:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
out = attn(x)
@@ -72,13 +77,15 @@ class TestT5Attention:
def test_no_scaling(self):
"""T5 attention famously has no sqrt(d) scaling. Verify structure."""
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
# No scale attribute (unlike standard attention)
assert not hasattr(attn, "scale")
def test_with_position_bias(self):
from mlx_video.models.wan.text_encoder import T5Attention, T5RelativeEmbedding
from mlx_video.models.wan_2.text_encoder import T5Attention, T5RelativeEmbedding
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
rel_emb = T5RelativeEmbedding(32, 4)
x = mx.random.normal((1, 10, 64))
@@ -88,7 +95,8 @@ class TestT5Attention:
assert out.shape == (1, 10, 64)
def test_with_mask(self):
from mlx_video.models.wan.text_encoder import T5Attention
from mlx_video.models.wan_2.text_encoder import T5Attention
attn = T5Attention(dim=64, dim_attn=64, num_heads=4)
x = mx.random.normal((1, 10, 64))
mask = mx.ones((1, 10))
@@ -100,7 +108,8 @@ class TestT5Attention:
class TestT5FeedForward:
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5FeedForward
from mlx_video.models.wan_2.text_encoder import T5FeedForward
ffn = T5FeedForward(64, 256)
x = mx.random.normal((1, 10, 64))
out = ffn(x)
@@ -109,7 +118,8 @@ class TestT5FeedForward:
def test_gated_structure(self):
"""T5 FFN is gated: gate(x) * fc1(x)."""
from mlx_video.models.wan.text_encoder import T5FeedForward
from mlx_video.models.wan_2.text_encoder import T5FeedForward
ffn = T5FeedForward(32, 64)
assert hasattr(ffn, "gate_proj")
assert hasattr(ffn, "fc1")
@@ -121,10 +131,17 @@ class TestT5Encoder:
mx.random.seed(42)
def test_output_shape(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
ids = mx.array([[1, 5, 10, 0, 0]])
mask = mx.array([[1, 1, 1, 0, 0]])
@@ -133,39 +150,67 @@ class TestT5Encoder:
assert out.shape == (1, 5, 64)
def test_shared_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=True,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=True,
)
assert encoder.pos_embedding is not None
for block in encoder.blocks:
assert block.pos_embedding is None
def test_per_layer_pos(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
assert encoder.pos_embedding is None
for block in encoder.blocks:
assert block.pos_embedding is not None
def test_param_count(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
num_params = sum(p.size for _, p in nn.utils.tree_flatten(encoder.parameters()))
assert num_params > 0
def test_without_mask(self):
from mlx_video.models.wan.text_encoder import T5Encoder
from mlx_video.models.wan_2.text_encoder import T5Encoder
encoder = T5Encoder(
vocab_size=100, dim=64, dim_attn=64, dim_ffn=128,
num_heads=4, num_layers=2, num_buckets=32, shared_pos=False,
vocab_size=100,
dim=64,
dim_attn=64,
dim_ffn=128,
num_heads=4,
num_layers=2,
num_buckets=32,
shared_pos=False,
)
ids = mx.array([[1, 5, 10]])
out = encoder(ids)

Some files were not shown because too many files have changed in this diff Show More