Refactor LTX-2 model structure

This commit is contained in:
Prince Canuma
2026-03-16 14:50:01 +01:00
parent decb3eb9e5
commit 3a0da19adb
50 changed files with 3882 additions and 3365 deletions

View File

@@ -0,0 +1,8 @@
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
)
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio

View File

@@ -0,0 +1,161 @@
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import get_timestep_embedding
class AdaLayerNormSingle(nn.Module):
def __init__(
self,
embedding_dim: int,
embedding_coefficient: int = 6,
use_additional_conditions: bool = False,
):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim=embedding_dim,
size_emb_dim=0 if not use_additional_conditions else embedding_dim // 3,
use_additional_conditions=use_additional_conditions,
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
def __call__(
self,
timestep: mx.array,
added_cond_kwargs: dict | None = None,
batch_size: int | None = None,
hidden_dtype: mx.Dtype | None = None,
) -> Tuple[mx.array, mx.array]:
added_cond_kwargs = added_cond_kwargs or {}
embedded_timestep = self.emb(
timestep,
batch_size=batch_size,
hidden_dtype=hidden_dtype,
**added_cond_kwargs,
)
scale_shift_params = self.linear(self.silu(embedded_timestep))
return scale_shift_params, embedded_timestep
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(
self,
embedding_dim: int,
size_emb_dim: int = 0,
use_additional_conditions: bool = False,
timestep_proj_dim: int = 256,
):
super().__init__()
self.embedding_dim = embedding_dim
self.size_emb_dim = size_emb_dim
self.use_additional_conditions = use_additional_conditions
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
if use_additional_conditions and size_emb_dim > 0:
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
def __call__(
self,
timestep: mx.array,
resolution: mx.array | None = None,
aspect_ratio: mx.array | None = None,
batch_size: int | None = None,
hidden_dtype: mx.Dtype | None = None,
) -> mx.array:
# Project timestep
timesteps_proj = self.time_proj(timestep)
if hidden_dtype is not None:
timesteps_proj = timesteps_proj.astype(hidden_dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj)
# Add additional conditions if enabled
if self.use_additional_conditions and self.size_emb_dim > 0:
if resolution is not None and aspect_ratio is not None:
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
timesteps_emb = timesteps_emb + additional_embeds
return timesteps_emb
class Timesteps(nn.Module):
def __init__(
self,
num_channels: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1.0,
):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def __call__(self, timesteps: mx.array) -> mx.array:
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int | None = None,
):
super().__init__()
out_dim = out_dim or time_embed_dim
self.linear1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
self.linear2 = nn.Linear(time_embed_dim, out_dim)
def __call__(self, sample: mx.array) -> mx.array:
sample = self.linear1(sample)
sample = self.act(sample)
sample = self.linear2(sample)
return sample
class ConditionEmbedding(nn.Module):
def __init__(self, size_emb_dim: int, embedding_dim: int):
super().__init__()
self.resolution_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
self.aspect_ratio_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
def __call__(
self,
resolution: mx.array,
aspect_ratio: mx.array,
hidden_dtype: mx.Dtype | None = None,
) -> mx.array:
resolution_emb = self.resolution_embedder(resolution)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio)
if hidden_dtype is not None:
resolution_emb = resolution_emb.astype(hidden_dtype)
aspect_ratio_emb = aspect_ratio_emb.astype(hidden_dtype)
return resolution_emb + aspect_ratio_emb

View File

@@ -0,0 +1,154 @@
"""Attention module for LTX-2."""
import math
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.config import LTXRopeType
from mlx_video.models.ltx_2.rope import apply_rotary_emb
def scaled_dot_product_attention(
q: mx.array,
k: mx.array,
v: mx.array,
heads: int,
mask: Optional[mx.array] = None,
) -> mx.array:
b, q_seq_len, dim = q.shape
_, kv_seq_len, _ = k.shape
dim_head = dim // heads
# Reshape to (B, seq_len, heads, dim_head)
q = mx.reshape(q, (b, q_seq_len, heads, dim_head))
k = mx.reshape(k, (b, kv_seq_len, heads, dim_head))
v = mx.reshape(v, (b, kv_seq_len, heads, dim_head))
# Transpose to (B, heads, seq_len, dim_head)
q = mx.swapaxes(q, 1, 2)
k = mx.swapaxes(k, 1, 2)
v = mx.swapaxes(v, 1, 2)
# Handle mask dimensions
if mask is not None:
# Add batch dimension if needed
if mask.ndim == 2:
mask = mx.expand_dims(mask, axis=0)
# Add heads dimension if needed
if mask.ndim == 3:
mask = mx.expand_dims(mask, axis=1)
# Compute scaled dot-product attention
scale = 1.0 / math.sqrt(dim_head)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
# Reshape back to (B, q_seq_len, heads * dim_head)
out = mx.swapaxes(out, 1, 2)
out = mx.reshape(out, (b, q_seq_len, heads * dim_head))
return out
class Attention(nn.Module):
"""Multi-head attention with rotary position embeddings.
Supports both self-attention and cross-attention.
"""
def __init__(
self,
query_dim: int,
context_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
norm_eps: float = 1e-6,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
has_gate_logits: bool = False,
):
super().__init__()
self.rope_type = rope_type
self.heads = heads
self.dim_head = dim_head
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
# Q, K, V projections
self.to_q = nn.Linear(query_dim, inner_dim, bias=True)
self.to_k = nn.Linear(context_dim, inner_dim, bias=True)
self.to_v = nn.Linear(context_dim, inner_dim, bias=True)
# Q and K normalization
self.q_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
self.k_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
# Output projection
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
# Per-head gating (LTX-2.3)
if has_gate_logits:
self.to_gate_logits = nn.Linear(query_dim, heads, bias=True)
def __call__(
self,
x: mx.array,
context: Optional[mx.array] = None,
mask: Optional[mx.array] = None,
pe: Optional[Tuple[mx.array, mx.array]] = None,
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
skip_attention: bool = False,
) -> mx.array:
"""Forward pass.
Args:
x: Query input of shape (B, seq_len, query_dim)
context: Context for cross-attention. If None, uses x (self-attention)
mask: Attention mask
pe: Position embeddings for query (and key if k_pe is None)
k_pe: Position embeddings for key (optional, uses pe if None)
skip_attention: If True, bypass Q*K*V attention and use value projection
only (for STG perturbation). Matches PyTorch all_perturbed=True.
Returns:
Attention output of shape (B, seq_len, query_dim)
"""
# Compute per-head gate early (from original input)
gate = None
if hasattr(self, "to_gate_logits"):
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
context = x if context is None else context
v = self.to_v(context)
if skip_attention:
# STG: bypass Q*K*V attention, use value projection only
out = v
else:
# Standard attention
q = self.to_q(x)
k = self.to_k(context)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe, self.rope_type)
k_pe_to_use = pe if k_pe is None else k_pe
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
# Apply per-head gating
if gate is not None:
b, seq_len, _ = out.shape
out = mx.reshape(out, (b, seq_len, self.heads, self.dim_head))
out = out * gate[..., None]
out = mx.reshape(out, (b, seq_len, -1))
# Project output
return self.to_out(out)

View File

@@ -0,0 +1,48 @@
"""Audio VAE module for LTX-2 audio generation."""
from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
from .causal_conv_2d import CausalConv2d, make_conv2d
from ..config import CausalityAxis
from .downsample import Downsample, build_downsampling_path
from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock
from .upsample import Upsample, build_upsampling_path
from .vocoder import Vocoder, load_vocoder
__all__ = [
# Main components
"AudioEncoder",
"AudioDecoder",
"Vocoder",
"load_vocoder",
"decode_audio",
# Audio processing
"load_audio",
"ensure_stereo",
"waveform_to_mel",
# Ops
"AudioLatentShape",
"AudioPatchifier",
"PerChannelStatistics",
# Building blocks
"AttentionType",
"AttnBlock",
"make_attn",
"CausalConv2d",
"make_conv2d",
"CausalityAxis",
"Downsample",
"build_downsampling_path",
"NormType",
"PixelNorm",
"build_normalization_layer",
"ResBlock1",
"ResBlock2",
"ResnetBlock",
"LRELU_SLOPE",
"Upsample",
"build_upsampling_path",
]

View File

@@ -0,0 +1,108 @@
"""Attention blocks for audio VAE."""
from enum import Enum
import mlx.core as mx
import mlx.nn as nn
from .normalization import NormType, build_normalization_layer
class AttentionType(Enum):
"""Enum for specifying the attention mechanism type."""
VANILLA = "vanilla"
LINEAR = "linear"
NONE = "none"
class AttnBlock(nn.Module):
"""Self-attention block for audio VAE."""
def __init__(
self,
in_channels: int,
norm_type: NormType = NormType.GROUP,
) -> None:
super().__init__()
self.in_channels = in_channels
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
# Using Conv2d with kernel_size=1 for Q, K, V projections
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass through attention block.
Args:
x: Input tensor of shape (B, H, W, C) in MLX channels-last format
Returns:
Output tensor with attention applied (residual connection)
"""
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# Compute attention
# x shape: (B, H, W, C)
b, h, w, c = q.shape
# Reshape for attention: (B, H*W, C)
q = q.reshape(b, h * w, c)
k = k.reshape(b, h * w, c)
v = v.reshape(b, h * w, c)
# Attention: Q @ K^T / sqrt(d)
# q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW)
# w_: (B, HW, HW)
scale = float(c) ** (-0.5)
w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale
w_ = mx.softmax(w_, axis=-1)
# Attend to values
# w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C)
h_ = mx.matmul(w_, v)
# Reshape back to spatial dims
h_ = h_.reshape(b, h, w, c)
h_ = self.proj_out(h_)
return x + h_
class Identity(nn.Module):
"""Identity module that returns input unchanged."""
def __call__(self, x: mx.array) -> mx.array:
return x
def make_attn(
in_channels: int,
attn_type: AttentionType = AttentionType.VANILLA,
norm_type: NormType = NormType.GROUP,
) -> nn.Module:
"""
Create an attention module based on type.
Args:
in_channels: Number of input channels
attn_type: Type of attention mechanism
norm_type: Type of normalization
Returns:
Attention module
"""
if attn_type == AttentionType.VANILLA:
return AttnBlock(in_channels, norm_type=norm_type)
elif attn_type == AttentionType.NONE:
return Identity()
elif attn_type == AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
else:
raise ValueError(f"Unknown attention type: {attn_type}")

View File

@@ -0,0 +1,135 @@
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
using librosa for macOS/MLX compatibility.
"""
from pathlib import Path
import numpy as np
import mlx.core as mx
def load_audio(
path: str,
target_sr: int = 16000,
start_time: float = 0.0,
max_duration: float | None = None,
mono: bool = False,
) -> tuple[np.ndarray, int]:
"""Load audio file, resample to target sample rate.
Args:
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
target_sr: Target sample rate (default 16000 Hz).
start_time: Start time in seconds.
max_duration: Maximum duration in seconds. None = read to end.
mono: If True, convert to mono. Default False (preserve channels).
Returns:
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
"""
import librosa
# librosa.load returns mono by default; we want to preserve stereo
y, sr = librosa.load(
path,
sr=target_sr,
mono=mono,
offset=start_time,
duration=max_duration,
)
# Ensure 2D: (channels, samples)
if y.ndim == 1:
y = y[np.newaxis, :] # (1, samples)
return y.astype(np.float32), sr
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
if waveform.ndim == 1:
waveform = waveform[np.newaxis, :]
if waveform.shape[0] == 1:
waveform = np.concatenate([waveform, waveform], axis=0)
elif waveform.shape[0] > 2:
waveform = waveform[:2]
return waveform
def waveform_to_mel(
waveform: np.ndarray,
sample_rate: int = 16000,
n_fft: int = 1024,
hop_length: int = 160,
win_length: int = 1024,
n_mels: int = 64,
fmin: float = 0.0,
fmax: float = 8000.0,
) -> mx.array:
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
PyTorch reference:
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
Args:
waveform: (channels, samples) float32 numpy array.
sample_rate: Sample rate of the waveform.
n_fft: FFT size.
hop_length: Hop length.
win_length: Window length.
n_mels: Number of mel bins.
fmin: Minimum frequency for mel filterbank.
fmax: Maximum frequency for mel filterbank.
Returns:
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
"""
import librosa
# Ensure 2D
if waveform.ndim == 1:
waveform = waveform[np.newaxis, :]
channels = waveform.shape[0]
mels = []
for ch in range(channels):
# Magnitude spectrogram (power=1.0)
S = np.abs(librosa.stft(
waveform[ch],
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=True,
pad_mode="reflect",
))
# Mel filterbank with slaney normalization
mel_basis = librosa.filters.mel(
sr=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
norm="slaney",
)
mel = mel_basis @ S
# Log scale
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
# Transpose: (n_mels, time) -> (time, n_mels)
mel = mel.T
mels.append(mel)
# Stack channels: (channels, time, n_mels)
mel_spec = np.stack(mels, axis=0)
# Add batch dim: (1, channels, time, n_mels)
mel_spec = mel_spec[np.newaxis, ...]
return mx.array(mel_spec, dtype=mx.float32)

View File

@@ -0,0 +1,532 @@
"""Audio VAE encoder and decoder for LTX-2."""
from typing import Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock
from .upsample import build_upsampling_path
LATENT_DOWNSAMPLE_FACTOR = 4
def build_mid_block(
channels: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
add_attention: bool,
) -> dict:
"""Build the middle block with two ResNet blocks and optional attention."""
mid = {}
mid["block_1"] = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
mid["attn_1"] = (
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
)
mid["block_2"] = ResnetBlock(
in_channels=channels,
out_channels=channels,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
return mid
def run_mid_block(mid: dict, features: mx.array) -> mx.array:
"""Run features through the middle block."""
features = mid["block_1"](features, temb=None)
if mid["attn_1"] is not None:
features = mid["attn_1"](features)
return mid["block_2"](features, temb=None)
class AudioEncoder(nn.Module):
"""Encoder that compresses audio spectrograms into latent representations."""
def __init__(self, config: AudioEncoderModelConfig) -> None:
super().__init__()
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = config.sample_rate
self.mel_hop_length = config.mel_hop_length
self.is_causal = config.is_causal
self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=config.sample_rate,
hop_length=config.mel_hop_length,
is_causal=config.is_causal,
)
self.ch = config.ch
self.temb_ch = 0
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.in_channels = config.in_channels
self.z_channels = config.z_channels
self.double_z = config.double_z
self.norm_type = config.norm_type
self.causality_axis = config.causality_axis
self.attn_type = config.attn_type
self.conv_in = make_conv2d(
config.in_channels, self.ch, kernel_size=3, stride=1,
causality_axis=self.causality_axis,
)
self.down, block_in = build_downsampling_path(
ch=config.ch,
ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=config.num_res_blocks,
resolution=config.resolution,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=config.attn_resolutions or set(),
resamp_with_conv=config.resamp_with_conv,
)
self.mid = build_mid_block(
channels=block_in,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=config.mid_block_add_attention,
)
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
self.conv_out = make_conv2d(
block_in, out_channels, kernel_size=3, stride=1,
causality_axis=self.causality_axis,
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio encoder weights from PyTorch format."""
sanitized = {}
for key, value in weights.items():
new_key = key
if key.startswith("audio_vae.encoder."):
new_key = key.replace("audio_vae.encoder.", "")
elif key.startswith("encoder."):
new_key = key.replace("encoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif "per_channel_statistics" in key:
if "mean-of-means" in key or "latents_mean" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key or "latents_std" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
elif key == "latents_mean":
new_key = "per_channel_statistics.mean_of_means"
elif key == "latents_std":
new_key = "per_channel_statistics.std_of_means"
else:
continue
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
"""Load audio encoder from pretrained weights."""
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
import json
model_path = Path(model_path)
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
encoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
encoder.load_weights(list(weights.items()), strict=True)
return encoder
def __call__(self, spectrogram: mx.array) -> mx.array:
"""Encode audio spectrogram into normalized latent representation.
Args:
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
Returns:
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
"""
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
h = self.conv_in(spectrogram)
h = self._run_downsampling_path(h)
h = run_mid_block(self.mid, h)
h = self._finalize_output(h)
return self._normalize_latents(h)
def _run_downsampling_path(self, h: mx.array) -> mx.array:
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx in range(self.num_res_blocks):
h = stage["block"][block_idx](h, temb=None)
if block_idx in stage["attn"]:
h = stage["attn"][block_idx](h)
if level != self.num_resolutions - 1 and "downsample" in stage:
h = stage["downsample"](h)
return h
def _finalize_output(self, h: mx.array) -> mx.array:
h = self.norm_out(h)
h = nn.silu(h)
return self.conv_out(h)
def _normalize_latents(self, h: mx.array) -> mx.array:
"""Normalize encoder output using per-channel statistics.
Takes first half of channels (mean) when double_z=True,
then patchifies, normalizes, and unpatchifies.
"""
# h shape: (B, T', F', 2*z_channels) in MLX format
z_channels = self.z_channels
means = h[..., :z_channels]
latent_shape = AudioLatentShape(
batch=means.shape[0],
channels=means.shape[3],
frames=means.shape[1],
mel_bins=means.shape[2],
)
patched = self.patchifier.patchify(means)
normalized = self.per_channel_statistics.normalize(patched)
return self.patchifier.unpatchify(normalized, latent_shape)
class AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers,
attention resolutions, and causal convolutions.
"""
def __init__(
self,
config: AudioDecoderModelConfig,
) -> None:
"""
Initialize the AudioDecoder.
Args:
ch: Base number of feature channels
out_ch: Number of output channels (2 for stereo)
ch_mult: Multiplicative factors for channels at each resolution
num_res_blocks: Number of residual blocks per resolution
attn_resolutions: Resolutions at which to apply attention
resolution: Input spatial resolution
z_channels: Number of latent channels
norm_type: Normalization type
causality_axis: Axis for causal convolutions
dropout: Dropout probability
mid_block_add_attention: Whether to add attention in middle block
sample_rate: Audio sample rate
mel_hop_length: Hop length for mel spectrogram
is_causal: Whether to use causal convolutions
mel_bins: Number of mel frequency bins
"""
super().__init__()
# Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
# ch=128 matches this dimension, so use ch for per_channel_statistics
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = config.sample_rate
self.mel_hop_length = config.mel_hop_length
self.is_causal = config.is_causal
self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=config.sample_rate,
hop_length=config.mel_hop_length,
is_causal=config.is_causal,
)
self.ch = config.ch
self.temb_ch = 0
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.out_ch = config.out_ch
self.give_pre_end = config.give_pre_end
self.tanh_out = config.tanh_out
self.norm_type = config.norm_type
self.z_channels = config.z_channels
self.channel_multipliers = config.ch_mult
self.attn_resolutions = config.attn_resolutions
self.causality_axis = config.causality_axis
self.attn_type = config.attn_type
base_block_channels = config.ch * self.channel_multipliers[-1]
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
self.mid = build_mid_block(
channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=config.mid_block_add_attention,
)
self.up, final_block_channels = build_upsampling_path(
ch=config.ch,
ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=config.num_res_blocks,
resolution=config.resolution,
temb_channels=self.temb_ch,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=config.attn_resolutions,
resamp_with_conv=config.resamp_with_conv,
initial_block_channels=base_block_channels,
)
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.conv_out = make_conv2d(
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for audio VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle audio_vae.decoder weights
if key.startswith("audio_vae.decoder."):
new_key = key.replace("audio_vae.decoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue # Skip other statistics keys
else:
continue # Skip non-decoder keys
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model."""
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
import json
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
decoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = decoder.sanitize(weights)
decoder.load_weights(list(weights.items()), strict=True)
return decoder
def __call__(self, sample: mx.array) -> mx.array:
"""
Decode latent features back to audio spectrograms.
Args:
sample: Encoded latent representation of shape (B, H, W, C) in MLX format
or (B, C, H, W) in PyTorch format (will be transposed)
Returns:
Reconstructed audio spectrogram
"""
# Handle input format - if channels are in dim 1, transpose to channels-last
if sample.shape[1] == self.z_channels and sample.ndim == 4:
# PyTorch format (B, C, H, W) -> MLX format (B, H, W, C)
sample = mx.transpose(sample, (0, 2, 3, 1))
sample, target_shape = self._denormalize_latents(sample)
h = self.conv_in(sample)
h = run_mid_block(self.mid, h)
h = self._run_upsampling_path(h)
h = self._finalize_output(h)
return self._adjust_output_shape(h, target_shape)
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
"""Denormalize latents using per-channel statistics."""
# sample shape: (B, H, W, C) in MLX format
latent_shape = AudioLatentShape(
batch=sample.shape[0],
channels=sample.shape[3], # channels last
frames=sample.shape[1], # height = frames
mel_bins=sample.shape[2], # width = mel_bins
)
sample_patched = self.patchifier.patchify(sample)
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis != CausalityAxis.NONE:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_shape = AudioLatentShape(
batch=latent_shape.batch,
channels=self.out_ch,
frames=target_frames,
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
)
return sample, target_shape
def _adjust_output_shape(
self,
decoded_output: mx.array,
target_shape: AudioLatentShape,
) -> mx.array:
"""
Adjust output shape to match target dimensions for variable-length audio.
Args:
decoded_output: Tensor of shape (B, H, W, C) in MLX format
target_shape: AudioLatentShape describing target dimensions
Returns:
Tensor adjusted to match target_shape exactly
"""
# Current output shape: (batch, frames, mel_bins, channels) in MLX format
_, current_time, current_freq, _ = decoded_output.shape
target_channels = target_shape.channels
target_time = target_shape.frames
target_freq = target_shape.mel_bins
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
]
# Step 2: Calculate padding needed for time and frequency dimensions
time_padding_needed = target_time - decoded_output.shape[1]
freq_padding_needed = target_freq - decoded_output.shape[2]
# Step 3: Apply padding if needed
if time_padding_needed > 0 or freq_padding_needed > 0:
# MLX pad: [(before_0, after_0), ...]
# For (B, H, W, C): H=time, W=freq
padding = [
(0, 0), # batch
(0, max(time_padding_needed, 0)), # time
(0, max(freq_padding_needed, 0)), # freq
(0, 0), # channels
]
decoded_output = mx.pad(decoded_output, padding)
# Step 4: Final safety crop to ensure exact target shape
decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels]
# Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility
decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2))
return decoded_output
def _run_upsampling_path(self, h: mx.array) -> mx.array:
"""Run through upsampling path."""
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx in range(len(stage["block"])):
h = stage["block"][block_idx](h, temb=None)
if block_idx in stage["attn"]:
h = stage["attn"][block_idx](h)
if level != 0 and "upsample" in stage:
h = stage["upsample"](h)
return h
def _finalize_output(self, h: mx.array) -> mx.array:
"""Apply final normalization and convolution."""
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return mx.tanh(h) if self.tanh_out else h
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
"""
Decode an audio latent representation using the provided audio decoder and vocoder.
Args:
latent: Input audio latent tensor
audio_decoder: Model to decode the latent to spectrogram
vocoder: Model to convert spectrogram to audio waveform
Returns:
Decoded audio as a float tensor
"""
decoded_audio = audio_decoder(latent)
decoded_audio = vocoder(decoded_audio)
# Remove batch dimension if present
if decoded_audio.shape[0] == 1:
decoded_audio = decoded_audio[0]
return decoded_audio.astype(mx.float32)

View File

@@ -0,0 +1,146 @@
"""Causal 2D convolutions for audio VAE."""
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from ..config import CausalityAxis
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
"""Convert int or tuple to tuple pair."""
if isinstance(x, int):
return (x, x)
return x
class CausalConv2d(nn.Module):
"""
A causal 2D convolution.
This layer ensures that the output at time `t` only depends on inputs
at time `t` and earlier. It achieves this by applying asymmetric padding
to the time dimension before the convolution.
Note: MLX Conv2d expects input shape (N, H, W, C) - channels last.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.causality_axis = causality_axis
# Ensure kernel_size and dilation are tuples
kernel_size = _pair(kernel_size)
dilation = _pair(dilation)
# Calculate padding dimensions
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
# Store padding for manual application
# MLX pad order: [(before_axis0, after_axis0), (before_axis1, after_axis1), ...]
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
if self.causality_axis == CausalityAxis.NONE:
# Non-causal: symmetric padding
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
# Causal on width: pad left (before width axis)
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
elif self.causality_axis == CausalityAxis.HEIGHT:
# Causal on height: pad top (before height axis)
self.padding = (pad_h, 0, pad_w // 2, pad_w - pad_w // 2)
else:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
# The internal convolution layer uses no padding, as we handle it manually
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass with causal padding.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
Returns:
Output tensor after causal convolution
"""
# Apply causal padding before convolution
# padding format: (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
pad_h_top, pad_h_bottom, pad_w_left, pad_w_right = self.padding
if any(p > 0 for p in self.padding):
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
return self.conv(x)
def make_conv2d(
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
padding: Union[int, Tuple[int, int], None] = None,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causality_axis: CausalityAxis | None = None,
) -> nn.Module:
"""
Create a 2D convolution layer that can be either causal or non-causal.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Size of the convolution kernel
stride: Convolution stride
padding: Padding (if None, will be calculated based on causal flag)
dilation: Dilation rate
groups: Number of groups for grouped convolution
bias: Whether to use bias
causality_axis: Dimension along which to apply causality.
Returns:
Either a regular Conv2d or CausalConv2d layer
"""
if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d(
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
)
else:
# For non-causal convolution, use symmetric padding if not specified
if padding is None:
if isinstance(kernel_size, int):
padding = kernel_size // 2
else:
padding = tuple(k // 2 for k in kernel_size)
return nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)

View File

@@ -0,0 +1,127 @@
"""Downsampling layers for audio VAE."""
from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from ..config import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock
class Downsample(nn.Module):
"""
A downsampling layer that can use either a strided convolution
or average pooling. Supports standard and causal padding for the
convolutional mode.
"""
def __init__(
self,
in_channels: int,
with_conv: bool,
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
) -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
raise ValueError("causality is only supported when `with_conv=True`.")
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in MLX conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass with downsampling.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
Returns:
Downsampled tensor
"""
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom) for PyTorch
# For MLX pad: [(before_axis0, after_axis0), ...]
# x shape: (N, H, W, C) -> pad on H and W axes
if self.causality_axis == CausalityAxis.NONE:
# pad: (left=0, right=1, top=0, bottom=1)
pad = [(0, 0), (0, 1), (0, 1), (0, 0)]
elif self.causality_axis == CausalityAxis.WIDTH:
# pad: (left=2, right=0, top=0, bottom=1)
pad = [(0, 0), (0, 1), (2, 0), (0, 0)]
elif self.causality_axis == CausalityAxis.HEIGHT:
# pad: (left=0, right=1, top=2, bottom=0)
pad = [(0, 0), (2, 0), (0, 1), (0, 0)]
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
# pad: (left=1, right=0, top=0, bottom=1)
pad = [(0, 0), (0, 1), (1, 0), (0, 0)]
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
x = mx.pad(x, pad, constant_values=0)
x = self.conv(x)
else:
# Average pooling with 2x2 kernel and stride 2
# MLX doesn't have built-in avg_pool2d, implement manually
# x shape: (N, H, W, C)
n, h, w, c = x.shape
# Reshape to (N, H//2, 2, W//2, 2, C) and mean over pooling dims
x = x.reshape(n, h // 2, 2, w // 2, 2, c)
x = mx.mean(x, axis=(2, 4))
return x
def build_downsampling_path(
*,
ch: int,
ch_mult: Tuple[int, ...],
num_resolutions: int,
num_res_blocks: int,
resolution: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
attn_resolutions: Set[int],
resamp_with_conv: bool,
) -> tuple[dict, int]:
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
down_modules = {}
curr_res = resolution
in_ch_mult = (1, *tuple(ch_mult))
block_in = ch
for i_level in range(num_resolutions):
stage = {}
stage["block"] = {}
stage["attn"] = {}
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(num_res_blocks):
stage["block"][i_block] = ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
if i_level != num_resolutions - 1:
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res // 2
down_modules[i_level] = stage
return down_modules, block_in

View File

@@ -0,0 +1,59 @@
"""Normalization layers for audio VAE."""
from enum import Enum
import mlx.core as mx
import mlx.nn as nn
class NormType(Enum):
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
GROUP = "group"
PIXEL = "pixel"
class PixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
For each element along the chosen dimension, this layer normalizes the tensor
by the root-mean-square of its values across that dimension:
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
"""
Args:
dim: Dimension along which to compute the RMS (typically channels).
eps: Small constant added for numerical stability.
"""
super().__init__()
self.dim = dim
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
"""Apply RMS normalization along the configured dimension."""
mean_sq = mx.mean(x**2, axis=self.dim, keepdims=True)
rms = mx.sqrt(mean_sq + self.eps)
return x / rms
def build_normalization_layer(
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
) -> nn.Module:
"""
Create a normalization layer based on the normalization type.
Args:
in_channels: Number of input channels
num_groups: Number of groups for group normalization
normtype: Type of normalization: "group" or "pixel"
Returns:
A normalization layer
"""
if normtype == NormType.GROUP:
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
if normtype == NormType.PIXEL:
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
return PixelNorm(dim=-1, eps=1e-6)
raise ValueError(f"Invalid normalization type: {normtype}")

View File

@@ -0,0 +1,98 @@
"""Audio processing utilities for audio VAE."""
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
@dataclass
class AudioLatentShape:
"""Shape descriptor for audio latent representations."""
batch: int
channels: int
frames: int
mel_bins: int
class PerChannelStatistics(nn.Module):
"""
Per-channel statistics for normalizing and denormalizing the latent representation.
This statistics is computed over the entire dataset and stored in model's checkpoint.
"""
def __init__(self, latent_channels: int = 128) -> None:
super().__init__()
self.latent_channels = latent_channels
# Initialize buffers - will be loaded from weights
# Using underscores for MLX compatibility with weight loading
self.std_of_means = mx.ones((latent_channels,))
self.mean_of_means = mx.zeros((latent_channels,))
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latent representation."""
# Broadcast statistics to match x shape
# x shape: (B, C, ...) or (B, ..., C)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x * std) + mean
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latent representation."""
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x - mean) / std
class AudioPatchifier:
"""
Audio patchifier for converting between audio latents and patches.
Combines channels and mel_bins dimensions for per-channel statistics.
"""
def __init__(
self,
patch_size: int = 1,
audio_latent_downsample_factor: int = 4,
sample_rate: int = 16000,
hop_length: int = 160,
is_causal: bool = True,
):
self.patch_size = patch_size
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.sample_rate = sample_rate
self.hop_length = hop_length
self.is_causal = is_causal
def patchify(self, x: mx.array) -> mx.array:
"""Convert audio latents to patches.
Input shape: (B, T, F, C) in MLX format (channels last)
Output shape: (B, T, C*F) - flattened for per-channel statistics
The output order is (c f) to match PyTorch's "b c t f -> b t (c f)".
"""
# x shape: (B, T, F, C) e.g., (1, 68, 16, 8)
b, t, f, c = x.shape
# Transpose to (B, T, C, F) for correct (c f) ordering
x = mx.transpose(x, (0, 1, 3, 2))
# Reshape to (B, T, C*F) e.g., (1, 68, 128)
return x.reshape(b, t, c * f)
def unpatchify(self, x: mx.array, latent_shape: AudioLatentShape) -> mx.array:
"""Convert patches back to audio latents.
Input shape: (B, T, C*F)
Output shape: (B, T, F, C) in MLX format
Reverses patchify's "b t (c f) -> b c t f" then transposes to MLX format.
"""
# x shape: (B, T, C*F) e.g., (1, 68, 128)
b, t, cf = x.shape
c = latent_shape.channels
f = latent_shape.mel_bins
# Reshape to (B, T, C, F)
x = x.reshape(b, t, c, f)
# Transpose to MLX format (B, T, F, C)
return mx.transpose(x, (0, 1, 3, 2))

View File

@@ -0,0 +1,185 @@
"""ResNet blocks for audio VAE and vocoder."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1
def leaky_relu(x: mx.array, negative_slope: float = LRELU_SLOPE) -> mx.array:
"""Leaky ReLU activation."""
return mx.maximum(x, x * negative_slope)
class ResBlock1(nn.Module):
"""1D ResNet block for vocoder with dilated convolutions."""
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
):
super().__init__()
# First set of convolutions with different dilations
self.convs1 = {
i: nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=(kernel_size - 1) * d // 2,
)
for i, d in enumerate(dilation)
}
# Second set of convolutions with dilation=1
self.convs2 = {
i: nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=(kernel_size - 1) // 2,
)
for i in range(len(dilation))
}
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass through residual blocks."""
for i in range(len(self.convs1)):
xt = leaky_relu(x, LRELU_SLOPE)
xt = self.convs1[i](xt)
xt = leaky_relu(xt, LRELU_SLOPE)
xt = self.convs2[i](xt)
x = xt + x
return x
class ResBlock2(nn.Module):
"""1D ResNet block for vocoder (alternative version)."""
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: Tuple[int, int] = (1, 3),
):
super().__init__()
self.convs = {
i: nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=(kernel_size - 1) * d // 2,
)
for i, d in enumerate(dilation)
}
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass through residual blocks."""
for i in range(len(self.convs)):
xt = leaky_relu(x, LRELU_SLOPE)
xt = self.convs[i](xt)
x = xt + x
return x
class ResnetBlock(nn.Module):
"""2D ResNet block for audio VAE encoder/decoder."""
def __init__(
self,
*,
in_channels: int,
out_channels: int | None = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
norm_type: NormType = NormType.GROUP,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.temb_channels = temb_channels
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
self.conv1 = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
self.dropout_rate = dropout
self.conv2 = make_conv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
def __call__(
self,
x: mx.array,
temb: mx.array | None = None,
) -> mx.array:
"""
Forward pass through ResNet block.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
temb: Optional time embedding tensor
Returns:
Output tensor
"""
h = x
h = self.norm1(h)
h = nn.silu(h)
h = self.conv1(h)
if temb is not None and self.temb_channels > 0:
# temb: (B, temb_channels) -> (B, out_channels)
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
h = self.norm2(h)
h = nn.silu(h)
if self.dropout_rate > 0:
h = nn.Dropout(self.dropout_rate)(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h

View File

@@ -0,0 +1,135 @@
"""Upsampling layers for audio VAE."""
from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock
def nearest_neighbor_upsample(x: mx.array, scale_factor: int = 2) -> mx.array:
"""
Nearest neighbor upsampling for 4D tensors.
Args:
x: Input tensor of shape (N, H, W, C)
scale_factor: Upsampling factor
Returns:
Upsampled tensor of shape (N, H*scale_factor, W*scale_factor, C)
"""
n, h, w, c = x.shape
# Repeat along height and width
x = mx.repeat(x, scale_factor, axis=1)
x = mx.repeat(x, scale_factor, axis=2)
return x
class Upsample(nn.Module):
"""Upsampling layer with optional convolution."""
def __init__(
self,
in_channels: int,
with_conv: bool,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
) -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
def __call__(self, x: mx.array) -> mx.array:
"""
Forward pass with upsampling.
Args:
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
Returns:
Upsampled tensor
"""
# Nearest neighbor 2x upsampling
x = nearest_neighbor_upsample(x, scale_factor=2)
if self.with_conv:
x = self.conv(x)
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
# So the output elements rely on the following windows:
# 0: [-,-,0]
# 1: [-,0,0]
# 2: [0,0,1]
# 3: [0,1,1]
# 4: [1,1,2]
# 5: [1,2,2]
# Notice that the first and second elements in the output rely only on the first element in the input,
# while all other elements rely on two elements in the input.
# So we can drop the first element to undo the padding (rather than the last element).
# This is a no-op for non-causal convolutions.
if self.causality_axis == CausalityAxis.NONE:
pass # x remains unchanged
elif self.causality_axis == CausalityAxis.HEIGHT:
x = x[:, 1:, :, :]
elif self.causality_axis == CausalityAxis.WIDTH:
x = x[:, :, 1:, :]
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
pass # x remains unchanged
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
def build_upsampling_path(
*,
ch: int,
ch_mult: Tuple[int, ...],
num_resolutions: int,
num_res_blocks: int,
resolution: int,
temb_channels: int,
dropout: float,
norm_type: NormType,
causality_axis: CausalityAxis,
attn_type: AttentionType,
attn_resolutions: Set[int],
resamp_with_conv: bool,
initial_block_channels: int,
) -> tuple[dict, int]:
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
up_modules = {}
block_in = initial_block_channels
curr_res = resolution // (2 ** (num_resolutions - 1))
for level in reversed(range(num_resolutions)):
stage = {}
stage["block"] = {}
stage["attn"] = {}
block_out = ch * ch_mult[level]
for i_block in range(num_res_blocks + 1):
stage["block"][i_block] = ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=temb_channels,
dropout=dropout,
norm_type=norm_type,
causality_axis=causality_axis,
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
if level != 0:
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res *= 2
up_modules[level] = stage
return up_modules, block_in

View File

@@ -0,0 +1,689 @@
"""Vocoder for converting mel spectrograms to audio waveforms.
Supports:
- HiFi-GAN (LTX-2): ResBlock1 with LeakyReLU
- BigVGAN v2 (LTX-2.3): AMPBlock1 with Snake/SnakeBeta + anti-aliased resampling
- VocoderWithBWE (LTX-2.3): Base vocoder + bandwidth extension (16kHz -> 48kHz)
"""
import math
from typing import List, Tuple
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from ..config import VocoderModelConfig
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
def get_padding(kernel_size: int, dilation: int = 1) -> int:
return int((kernel_size * dilation - dilation) / 2)
# ---------------------------------------------------------------------------
# Snake / SnakeBeta activations (BigVGAN v2)
# ---------------------------------------------------------------------------
class Snake(nn.Module):
"""Snake activation: x + (1/alpha) * sin^2(alpha * x)."""
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
alpha = self.alpha # (C,)
if self.alpha_logscale:
alpha = mx.exp(alpha)
return x + (1.0 / (alpha + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
class SnakeBeta(nn.Module):
"""SnakeBeta activation: x + (1/beta) * sin^2(alpha * x)."""
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
def __call__(self, x: mx.array) -> mx.array:
alpha = self.alpha
beta = self.beta
if self.alpha_logscale:
alpha = mx.exp(alpha)
beta = mx.exp(beta)
return x + (1.0 / (beta + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling (Kaiser-sinc filters)
# ---------------------------------------------------------------------------
def _sinc(x: mx.array) -> mx.array:
return mx.where(
x == 0,
mx.ones_like(x),
mx.sin(mx.array(math.pi) * x) / (mx.array(math.pi) * x),
)
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
"""Compute a Kaiser-windowed sinc filter."""
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if amplitude > 50.0:
beta = 0.1102 * (amplitude - 8.7)
elif amplitude >= 21.0:
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
else:
beta = 0.0
# Kaiser window - compute using scipy-compatible formula
import numpy as np
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
if even:
time = mx.arange(-half_size, half_size).astype(mx.float32) + 0.5
else:
time = mx.arange(kernel_size).astype(mx.float32) - half_size
if cutoff == 0:
filter_ = mx.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ = filter_ / mx.sum(filter_)
return filter_.reshape(1, 1, kernel_size)
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
import numpy as np
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
kernel_size = 2 * width * ratio + 1
pad = width
pad_left = 2 * width * ratio
pad_right = kernel_size - ratio
time = (np.arange(kernel_size) / ratio - width) * rolloff
time_clamped = np.clip(time, -lowpass_filter_width, lowpass_filter_width)
window = np.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
sinc_filter = np.sinc(time) * window * rolloff / ratio
filter_ = mx.array(sinc_filter.astype(np.float32)).reshape(1, 1, kernel_size)
return filter_, pad, pad_left, pad_right
class LowPassFilter1d(nn.Module):
"""Low-pass filter using depthwise convolution with Kaiser-sinc kernel."""
def __init__(
self,
cutoff: float = 0.5,
half_width: float = 0.6,
stride: int = 1,
kernel_size: int = 12,
) -> None:
super().__init__()
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
# Filter buffer - shape (1, 1, K) in PyTorch format, loaded from weights
self.filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
n, l, c = x.shape
# Pad with edge values: replicate first/last value
first = mx.repeat(x[:, :1, :], self.pad_left, axis=1)
last = mx.repeat(x[:, -1:, :], self.pad_right, axis=1)
x = mx.concatenate([first, x, last], axis=1)
# Expand filter for depthwise conv: (1, 1, K) -> (C, K, 1) for groups=C
# Filter is stored in PyTorch format (1, 1, K), need (C, K, 1) MLX format
filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
filt = mx.repeat(filt, c, axis=0) # (C, K, 1)
# Transpose x for depthwise conv: (N, L, C) -> (N*C, L, 1) then conv
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
x = mx.conv1d(x, filt[:1], stride=self.stride, groups=1) # (N*C, L', 1)
x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
return x
class UpSample1d(nn.Module):
"""Anti-aliased upsampling using transposed convolution with sinc filter."""
def __init__(
self,
ratio: int = 2,
kernel_size: int = None,
window_type: str = "kaiser",
) -> None:
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
filt, self.pad, self.pad_left, self.pad_right = hann_sinc_filter1d(ratio)
self.kernel_size = filt.shape[2]
self.filter = filt
else:
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
self.filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size,
)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
n, l, c = x.shape
# Pad with edge values
first = mx.repeat(x[:, :1, :], self.pad, axis=1)
last = mx.repeat(x[:, -1:, :], self.pad, axis=1)
x = mx.concatenate([first, x, last], axis=1)
# Process per-channel via reshape: (N, L, C) -> (N*C, L, 1)
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
# Transposed conv for upsampling
# Filter: (1, 1, K) PyTorch -> (1, K, 1) MLX format for conv_transpose1d
filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
# Trim padding
x = x[:, self.pad_left:-self.pad_right, :]
x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
return x
class DownSample1d(nn.Module):
"""Anti-aliased downsampling using low-pass filter."""
def __init__(self, ratio: int = 2, kernel_size: int = None) -> None:
super().__init__()
self.ratio = ratio
kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=kernel_size,
)
def __call__(self, x: mx.array) -> mx.array:
return self.lowpass(x)
class Activation1d(nn.Module):
"""Anti-aliased activation: upsample -> activate -> downsample."""
def __init__(
self,
activation: nn.Module,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12,
) -> None:
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def __call__(self, x: mx.array) -> mx.array:
x = self.upsample(x)
x = self.act(x)
return self.downsample(x)
# ---------------------------------------------------------------------------
# AMPBlock1 (BigVGAN v2 residual block)
# ---------------------------------------------------------------------------
class AMPBlock1(nn.Module):
"""BigVGAN v2 residual block with anti-aliased Snake activations."""
def __init__(
self,
channels: int,
kernel_size: int = 3,
dilation: Tuple[int, int, int] = (1, 3, 5),
activation: str = "snakebeta",
) -> None:
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = {
i: nn.Conv1d(
channels, channels, kernel_size, stride=1,
dilation=d, padding=get_padding(kernel_size, d),
)
for i, d in enumerate(dilation)
}
self.convs2 = {
i: nn.Conv1d(
channels, channels, kernel_size, stride=1,
dilation=1, padding=get_padding(kernel_size, 1),
)
for i in range(len(dilation))
}
self.acts1 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
self.acts2 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
def __call__(self, x: mx.array) -> mx.array:
for i in range(len(self.convs1)):
xt = self.acts1[i](x)
xt = self.convs1[i](xt)
xt = self.acts2[i](xt)
xt = self.convs2[i](xt)
x = x + xt
return x
# ---------------------------------------------------------------------------
# STFT and MelSTFT (for BWE)
# ---------------------------------------------------------------------------
class STFTFn(nn.Module):
"""STFT via conv1d with precomputed DFT x window bases (loaded from checkpoint)."""
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
# Buffers loaded from checkpoint - PyTorch format (n_freqs*2, 1, filter_length)
self.forward_basis = mx.zeros((n_freqs * 2, 1, filter_length))
self.inverse_basis = mx.zeros((n_freqs * 2, 1, filter_length))
def __call__(self, y: mx.array) -> Tuple[mx.array, mx.array]:
"""Compute magnitude and phase from waveform.
Args:
y: (B, T) waveform
Returns:
magnitude: (B, n_freqs, T_frames)
phase: (B, n_freqs, T_frames)
"""
if y.ndim == 2:
y = mx.expand_dims(y, -1) # (B, T, 1)
left_pad = max(0, self.win_length - self.hop_length)
if left_pad > 0:
first = mx.repeat(y[:, :1, :], left_pad, axis=1)
y = mx.concatenate([first, y], axis=1)
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
spec = mx.conv1d(y, basis, stride=self.hop_length)
# Split real and imaginary
n_freqs = spec.shape[-1] // 2
real = spec[..., :n_freqs]
imag = spec[..., n_freqs:]
magnitude = mx.sqrt(real ** 2 + imag ** 2)
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
# Output: (B, T_frames, n_freqs) in MLX channels-last
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram from precomputed STFT bases."""
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
super().__init__()
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
self.mel_basis = mx.zeros((n_mel_channels, n_freqs))
def mel_spectrogram(self, y: mx.array) -> mx.array:
"""Compute log-mel spectrogram.
Args:
y: (B, T) waveform
Returns:
log_mel: (B, n_mels, T_frames) in channels-first for compatibility
"""
magnitude, phase = self.stft_fn(y)
# magnitude: (B, T_frames, n_freqs)
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
log_mel = mx.log(mx.clip(mel, 1e-5, None))
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
return mx.transpose(log_mel, (0, 2, 1))
# ---------------------------------------------------------------------------
# Vocoder (supports both HiFi-GAN and BigVGAN v2)
# ---------------------------------------------------------------------------
class Vocoder(nn.Module):
"""Vocoder for mel-to-waveform synthesis.
Supports resblock="1" (HiFi-GAN / LTX-2) and resblock="AMP1" (BigVGAN v2 / LTX-2.3).
"""
def __init__(self, config: VocoderModelConfig) -> None:
super().__init__()
self.output_sampling_rate = config.output_sample_rate
self.num_kernels = len(config.resblock_kernel_sizes)
self.num_upsamples = len(config.upsample_rates)
self.upsample_rates = config.upsample_rates
self.is_amp = config.resblock == "AMP1"
self.use_tanh_at_final = config.use_tanh_at_final
self.apply_final_activation = config.apply_final_activation
in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(
in_channels, config.upsample_initial_channel,
kernel_size=7, stride=1, padding=3,
)
# Upsampling layers
self.ups = {}
for i, (stride, kernel_size) in enumerate(
zip(config.upsample_rates, config.upsample_kernel_sizes)
):
in_ch = config.upsample_initial_channel // (2 ** i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch, out_ch,
kernel_size=kernel_size, stride=stride,
padding=(kernel_size - stride) // 2,
)
# Residual blocks
if self.is_amp:
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = AMPBlock1(
ch, kernel_size, tuple(dilations),
activation=config.activation,
)
block_idx += 1
else:
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
block_idx += 1
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
# Post-activation
if self.is_amp:
act_cls = SnakeBeta if config.activation == "snakebeta" else Snake
self.act_post = Activation1d(act_cls(final_channels))
# Final conv
out_channels = 2 if config.stereo else 1
self.conv_post = nn.Conv1d(
final_channels, out_channels,
kernel_size=7, stride=1, padding=3,
bias=config.use_bias_at_final,
)
self.upsample_factor = math.prod(config.upsample_rates)
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass.
Args:
x: Mel spectrogram (B, C, T, mel_bins) for stereo or (B, T, mel_bins) mono.
Returns:
Waveform (B, out_channels, T_audio) in channels-first format.
"""
# (B, C, T, mel) -> (B, C, mel, T)
x = mx.transpose(x, (0, 1, 3, 2))
if x.ndim == 4: # stereo: (B, 2, mel, T) -> (B, 2*mel, T)
b, s, c, t = x.shape
x = x.reshape(b, s * c, t)
# Channels-first (B, C, T) -> channels-last (B, T, C) for MLX conv
x = mx.transpose(x, (0, 2, 1))
x = self.conv_pre(x)
for i in range(self.num_upsamples):
if not self.is_amp:
x = leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
start = i * self.num_kernels
end = start + self.num_kernels
block_outputs = mx.stack(
[self.resblocks[idx](x) for idx in range(start, end)],
axis=0,
)
x = mx.mean(block_outputs, axis=0)
if self.is_amp:
x = self.act_post(x)
else:
x = nn.leaky_relu(x)
x = self.conv_post(x)
if self.apply_final_activation:
x = mx.tanh(x) if self.use_tanh_at_final else mx.clip(x, -1, 1)
# Back to channels-first (B, T, C) -> (B, C, T)
x = mx.transpose(x, (0, 2, 1))
return x
# ---------------------------------------------------------------------------
# VocoderWithBWE (Bandwidth Extension)
# ---------------------------------------------------------------------------
class VocoderWithBWE(nn.Module):
"""Vocoder + bandwidth extension upsampling (16kHz -> 48kHz).
Chains a base vocoder with a BWE generator that predicts a residual
added to a sinc-resampled skip connection.
"""
def __init__(
self,
vocoder: Vocoder,
bwe_generator: Vocoder,
mel_stft: MelSTFT,
input_sampling_rate: int = 16000,
output_sampling_rate: int = 48000,
hop_length: int = 80,
) -> None:
super().__init__()
self.vocoder = vocoder
self.bwe_generator = bwe_generator
self.mel_stft = mel_stft
self.input_sampling_rate = input_sampling_rate
self.output_sampling_rate = output_sampling_rate
self.hop_length = hop_length
# Hann-windowed sinc resampler (not stored in checkpoint)
self.resampler = UpSample1d(
ratio=output_sampling_rate // input_sampling_rate,
window_type="hann",
)
@property
def output_sample_rate(self) -> int:
return self.output_sampling_rate
def _compute_mel(self, audio: mx.array) -> mx.array:
"""Compute log-mel spectrogram from waveform.
Args:
audio: (B, C, T) waveform in channels-first
Returns:
mel: (B, C, n_mels, T_frames)
"""
batch, n_channels, _ = audio.shape
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
mel = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2])
def __call__(self, mel_spec: mx.array) -> mx.array:
"""Run vocoder + BWE.
Args:
mel_spec: Mel spectrogram, same format as Vocoder.forward input.
Returns:
Waveform (B, out_channels, T_audio) at output_sampling_rate.
"""
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
_, _, length_low_rate = x.shape
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
# Pad to hop_length multiple
remainder = length_low_rate % self.hop_length
if remainder != 0:
pad_amount = self.hop_length - remainder
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_amount)])
# Compute mel from vocoder output: (B, C, n_mels, T_frames)
mel = self._compute_mel(x)
# BWE expects (B, C, T_frames, mel_bins) -> transpose last two dims
mel_for_bwe = mx.transpose(mel, (0, 1, 3, 2)) # (B, C, T_frames, n_mels)
residual = self.bwe_generator(mel_for_bwe) # (B, C, T_high)
# Sinc upsample skip connection
# resampler expects (N, L, C): transpose from (B, C, T) -> (B, T, C)
x_for_resample = mx.transpose(x, (0, 2, 1))
skip = self.resampler(x_for_resample)
skip = mx.transpose(skip, (0, 2, 1)) # back to (B, C, T)
return mx.clip(residual + skip, -1, 1)[..., :output_length]
# ---------------------------------------------------------------------------
# Factory / from_pretrained
# ---------------------------------------------------------------------------
def load_vocoder(model_path: Path) -> nn.Module:
"""Load vocoder from pretrained model directory.
Automatically detects whether to load a simple Vocoder or VocoderWithBWE.
"""
import json
config_path = model_path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {model_path}")
with open(config_path) as f:
config_dict = json.load(f)
weights = mx.load(str(model_path / "model.safetensors"))
has_bwe = config_dict.get("has_bwe_generator", False)
if has_bwe:
return _load_vocoder_with_bwe(config_dict, weights)
else:
config = VocoderModelConfig.from_dict(config_dict)
model = Vocoder(config)
model.load_weights(list(weights.items()), strict=True)
return model
def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
"""Load VocoderWithBWE from config and weights."""
# Build vocoder from config
vocoder_cfg = config_dict.get("vocoder", {})
vocoder_config = VocoderModelConfig.from_dict(vocoder_cfg)
vocoder = Vocoder(vocoder_config)
# Build BWE generator from config
bwe_cfg = config_dict.get("bwe", {})
bwe_config = VocoderModelConfig.from_dict(bwe_cfg)
bwe_config.apply_final_activation = False
bwe_generator = Vocoder(bwe_config)
# MelSTFT from weight shapes
stft_basis = weights.get("mel_stft.stft_fn.forward_basis")
filter_length = stft_basis.shape[2] if stft_basis is not None else 512
mel_basis = weights.get("mel_stft.mel_basis")
n_mel_channels = mel_basis.shape[0] if mel_basis is not None else 64
hop_length = bwe_cfg.get("hop_length", 80)
input_sr = bwe_cfg.get("input_sampling_rate", 16000)
output_sr = bwe_cfg.get("output_sampling_rate", 48000)
mel_stft = MelSTFT(
filter_length=filter_length,
hop_length=hop_length,
win_length=filter_length,
n_mel_channels=n_mel_channels,
)
model = VocoderWithBWE(
vocoder=vocoder,
bwe_generator=bwe_generator,
mel_stft=mel_stft,
input_sampling_rate=input_sr,
output_sampling_rate=output_sr,
hop_length=hop_length,
)
model.load_weights(list(weights.items()), strict=False)
return model

View File

@@ -0,0 +1,3 @@
"""Conditioning modules for LTX-2 video generation."""
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning

View File

@@ -0,0 +1,199 @@
"""Latent-based conditioning for I2V (Image-to-Video) generation.
This module provides conditioning that injects encoded image latents into
the video generation process at specific frame positions.
"""
from dataclasses import dataclass
from typing import Optional, List, Tuple
import mlx.core as mx
@dataclass
class VideoConditionByLatentIndex:
"""Condition video generation by injecting latents at a specific frame index.
This replaces the latent at the specified frame index with the conditioned
latent and controls how much denoising is applied via the strength parameter.
Args:
latent: Encoded image latent of shape (B, C, 1, H, W)
frame_idx: Frame index to condition (0 = first frame)
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
"""
latent: mx.array
frame_idx: int = 0
strength: float = 1.0
def get_num_latent_frames(self) -> int:
"""Get number of latent frames in the conditioning."""
return self.latent.shape[2]
@dataclass
class LatentState:
"""State for latent diffusion with conditioning support.
Attributes:
latent: Current noisy latent (B, C, F, H, W)
clean_latent: Clean conditioning latent (B, C, F, H, W)
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
1.0 = full denoise, 0.0 = keep clean
"""
latent: mx.array
clean_latent: mx.array
denoise_mask: mx.array
def clone(self) -> "LatentState":
"""Create a copy of the state."""
return LatentState(
latent=self.latent,
clean_latent=self.clean_latent,
denoise_mask=self.denoise_mask,
)
def create_initial_state(
shape: Tuple[int, ...],
seed: Optional[int] = None,
noise_scale: float = 1.0,
) -> LatentState:
"""Create initial noisy latent state.
Args:
shape: Shape of latent (B, C, F, H, W)
seed: Optional random seed
noise_scale: Scale for initial noise (sigma)
Returns:
Initial LatentState with random noise
"""
if seed is not None:
mx.random.seed(seed)
noise = mx.random.normal(shape)
return LatentState(
latent=noise * noise_scale,
clean_latent=mx.zeros(shape),
denoise_mask=mx.ones((shape[0], 1, shape[2], 1, 1)), # Full denoise by default
)
def apply_conditioning(
state: LatentState,
conditionings: List[VideoConditionByLatentIndex],
) -> LatentState:
"""Apply conditioning items to a latent state.
Args:
state: Current latent state
conditionings: List of conditioning items to apply
Returns:
Updated LatentState with conditioning applied
"""
state = state.clone()
dtype = state.latent.dtype
b, c, f, h, w = state.latent.shape
for cond in conditionings:
cond_latent = cond.latent
frame_idx = cond.frame_idx
strength = cond.strength
# Validate shapes
_, cond_c, cond_f, cond_h, cond_w = cond_latent.shape
if (cond_c, cond_h, cond_w) != (c, h, w):
raise ValueError(
f"Conditioning latent spatial shape ({cond_c}, {cond_h}, {cond_w}) "
f"does not match target shape ({c}, {h}, {w})"
)
if frame_idx >= f:
raise ValueError(
f"Frame index {frame_idx} is out of bounds for latent with {f} frames"
)
# Get the conditioning frames count
num_cond_frames = cond_f
end_idx = min(frame_idx + num_cond_frames, f)
# Replace latent at conditioning position
# state.latent[:, :, frame_idx:end_idx] = cond_latent[:, :, :end_idx - frame_idx]
latent_list = []
clean_list = []
mask_list = []
for i in range(f):
if frame_idx <= i < end_idx:
# Use conditioning latent
cond_idx = i - frame_idx
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
# Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
else:
# Keep original
latent_list.append(state.latent[:, :, i:i+1])
clean_list.append(state.clean_latent[:, :, i:i+1])
mask_list.append(state.denoise_mask[:, :, i:i+1])
state.latent = mx.concatenate(latent_list, axis=2)
state.clean_latent = mx.concatenate(clean_list, axis=2)
state.denoise_mask = mx.concatenate(mask_list, axis=2)
return state
def apply_denoise_mask(
denoised: mx.array,
clean: mx.array,
denoise_mask: mx.array,
) -> mx.array:
"""Blend denoised output with clean state based on mask.
Args:
denoised: Denoised latent (B, C, F, H, W)
clean: Clean conditioning latent (B, C, F, H, W)
denoise_mask: Mask where 1.0 = use denoised, 0.0 = use clean
Returns:
Blended latent
"""
one = mx.array(1.0, dtype=denoised.dtype)
return denoised * denoise_mask + clean * (one - denoise_mask)
def add_noise_with_state(
state: LatentState,
noise_scale: float,
) -> LatentState:
"""Add noise to state while respecting conditioning.
For conditioned frames (mask < 1.0), adds noise proportionally
to allow some refinement while preserving the conditioning.
Args:
state: Current latent state
noise_scale: Scale for noise (sigma)
Returns:
Updated state with noise added
"""
state = state.clone()
# Generate noise
noise = mx.random.normal(state.latent.shape)
# For fully conditioned frames (mask=0), we want to add minimal noise
# For unconditioned frames (mask=1), we want full noise
# noisy = noise * sigma + latent * (1 - sigma)
# But we scale sigma by the mask for conditioned regions
effective_scale = noise_scale * state.denoise_mask
one = mx.array(1.0, dtype=state.latent.dtype)
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
return state

View File

@@ -0,0 +1,380 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional, Tuple
class LTXModelType(Enum):
AudioVideo = "ltx av model"
VideoOnly = "ltx video only model"
AudioOnly = "ltx audio only model"
def is_video_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
def is_audio_enabled(self) -> bool:
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
class LTXRopeType(Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
TWO_D = "2d"
class AttentionType(Enum):
DEFAULT = "default"
@dataclass
class BaseModelConfig:
@classmethod
def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig":
"""Create config from dictionary, filtering only valid parameters."""
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def to_dict(self) -> dict[str, Any]:
"""Export config to dictionary."""
result = {}
for k, v in self.__dict__.items():
if v is not None:
if isinstance(v, Enum):
result[k] = v.value
elif hasattr(v, 'to_dict'):
result[k] = v.to_dict()
else:
result[k] = v
return result
@dataclass
class TransformerConfig(BaseModelConfig):
dim: int
heads: int
d_head: int
context_dim: int
@dataclass
class VideoVAEConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels: int = 3
out_channels: int = 128
latent_channels: int = 128
patch_size: int = 4
encoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
])
decoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
])
@dataclass
class LTXModelConfig(BaseModelConfig):
# Model type
model_type: LTXModelType = LTXModelType.AudioVideo
# Video transformer config
num_attention_heads: int = 32
attention_head_dim: int = 128
in_channels: int = 128
out_channels: int = 128
num_layers: int = 48
cross_attention_dim: int = 4096
caption_channels: int = 3840
# Audio transformer config
audio_num_attention_heads: int = 32
audio_attention_head_dim: int = 64
audio_in_channels: int = 128
audio_out_channels: int = 128
audio_cross_attention_dim: int = 2048
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
# Positional embedding config
positional_embedding_theta: float = 10000.0
positional_embedding_max_pos: Optional[List[int]] = None
audio_positional_embedding_max_pos: Optional[List[int]] = None
use_middle_indices_grid: bool = True
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED
double_precision_rope: bool = False
# Timestep config
timestep_scale_multiplier: int = 1000
av_ca_timestep_scale_multiplier: int = 1000
# Normalization
norm_eps: float = 1e-6
# Attention type
attention_type: AttentionType = AttentionType.DEFAULT
# LTX-2.3: prompt-conditioned adaptive layer norm
# Controls: gate_logits in attention, 9-param scale_shift_table,
# prompt_adaln_single, per-block prompt_scale_shift_table,
# removal of caption_projection
has_prompt_adaln: bool = False
# VAE config
vae_config: Optional[VideoVAEConfig] = None
def __post_init__(self):
"""Set default values after initialization."""
if self.positional_embedding_max_pos is None:
self.positional_embedding_max_pos = [20, 2048, 2048]
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20]
# PyTorch LTX-2 configurator reads "frequencies_precision" (not
# "double_precision_rope") from the config. For LTX-2 (no prompt adaln)
# the key is absent, so double_precision_rope = False. For LTX-2.3
# (has_prompt_adaln=True) the safetensors config has
# frequencies_precision="float64", so double_precision_rope = True.
if not self.has_prompt_adaln:
self.double_precision_rope = False
# Convert string enum values if loading from dict
if isinstance(self.model_type, str):
self.model_type = LTXModelType(self.model_type)
if isinstance(self.rope_type, str):
self.rope_type = LTXRopeType(self.rope_type)
if isinstance(self.attention_type, str):
self.attention_type = AttentionType(self.attention_type)
@property
def inner_dim(self) -> int:
"""Video inner dimension."""
return self.num_attention_heads * self.attention_head_dim
@property
def audio_inner_dim(self) -> int:
"""Audio inner dimension."""
return self.audio_num_attention_heads * self.audio_attention_head_dim
def get_video_config(self) -> Optional[TransformerConfig]:
"""Get video transformer configuration."""
if not self.model_type.is_video_enabled():
return None
return TransformerConfig(
dim=self.inner_dim,
heads=self.num_attention_heads,
d_head=self.attention_head_dim,
context_dim=self.cross_attention_dim,
)
def get_audio_config(self) -> Optional[TransformerConfig]:
"""Get audio transformer configuration."""
if not self.model_type.is_audio_enabled():
return None
return TransformerConfig(
dim=self.audio_inner_dim,
heads=self.audio_num_attention_heads,
d_head=self.audio_attention_head_dim,
context_dim=self.audio_cross_attention_dim,
)
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
@dataclass
class AudioDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int | None = None
resamp_with_conv: bool = True
attn_type: str = None
give_pre_end: bool = False
tanh_out: bool = False
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
# Import here to avoid circular imports
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
# Convert causality_axis string to enum
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
# Convert norm_type string to enum
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
# Convert attn_type string to enum
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class AudioEncoderModelConfig(BaseModelConfig):
ch: int = 128
in_channels: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
double_z: bool = True
n_fft: int = 1024
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int = 64
resamp_with_conv: bool = True
attn_type: str = None
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class VocoderModelConfig(BaseModelConfig):
resblock_kernel_sizes: Optional[List[int]] = None
upsample_rates: Optional[List[int]] = None
upsample_kernel_sizes: Optional[List[int]] = None
resblock_dilation_sizes: Optional[List[List[int]]] = None
upsample_initial_channel: int = 1024
stereo: bool = True
resblock: str = "1"
output_sample_rate: int = 24000
activation: str = "snake"
use_tanh_at_final: bool = True
apply_final_activation: bool = True
use_bias_at_final: bool = True
def __post_init__(self):
if self.resblock_kernel_sizes is None:
self.resblock_kernel_sizes = [3, 7, 11]
if self.upsample_rates is None:
self.upsample_rates = [6, 5, 2, 2, 2]
if self.upsample_kernel_sizes is None:
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
if self.resblock_dilation_sizes is None:
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@dataclass
class VideoDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
timestep_conditioning: bool = False
@dataclass
class VideoEncoderModelConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels: int = 3
out_channels: int = 128
patch_size: int = 4
norm_layer: Enum = None
latent_log_var: Enum = None
encoder_spatial_padding_mode: Enum = None
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2})
])
def __post_init__(self):
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM
if self.latent_log_var is None:
self.latent_log_var = LogVarianceType.UNIFORM
if self.encoder_spatial_padding_mode is None:
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
if isinstance(self.norm_layer, str):
self.norm_layer = NormLayerType(self.norm_layer)
if isinstance(self.latent_log_var, str):
self.latent_log_var = LogVarianceType(self.latent_log_var)
if isinstance(self.encoder_spatial_padding_mode, str):
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.encoder_blocks is not None:
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
return result

View File

@@ -0,0 +1,785 @@
"""Convert LTX-2/2.3 safetensors to MLX directory layout.
Converts from the single-file format (e.g. Lightricks/LTX-2/ltx-2-19b-distilled.safetensors
or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular directory structure:
output/
├── transformer/ # DiT transformer weights (sharded)
│ ├── config.json
│ ├── model-00001-of-N.safetensors
│ └── model.safetensors.index.json
├── vae/
│ ├── decoder/ # Video VAE decoder
│ │ ├── config.json
│ │ └── model.safetensors
│ └── encoder/ # Video VAE encoder
│ ├── config.json
│ └── model.safetensors
├── audio_vae/ # Audio VAE decoder
│ ├── config.json
│ └── model.safetensors
├── vocoder/ # Audio vocoder
│ ├── config.json
│ └── model.safetensors
└── text_projections/ # Text projection connectors
└── model.safetensors
Usage:
# From HF repo ID
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
# From local folder containing the monolithic safetensors
python -m mlx_video.models.ltx_2.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
# From a direct safetensors file path
python -m mlx_video.models.ltx_2.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
"""
import argparse
import json
import re
import shutil
from pathlib import Path
from typing import Dict
import mlx.core as mx
# ─── Key prefix routing ──────────────────────────────────────────────────────
TRANSFORMER_PREFIX = "model.diffusion_model."
VAE_DECODER_PREFIX = "vae.decoder."
VAE_ENCODER_PREFIX = "vae.encoder."
VAE_STATS_PREFIX = "vae.per_channel_statistics."
AUDIO_DECODER_PREFIX = "audio_vae.decoder."
AUDIO_ENCODER_PREFIX = "audio_vae.encoder."
AUDIO_STATS_PREFIX = "audio_vae.per_channel_statistics."
VOCODER_PREFIX = "vocoder."
TEXT_PROJ_PREFIX = "text_embedding_projection."
VIDEO_CONNECTOR_PREFIX = "model.diffusion_model.video_embeddings_connector."
AUDIO_CONNECTOR_PREFIX = "model.diffusion_model.audio_embeddings_connector."
# ─── Sanitization functions ──────────────────────────────────────────────────
def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize transformer keys: strip prefix, rename layers, cast to bfloat16."""
sanitized = {}
for key, value in weights.items():
if not key.startswith(TRANSFORMER_PREFIX):
continue
# Skip connector weights (they go to text_projections)
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
new_key = key[len(TRANSFORMER_PREFIX):]
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Cast all weights to bfloat16 (matches MLX model loading behavior)
if value.dtype != mx.bfloat16:
value = value.astype(mx.bfloat16)
sanitized[new_key] = value
return sanitized
def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE decoder keys: strip prefix, transpose Conv3d, wrap .conv."""
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith(VAE_STATS_PREFIX):
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
elif key.startswith(VAE_DECODER_PREFIX):
new_key = key[len(VAE_DECODER_PREFIX):]
else:
continue
# Conv3d weight transpose: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Wrap .conv.weight -> .conv.conv.weight (CausalConv3d wrapper)
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder keys: strip prefix, transpose Conv3d/Conv2d."""
sanitized = {}
for key, value in weights.items():
new_key = None
if "position_ids" in key:
continue
if key.startswith(VAE_STATS_PREFIX):
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
# Per-channel statistics must stay float32 for precision
if value.dtype != mx.float32:
value = value.astype(mx.float32)
elif key.startswith(VAE_ENCODER_PREFIX):
new_key = key[len(VAE_ENCODER_PREFIX):]
else:
continue
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE decoder keys: strip prefix, transpose Conv2d."""
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith(AUDIO_DECODER_PREFIX):
new_key = key[len(AUDIO_DECODER_PREFIX):]
elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue
else:
continue
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize vocoder keys: strip prefix, transpose Conv1d/ConvTranspose1d."""
sanitized = {}
for key, value in weights.items():
if not key.startswith(VOCODER_PREFIX):
continue
new_key = key[len(VOCODER_PREFIX):]
# Handle Conv1d/ConvTranspose1d weight shape conversion
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
def sanitize_connector_key(key: str) -> str:
"""Sanitize connector sub-key names."""
key = key.replace(".ff.net.0.proj.", ".ff.proj_in.")
key = key.replace(".ff.net.2.", ".ff.proj_out.")
key = key.replace(".to_out.0.", ".to_out.")
return key
def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Extract text projection weights (aggregate_embed + connectors).
Handles both LTX-2 (aggregate_embed.weight) and LTX-2.3
(video_aggregate_embed.*, audio_aggregate_embed.*) formats.
"""
extracted = {}
# aggregate_embed weights (text_embedding_projection.*)
for key, value in weights.items():
if key.startswith(TEXT_PROJ_PREFIX):
new_key = key[len(TEXT_PROJ_PREFIX):]
extracted[new_key] = value
# video_embeddings_connector
for key, value in weights.items():
if key.startswith(VIDEO_CONNECTOR_PREFIX):
suffix = key[len(VIDEO_CONNECTOR_PREFIX):]
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value
# audio_embeddings_connector
for key, value in weights.items():
if key.startswith(AUDIO_CONNECTOR_PREFIX):
suffix = key[len(AUDIO_CONNECTOR_PREFIX):]
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value
return extracted
# ─── Saving utilities ─────────────────────────────────────────────────────────
def save_sharded(
weights: Dict[str, mx.array],
output_dir: Path,
max_shard_size_bytes: int = 5 * 1024 * 1024 * 1024, # 5GB per shard
):
"""Save weights as sharded safetensors with an index file."""
output_dir.mkdir(parents=True, exist_ok=True)
# Sort keys for deterministic output
sorted_keys = sorted(weights.keys())
# Calculate total size
total_size = sum(weights[k].nbytes for k in sorted_keys)
# Determine sharding
shards = []
current_shard = {}
current_size = 0
for key in sorted_keys:
tensor = weights[key]
tensor_size = tensor.nbytes
if current_size + tensor_size > max_shard_size_bytes and current_shard:
shards.append(current_shard)
current_shard = {}
current_size = 0
current_shard[key] = tensor
current_size += tensor_size
if current_shard:
shards.append(current_shard)
num_shards = len(shards)
weight_map = {}
for i, shard in enumerate(shards):
if num_shards == 1:
filename = "model.safetensors"
else:
filename = f"model-{i+1:05d}-of-{num_shards:05d}.safetensors"
mx.save_safetensors(str(output_dir / filename), shard)
for key in shard:
weight_map[key] = filename
# Write index
index = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
with open(output_dir / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2, sort_keys=True)
return num_shards
def save_single(weights: Dict[str, mx.array], output_dir: Path):
"""Save weights as a single safetensors file with an index."""
output_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(output_dir / "model.safetensors"), weights)
# Also write index for consistency
total_size = sum(v.nbytes for v in weights.values())
weight_map = {k: "model.safetensors" for k in sorted(weights.keys())}
index = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
with open(output_dir / "model.safetensors.index.json", "w") as f:
json.dump(index, f, indent=2, sort_keys=True)
def save_config(config: dict, output_dir: Path):
"""Save config.json to a directory."""
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=4)
f.write("\n")
# ─── Source resolution ─────────────────────────────────────────────────────────
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$")
def resolve_source(source: str, variant: str) -> Path:
"""Resolve source to a monolithic safetensors file path.
Args:
source: HF repo ID (e.g. "Lightricks/LTX-2"), local directory, or direct file path.
variant: Model variant ("distilled" or "dev") to select the right file.
Returns:
Path to the monolithic safetensors file.
"""
source_path = Path(source)
# Direct file path
if source_path.is_file():
return source_path
# Local directory — find the variant's safetensors file
if source_path.is_dir():
matches = []
for f in sorted(source_path.glob("ltx-*b-*.safetensors")):
m = MONOLITHIC_PATTERN.match(f.name)
if m and m.group("variant") == variant:
matches.append(f)
if matches:
return matches[0]
# Broader fallback
all_mono = sorted(source_path.glob("ltx-*.safetensors"))
for f in all_mono:
if variant in f.name and MONOLITHIC_PATTERN.match(f.name):
return f
raise FileNotFoundError(
f"No monolithic *-{variant}.safetensors found in {source_path}. "
f"Files found: {[f.name for f in all_mono]}"
)
# HF repo ID — download via huggingface_hub
if "/" in source and not source_path.exists():
from huggingface_hub import hf_hub_download, list_repo_files
# Find the right file in the repo
repo_files = list_repo_files(source)
target = None
for f in repo_files:
m = MONOLITHIC_PATTERN.match(f)
if m and m.group("variant") == variant:
target = f
break
if not target:
raise FileNotFoundError(
f"No *-{variant}.safetensors found in {source}. "
f"Available: {[f for f in repo_files if f.endswith('.safetensors')]}"
)
print(f"Downloading {target} from {source}...")
local_path = hf_hub_download(repo_id=source, filename=target)
return Path(local_path)
raise FileNotFoundError(
f"Source not found: {source}. Provide an HF repo ID, local directory, or file path."
)
# ─── Config inference ─────────────────────────────────────────────────────────
def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
"""Infer transformer config from weight shapes."""
# Count transformer layers
max_layer = -1
for key in weights:
if "transformer_blocks." in key:
parts = key.split(".")
try:
idx = parts.index("transformer_blocks") + 1
if idx < len(parts) and parts[idx].isdigit():
max_layer = max(max_layer, int(parts[idx]))
except ValueError:
pass
num_layers = max_layer + 1 if max_layer >= 0 else 48
# Detect cross_attention_dim from attn2.to_k (cross-attention input dim)
cross_attention_dim = 4096
for key, value in weights.items():
if "transformer_blocks.0.attn2.to_k.weight" in key:
cross_attention_dim = value.shape[-1]
break
# Check for prompt_adaln_single (LTX-2.3 feature)
has_prompt_adaln = any("prompt_adaln_single" in k for k in weights)
config = {
"attention_head_dim": 128,
"attention_type": "default",
"audio_attention_head_dim": 64,
"audio_caption_channels": 3840,
"audio_cross_attention_dim": 2048,
"audio_in_channels": 128,
"audio_num_attention_heads": 32,
"audio_out_channels": 128,
"audio_positional_embedding_max_pos": [20],
"av_ca_timestep_scale_multiplier": 1000,
"caption_channels": 3840,
"cross_attention_dim": cross_attention_dim,
"double_precision_rope": True,
"in_channels": 128,
"model_type": "ltx av model",
"norm_eps": 1e-06,
"num_attention_heads": 32,
"num_layers": num_layers,
"out_channels": 128,
"positional_embedding_max_pos": [20, 2048, 2048],
"positional_embedding_theta": 10000.0,
"rope_type": "split",
"timestep_scale_multiplier": 1000,
"use_middle_indices_grid": True,
}
if has_prompt_adaln:
config["has_prompt_adaln"] = True
return config
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
"""Infer VAE decoder config from weights."""
# Check for timestep conditioning keys
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
# Count channel multipliers from up_blocks
max_block = -1
for key in weights:
if "up_blocks." in key:
parts = key.split(".")
try:
idx = parts.index("up_blocks") + 1
if idx < len(parts) and parts[idx].isdigit():
max_block = max(max_block, int(parts[idx]))
except ValueError:
pass
# Default config
config = {
"ch": 128,
"ch_mult": [1, 2, 4],
"dropout": 0.0,
"num_res_blocks": 2,
"out_ch": 2,
"resolution": 256,
"timestep_conditioning": has_timestep,
"z_channels": 8,
}
return config
def infer_vae_encoder_config(weights: Dict[str, mx.array]) -> dict:
"""Return VAE encoder config (architecture is consistent across versions)."""
return {
"convolution_dimensions": 3,
"encoder_blocks": [
["res_x", {"num_layers": 4}],
["compress_space_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_time_res", {"multiplier": 2}],
["res_x", {"num_layers": 6}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
["compress_all_res", {"multiplier": 2}],
["res_x", {"num_layers": 2}],
],
"encoder_spatial_padding_mode": "zeros",
"in_channels": 3,
"latent_log_var": "uniform",
"norm_layer": "pixel_norm",
"out_channels": 128,
"patch_size": 4,
}
def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict:
"""Return audio VAE config."""
return {
"attn_resolutions": [],
"attn_type": "vanilla",
"causality_axis": "height",
"ch": 128,
"ch_mult": [1, 2, 4],
"dropout": 0.0,
"give_pre_end": False,
"is_causal": True,
"mel_bins": 64,
"mel_hop_length": 160,
"mid_block_add_attention": False,
"norm_type": "pixel",
"num_res_blocks": 2,
"out_ch": 2,
"resamp_with_conv": True,
"resolution": 256,
"sample_rate": 16000,
"tanh_out": False,
"z_channels": 8,
}
def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict:
"""Infer vocoder config from weights."""
# Check for bwe_generator (LTX-2.3 BigVGAN vocoder)
has_bwe = any(k.startswith("bwe_generator") for k in weights)
if has_bwe:
return {
"type": "bigvgan",
"has_bwe_generator": True,
}
return {
"output_sample_rate": 24000,
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"resblock_kernel_sizes": [3, 7, 11],
"stereo": True,
"upsample_initial_channel": 1024,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_rates": [6, 5, 2, 2, 2],
}
# ─── Main ─────────────────────────────────────────────────────────────────────
def convert(source: str, output_path: Path, variant: str = "distilled"):
"""Convert monolithic safetensors to modular directory layout.
Args:
source: HF repo ID (e.g. "Lightricks/LTX-2"), local directory, or file path.
output_path: Output directory for the modular layout.
variant: "distilled" or "dev".
"""
source_path = resolve_source(source, variant)
print(f"Loading monolithic weights from {source_path.name}...")
all_weights = mx.load(str(source_path))
total_keys = len(all_weights)
print(f" Loaded {total_keys} keys")
# Route keys to components
print("\nExtracting components...")
# 1. Transformer
print(" [1/6] Transformer...")
transformer_weights = sanitize_transformer(all_weights)
num_shards = save_sharded(transformer_weights, output_path / "transformer")
config = infer_transformer_config(transformer_weights)
save_config(config, output_path / "transformer")
t_params = sum(v.size for v in transformer_weights.values())
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
# 2. VAE Decoder
print(" [2/6] VAE Decoder...")
vae_decoder_weights = sanitize_vae_decoder(all_weights)
save_single(vae_decoder_weights, output_path / "vae" / "decoder")
config = infer_vae_decoder_config(vae_decoder_weights, variant)
save_config(config, output_path / "vae" / "decoder")
d_params = sum(v.size for v in vae_decoder_weights.values())
print(f" {len(vae_decoder_weights)} keys, {d_params:,} params")
# 3. VAE Encoder
print(" [3/6] VAE Encoder...")
vae_encoder_weights = sanitize_vae_encoder(all_weights)
save_single(vae_encoder_weights, output_path / "vae" / "encoder")
config = infer_vae_encoder_config(vae_encoder_weights)
save_config(config, output_path / "vae" / "encoder")
e_params = sum(v.size for v in vae_encoder_weights.values())
print(f" {len(vae_encoder_weights)} keys, {e_params:,} params")
# 4. Audio VAE Decoder
print(" [4/6] Audio VAE Decoder...")
audio_decoder_weights = sanitize_audio_decoder(all_weights)
save_single(audio_decoder_weights, output_path / "audio_vae")
config = infer_audio_vae_config(audio_decoder_weights)
save_config(config, output_path / "audio_vae")
a_params = sum(v.size for v in audio_decoder_weights.values())
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
# 5. Vocoder
print(" [5/6] Vocoder...")
vocoder_weights = sanitize_vocoder(all_weights)
save_single(vocoder_weights, output_path / "vocoder")
config = infer_vocoder_config(vocoder_weights)
save_config(config, output_path / "vocoder")
v_params = sum(v.size for v in vocoder_weights.values())
print(f" {len(vocoder_weights)} keys, {v_params:,} params")
# 6. Text Projections
print(" [6/6] Text Projections...")
text_proj_weights = extract_text_projections(all_weights)
tp_dir = output_path / "text_projections"
tp_dir.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(tp_dir / "model.safetensors"), text_proj_weights)
tp_params = sum(v.size for v in text_proj_weights.values())
print(f" {len(text_proj_weights)} keys, {tp_params:,} params")
# 7. Copy upscaler files
print("\nCopying upscaler files...")
source_dir = source_path.parent
is_hf_repo = "/" in source and not Path(source).exists()
upscaler_files = []
if is_hf_repo:
from huggingface_hub import list_repo_files
upscaler_files = [
f for f in list_repo_files(source) if UPSCALER_PATTERN.match(f)
]
else:
upscaler_files = [
f.name for f in source_dir.iterdir()
if f.is_file() and UPSCALER_PATTERN.match(f.name)
]
if not upscaler_files:
print(" No upscaler files found")
for upscaler_file in sorted(upscaler_files):
dest = output_path / upscaler_file
if dest.exists():
print(f" {upscaler_file}: already exists, skipping")
continue
local_candidate = source_dir / upscaler_file
if local_candidate.is_file():
shutil.copy2(str(local_candidate), str(dest))
print(f" {upscaler_file}: copied")
elif is_hf_repo:
from huggingface_hub import hf_hub_download
print(f" {upscaler_file}: downloading from {source}...")
downloaded = hf_hub_download(repo_id=source, filename=upscaler_file)
shutil.copy2(downloaded, str(dest))
print(f" {upscaler_file}: done")
else:
print(f" {upscaler_file}: not found, skipping")
# 8. Link text_encoder and tokenizer directories
print("\nLinking text encoder & tokenizer...")
for subdir in ["text_encoder", "tokenizer"]:
dest = output_path / subdir
if dest.exists():
print(f" {subdir}/: already exists, skipping")
continue
local_candidate = source_dir / subdir
if local_candidate.is_dir():
# Resolve through symlinks to get the real directory
real_path = local_candidate.resolve()
dest.symlink_to(real_path)
print(f" {subdir}/: symlinked to {real_path}")
elif is_hf_repo:
from huggingface_hub import list_repo_files, snapshot_download
# Only download if the subdir exists in the repo
repo_files = list_repo_files(source)
if any(f.startswith(f"{subdir}/") for f in repo_files):
print(f" {subdir}/: downloading from {source}...")
snapshot_download(
repo_id=source,
allow_patterns=f"{subdir}/*",
local_dir=str(output_path),
)
print(f" {subdir}/: done")
else:
print(f" {subdir}/: not in repo, skipping")
else:
print(f" {subdir}/: not found in source, skipping")
# Summary
all_converted = (
len(transformer_weights)
+ len(vae_decoder_weights)
+ len(vae_encoder_weights)
+ len(audio_decoder_weights)
+ len(vocoder_weights)
+ len(text_proj_weights)
)
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
if all_converted < total_keys:
# Find unconverted keys
converted_prefixes = set()
for key in all_weights:
if key.startswith(TRANSFORMER_PREFIX):
converted_prefixes.add(key)
elif key.startswith(VAE_DECODER_PREFIX) or key.startswith(VAE_STATS_PREFIX):
converted_prefixes.add(key)
elif key.startswith(VAE_ENCODER_PREFIX):
converted_prefixes.add(key)
elif key.startswith(AUDIO_DECODER_PREFIX) or key.startswith(AUDIO_STATS_PREFIX):
converted_prefixes.add(key)
elif key.startswith(AUDIO_ENCODER_PREFIX):
converted_prefixes.add(key)
elif key.startswith(VOCODER_PREFIX):
converted_prefixes.add(key)
elif key.startswith(TEXT_PROJ_PREFIX):
converted_prefixes.add(key)
elif key.startswith(VIDEO_CONNECTOR_PREFIX) or key.startswith(AUDIO_CONNECTOR_PREFIX):
converted_prefixes.add(key)
skipped = set(all_weights.keys()) - converted_prefixes
if skipped:
print(f" Skipped {len(skipped)} keys:")
for k in sorted(skipped)[:20]:
print(f" {k}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert monolithic LTX-2/2.3 safetensors to modular MLX layout"
)
parser.add_argument(
"--source",
type=str,
required=True,
help="HF repo ID (e.g. Lightricks/LTX-2, Lightricks/LTX-2.3), local directory, or direct safetensors file path",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Output directory for modular layout",
)
parser.add_argument(
"--variant",
type=str,
choices=["distilled", "dev"],
default="distilled",
help="Model variant (affects VAE decoder config and which file to download)",
)
args = parser.parse_args()
convert(args.source, Path(args.output), variant=args.variant)

View File

@@ -0,0 +1,40 @@
import mlx.core as mx
import mlx.nn as nn
class GELU(nn.Module):
def __init__(self, approximate: str = "tanh"):
super().__init__()
self.approximate = approximate
def __call__(self, x: mx.array) -> mx.array:
if self.approximate == "tanh":
return nn.gelu_approx(x)
else:
return nn.gelu(x)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: int | None = None,
mult: int = 4,
bias: bool = True,
):
super().__init__()
dim_out = dim_out or dim
inner_dim = int(dim * mult)
self.proj_in = nn.Linear(dim, inner_dim, bias=bias)
self.act = GELU(approximate="tanh")
self.proj_out = nn.Linear(inner_dim, dim_out, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.proj_in(x)
x = self.act(x)
x = self.proj_out(x)
return x

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,651 @@
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
LTXModelType,
LTXRopeType,
TransformerConfig,
)
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx_2.transformer import (
BasicAVTransformerBlock,
Modality,
TransformerArgs,
)
from mlx_video.utils import to_denoised
class TransformerArgsPreprocessor:
def __init__(
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: Optional[PixArtAlphaTextProjection],
inner_dim: int,
max_pos: List[int],
num_attention_heads: int,
use_middle_indices_grid: bool,
timestep_scale_multiplier: int,
positional_embedding_theta: float,
rope_type: LTXRopeType,
double_precision_rope: bool = False,
prompt_adaln: Optional[AdaLayerNormSingle] = None,
):
self.patchify_proj = patchify_proj
self.adaln = adaln
self.caption_projection = caption_projection
self.prompt_adaln = prompt_adaln
self.inner_dim = inner_dim
self.max_pos = max_pos
self.num_attention_heads = num_attention_heads
self.use_middle_indices_grid = use_middle_indices_grid
self.timestep_scale_multiplier = timestep_scale_multiplier
self.positional_embedding_theta = positional_embedding_theta
self.rope_type = rope_type
self.double_precision_rope = double_precision_rope
def _prepare_timestep(
self,
timestep: mx.array,
batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
# Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
return timestep_emb, embedded_timestep
def _prepare_timestep_with_adaln(
self,
adaln: AdaLayerNormSingle,
timestep: mx.array,
batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
return timestep_emb, embedded_timestep
def _prepare_context(
self,
context: mx.array,
x: mx.array,
attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
batch_size = x.shape[0]
if self.caption_projection is not None:
context = self.caption_projection(context)
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
return context, attention_mask
def _prepare_attention_mask(
self,
attention_mask: Optional[mx.array],
x_dtype: mx.Dtype,
) -> Optional[mx.array]:
if attention_mask is None:
return None
# Check if already float
if attention_mask.dtype in [mx.float16, mx.float32, mx.bfloat16]:
return attention_mask
# Convert boolean/int mask to float mask
# 0 -> -inf (masked), 1 -> 0 (not masked)
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
return mask
def _prepare_positional_embeddings(
self,
positions: mx.array,
inner_dim: int,
max_pos: List[int],
use_middle_indices_grid: bool,
num_attention_heads: int,
) -> Tuple[mx.array, mx.array]:
pe = precompute_freqs_cis(
positions,
dim=inner_dim,
theta=self.positional_embedding_theta,
max_pos=max_pos,
use_middle_indices_grid=use_middle_indices_grid,
num_attention_heads=num_attention_heads,
rope_type=self.rope_type,
double_precision=self.double_precision_rope,
)
return pe
def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
if modality.positional_embeddings is not None:
pe = modality.positional_embeddings
else:
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
# Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps
prompt_timestep = None
prompt_embedded_timestep = None
if self.prompt_adaln is not None and modality.sigma is not None:
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
)
return TransformerArgs(
x=x,
context=context,
context_mask=attention_mask,
timesteps=timestep,
embedded_timestep=embedded_timestep,
positional_embeddings=pe,
cross_positional_embeddings=None,
cross_scale_shift_timestep=None,
cross_gate_timestep=None,
enabled=modality.enabled,
prompt_timesteps=prompt_timestep,
prompt_embedded_timestep=prompt_embedded_timestep,
)
class MultiModalTransformerArgsPreprocessor:
def __init__(
self,
patchify_proj: nn.Linear,
adaln: AdaLayerNormSingle,
caption_projection: Optional[PixArtAlphaTextProjection],
cross_scale_shift_adaln: AdaLayerNormSingle,
cross_gate_adaln: AdaLayerNormSingle,
inner_dim: int,
max_pos: List[int],
num_attention_heads: int,
cross_pe_max_pos: int,
use_middle_indices_grid: bool,
audio_cross_attention_dim: int,
timestep_scale_multiplier: int,
positional_embedding_theta: float,
rope_type: LTXRopeType,
av_ca_timestep_scale_multiplier: int,
double_precision_rope: bool = False,
prompt_adaln: Optional[AdaLayerNormSingle] = None,
):
self.simple_preprocessor = TransformerArgsPreprocessor(
patchify_proj=patchify_proj,
adaln=adaln,
caption_projection=caption_projection,
inner_dim=inner_dim,
max_pos=max_pos,
num_attention_heads=num_attention_heads,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
positional_embedding_theta=positional_embedding_theta,
rope_type=rope_type,
double_precision_rope=double_precision_rope,
prompt_adaln=prompt_adaln,
)
self.cross_scale_shift_adaln = cross_scale_shift_adaln
self.cross_gate_adaln = cross_gate_adaln
self.cross_pe_max_pos = cross_pe_max_pos
self.audio_cross_attention_dim = audio_cross_attention_dim
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
def prepare(self, modality: Modality) -> TransformerArgs:
from dataclasses import replace
transformer_args = self.simple_preprocessor.prepare(modality)
# Prepare cross-modal positional embeddings
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
positions=modality.positions[:, 0:1, :],
inner_dim=self.audio_cross_attention_dim,
max_pos=[self.cross_pe_max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.simple_preprocessor.num_attention_heads,
)
# Prepare cross-attention timestep embeddings
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
)
return replace(
transformer_args,
cross_positional_embeddings=cross_pe,
cross_scale_shift_timestep=cross_scale_shift_timestep,
cross_gate_timestep=cross_gate_timestep,
)
def _prepare_cross_attention_timestep(
self,
timestep: mx.array,
timestep_scale_multiplier: int,
batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * timestep_scale_multiplier
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
return scale_shift_timestep, gate_timestep
class LTXModel(nn.Module):
def __init__(self, config: LTXModelConfig):
super().__init__()
self.config = config
self.model_type = config.model_type
self.use_middle_indices_grid = config.use_middle_indices_grid
self.rope_type = config.rope_type
self.timestep_scale_multiplier = config.timestep_scale_multiplier
self.positional_embedding_theta = config.positional_embedding_theta
cross_pe_max_pos = None
if config.model_type.is_video_enabled():
self.positional_embedding_max_pos = config.positional_embedding_max_pos
self.num_attention_heads = config.num_attention_heads
self.inner_dim = config.inner_dim
self._init_video(config)
if config.model_type.is_audio_enabled():
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
self.audio_num_attention_heads = config.audio_num_attention_heads
self.audio_inner_dim = config.audio_inner_dim
self._init_audio(config)
# Initialize cross-modal components
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
cross_pe_max_pos = max(
config.positional_embedding_max_pos[0],
config.audio_positional_embedding_max_pos[0],
)
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
self.audio_cross_attention_dim = config.audio_cross_attention_dim
self._init_audio_video(config)
self._init_preprocessors(config, cross_pe_max_pos)
self._init_transformer_blocks(config)
def _init_video(self, config: LTXModelConfig) -> None:
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
if config.has_prompt_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
hidden_size=self.inner_dim,
)
self.scale_shift_table = mx.zeros((2, self.inner_dim))
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
if config.has_prompt_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels,
hidden_size=self.audio_inner_dim,
)
# Output components
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
def _init_audio_video(self, config: LTXModelConfig) -> None:
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
embedding_coefficient=num_scale_shift_values,
)
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=num_scale_shift_values,
)
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
self.inner_dim,
embedding_coefficient=1,
)
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=1,
)
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
# Multi-modal preprocessors
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=getattr(self, "caption_projection", None),
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
cross_pe_max_pos=cross_pe_max_pos,
use_middle_indices_grid=config.use_middle_indices_grid,
audio_cross_attention_dim=config.audio_cross_attention_dim,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=getattr(self, "audio_caption_projection", None),
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
cross_pe_max_pos=cross_pe_max_pos,
use_middle_indices_grid=config.use_middle_indices_grid,
audio_cross_attention_dim=config.audio_cross_attention_dim,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
elif config.model_type.is_video_enabled():
self.video_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
adaln=self.adaln_single,
caption_projection=getattr(self, "caption_projection", None),
inner_dim=self.inner_dim,
max_pos=config.positional_embedding_max_pos,
num_attention_heads=self.num_attention_heads,
use_middle_indices_grid=config.use_middle_indices_grid,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "prompt_adaln_single", None),
)
elif config.model_type.is_audio_enabled():
self.audio_args_preprocessor = TransformerArgsPreprocessor(
patchify_proj=self.audio_patchify_proj,
adaln=self.audio_adaln_single,
caption_projection=getattr(self, "audio_caption_projection", None),
inner_dim=self.audio_inner_dim,
max_pos=config.audio_positional_embedding_max_pos,
num_attention_heads=self.audio_num_attention_heads,
use_middle_indices_grid=config.use_middle_indices_grid,
timestep_scale_multiplier=config.timestep_scale_multiplier,
positional_embedding_theta=config.positional_embedding_theta,
rope_type=config.rope_type,
double_precision_rope=config.double_precision_rope,
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
)
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
video_config = config.get_video_config()
audio_config = config.get_audio_config()
self.transformer_blocks = {
idx: BasicAVTransformerBlock(
idx=idx,
video=video_config,
audio=audio_config,
rope_type=config.rope_type,
norm_eps=config.norm_eps,
has_prompt_adaln=config.has_prompt_adaln,
)
for idx in range(config.num_layers)
}
def _process_transformer_blocks(
self,
video: Optional[TransformerArgs],
audio: Optional[TransformerArgs],
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks.
Args:
stg_video_blocks: Block indices where video self-attention is skipped (STG).
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
"""
stg_v_set = set(stg_video_blocks) if stg_video_blocks else set()
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
for idx, block in self.transformer_blocks.items():
video, audio = block(
video=video, audio=audio,
skip_video_self_attn=(idx in stg_v_set),
skip_audio_self_attn=(idx in stg_a_set),
skip_cross_modal=skip_cross_modal,
)
return video, audio
def _process_output(
self,
scale_shift_table: mx.array,
norm_out: nn.LayerNorm,
proj_out: nn.Linear,
x: mx.array,
embedded_timestep: mx.array,
) -> mx.array:
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
timestep_expanded = embedded_timestep[:, :, None, :] # (B, 1, 1, dim)
# Combine: (1, 1, 2, dim) + (B, 1, 1, dim) broadcasts to (B, 1, 2, dim)
scale_shift_values = table_expanded + timestep_expanded
# Extract shift and scale (first index is shift, second is scale)
shift = scale_shift_values[:, :, 0, :] # (B, 1, dim)
scale = scale_shift_values[:, :, 1, :] # (B, 1, dim)
x = norm_out(x)
x = x * (1 + scale) + shift # Broadcasts (B, 1, dim) to (B, seq, dim)
x = proj_out(x)
return x
def __call__(
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
"""Forward pass.
Args:
video: Video modality input.
audio: Audio modality input.
stg_video_blocks: Block indices where video self-attention is skipped (STG).
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
"""
# Validate inputs
if not self.model_type.is_video_enabled() and video is not None:
raise ValueError("Video is not enabled for this model")
if not self.model_type.is_audio_enabled() and audio is not None:
raise ValueError("Audio is not enabled for this model")
# Preprocess arguments
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
# Process transformer blocks
video_out, audio_out = self._process_transformer_blocks(
video=video_args,
audio=audio_args,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
# Process outputs
vx = (
self._process_output(
self.scale_shift_table,
self.norm_out,
self.proj_out,
video_out.x,
video_out.embedded_timestep,
)
if video_out is not None
else None
)
ax = (
self._process_output(
self.audio_scale_shift_table,
self.audio_norm_out,
self.audio_proj_out,
audio_out.x,
audio_out.embedded_timestep,
)
if audio_out is not None
else None
)
return vx, ax
def sanitize(self, weights: dict) -> dict:
sanitized = {}
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
if not has_raw_prefix:
return weights
for key, value in weights.items():
new_key = key
if not key.startswith("model.diffusion_model."):
continue
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
# Remove 'model.diffusion_model.' prefix
new_key = new_key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTXModel":
import json
config_dict = {}
with open(model_path / "config.json", "r") as f:
config_dict = json.load(f)
config = LTXModelConfig(**config_dict)
model = cls(config)
weights = {}
for weight_file in model_path.glob("*.safetensors"):
weights.update(mx.load(str(weight_file)))
sanitized = model.sanitize(weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
model.load_weights(list(sanitized.items()), strict=strict)
mx.eval(model.parameters())
model.eval()
return model
class X0Model(nn.Module):
def __init__(self, velocity_model: LTXModel):
super().__init__()
self.velocity_model = velocity_model
def __call__(
self,
video: Optional[Modality] = None,
audio: Optional[Modality] = None,
stg_video_blocks: Optional[List[int]] = None,
stg_audio_blocks: Optional[List[int]] = None,
skip_cross_modal: bool = False,
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model(
video, audio,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
return denoised_video, denoised_audio

View File

@@ -0,0 +1,165 @@
import numpy as np
from typing import Optional
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
Args:
image: Input image as uint8 numpy array (H, W, C)
d: Diameter of each pixel neighborhood
sigma_color: Filter sigma in the color space
sigma_space: Filter sigma in the coordinate space
Returns:
Filtered image
"""
try:
import cv2
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
except ImportError:
# Fallback to simple Gaussian blur if cv2 not available
return gaussian_blur(image, kernel_size=3)
def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
"""Apply Gaussian blur.
Args:
image: Input image as uint8 numpy array (H, W, C)
kernel_size: Size of the Gaussian kernel (must be odd)
Returns:
Blurred image
"""
try:
import cv2
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
except ImportError:
# Simple box blur fallback
from scipy.ndimage import uniform_filter
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
"""Apply unsharp masking to enhance edges after blur.
Args:
image: Input image as uint8 numpy array
kernel_size: Size of the Gaussian kernel
sigma: Gaussian sigma
amount: Strength of sharpening
Returns:
Sharpened image
"""
try:
import cv2
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
return np.clip(sharpened, 0, 255).astype(np.uint8)
except ImportError:
return image
def reduce_grid_artifacts(
video: np.ndarray,
method: str = "bilateral",
strength: float = 1.0,
) -> np.ndarray:
"""Reduce grid artifacts in video frames.
Args:
video: Video as numpy array (F, H, W, C) uint8
method: "bilateral", "gaussian", or "frequency"
strength: How strong to apply the filter (0-1)
Returns:
Processed video
"""
if method == "bilateral":
d = max(3, int(5 * strength))
sigma = 50 + 50 * strength
processed = np.stack([
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
for frame in video
])
elif method == "gaussian":
kernel_size = max(3, int(3 + 4 * strength))
if kernel_size % 2 == 0:
kernel_size += 1
processed = np.stack([
gaussian_blur(frame, kernel_size=kernel_size)
for frame in video
])
elif method == "frequency":
processed = np.stack([
remove_grid_frequency(frame, grid_size=8)
for frame in video
])
else:
raise ValueError(f"Unknown method: {method}")
# Optionally sharpen to recover some detail
if strength < 1.0:
# Blend with original based on strength
alpha = strength
processed = (alpha * processed + (1 - alpha) * video).astype(np.uint8)
return processed
def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
"""Remove grid-frequency components using FFT.
Args:
frame: Input frame (H, W, C) uint8
grid_size: Expected grid periodicity in pixels
Returns:
Filtered frame
"""
result = np.zeros_like(frame)
for c in range(frame.shape[2]):
channel = frame[:, :, c].astype(np.float32)
h, w = channel.shape
# FFT
fft = np.fft.fft2(channel)
fft_shifted = np.fft.fftshift(fft)
# Create notch filter at grid frequencies
cy, cx = h // 2, w // 2
mask = np.ones((h, w), dtype=np.float32)
# Attenuate frequencies at grid periodicity
freq_y = h // grid_size
freq_x = w // grid_size
for fy in range(-2, 3):
for fx in range(-2, 3):
if fy == 0 and fx == 0:
continue
y_pos = cy + fy * freq_y
x_pos = cx + fx * freq_x
if 0 <= y_pos < h and 0 <= x_pos < w:
# Gaussian attenuation around the frequency
for dy in range(-2, 3):
for dx in range(-2, 3):
yy, xx = y_pos + dy, x_pos + dx
if 0 <= yy < h and 0 <= xx < w:
dist = np.sqrt(dy**2 + dx**2)
mask[yy, xx] *= min(1.0, dist / 3.0)
# Apply mask and inverse FFT
fft_filtered = fft_shifted * mask
channel_filtered = np.fft.ifft2(np.fft.ifftshift(fft_filtered)).real
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
return result

View File

@@ -0,0 +1,30 @@
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
#### Guidelines:
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Chronological flow: Use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
#### Important notes:
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
- Format: Never start output with punctuation marks or special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
#### Output Format (Strict):
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
#### Example output:
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.

View File

@@ -0,0 +1,40 @@
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
#### Guidelines
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
- Speech (only when requested):
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
- Specify language if not English and accent if relevant.
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
#### Important notes:
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
- Camera motion: DO NOT invent camera motion unless requested by the user.
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
- Format: DO NOT start your response with special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
#### Output Format (Strict):
- Single continuous paragraph in natural language (English).
- NO titles, headings, prefaces, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
#### Example
Input: "A woman at a coffee shop talking on the phone"
Output:
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.

View File

@@ -0,0 +1,540 @@
import math
from typing import List, Optional, Tuple
import mlx.core as mx
from mlx_video.models.ltx_2.config import LTXRopeType
def apply_rotary_emb(
input_tensor: mx.array,
freqs_cis: Tuple[mx.array, mx.array],
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
) -> mx.array:
"""Apply rotary position embeddings to input tensor.
Args:
input_tensor: Input tensor to apply RoPE to
freqs_cis: Tuple of (cos_freqs, sin_freqs)
rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT)
Returns:
Tensor with rotary embeddings applied
"""
if rope_type == LTXRopeType.INTERLEAVED:
return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
elif rope_type == LTXRopeType.SPLIT:
return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
else:
raise ValueError(f"Invalid rope type: {rope_type}")
def apply_interleaved_rotary_emb(
input_tensor: mx.array,
cos_freqs: mx.array,
sin_freqs: mx.array,
) -> mx.array:
"""Apply interleaved rotary embeddings.
Pairs adjacent dimensions and applies rotation.
Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ...
Args:
input_tensor: Input tensor of shape (..., dim)
cos_freqs: Cosine frequencies
sin_freqs: Sine frequencies
Returns:
Tensor with interleaved rotary embeddings applied
"""
# Compute in float32 for better precision
input_dtype = input_tensor.dtype
input_tensor = input_tensor.astype(mx.float32)
cos_freqs = cos_freqs.astype(mx.float32)
sin_freqs = sin_freqs.astype(mx.float32)
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
shape = input_tensor.shape
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
# Extract pairs
t1 = input_tensor[..., 0] # Even indices
t2 = input_tensor[..., 1] # Odd indices
# Apply rotation: (-t2, t1) pattern
t_rot = mx.stack([-t2, t1], axis=-1)
# Flatten back: (..., dim/2, 2) -> (..., dim)
input_tensor = mx.reshape(input_tensor, shape)
t_rot = mx.reshape(t_rot, shape)
# Apply rotary embeddings
out = input_tensor * cos_freqs + t_rot * sin_freqs
return out.astype(input_dtype)
def rotate_half_interleaved(x: mx.array) -> mx.array:
"""Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].
PyTorch equivalent:
t_dup = rearrange(x, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
return rearrange(t_dup, "... d r -> ... (d r)")
"""
# x: (..., dim) where dim is even
x_even = x[..., 0::2] # [x0, x2, x4, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
rotated = mx.stack([-x_odd, x_even], axis=-1)
return mx.reshape(rotated, x.shape)
def apply_rotary_emb_1d(
q: mx.array,
k: mx.array,
freqs_cis: mx.array,
) -> Tuple[mx.array, mx.array]:
"""Apply 1D rotary embeddings using precomputed frequencies (interleaved)."""
# freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin
cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim)
sin = freqs_cis[..., 1]
# q, k: (batch, seq_len, num_heads, head_dim)
# Interleaved RoPE: pairs of adjacent dims rotate together
q_r = q * cos + rotate_half_interleaved(q) * sin
k_r = k * cos + rotate_half_interleaved(k) * sin
return q_r, k_r
def apply_split_rotary_emb(
input_tensor: mx.array,
cos_freqs: mx.array,
sin_freqs: mx.array,
) -> mx.array:
"""Apply split rotary embeddings.
Splits dimensions into two halves and applies rotation.
Pattern: split into first half and second half
Args:
input_tensor: Input tensor
cos_freqs: Cosine frequencies of shape (B, H, T, D//2)
sin_freqs: Sine frequencies of shape (B, H, T, D//2)
Returns:
Tensor with split rotary embeddings applied
"""
input_dtype = input_tensor.dtype
needs_reshape = False
original_shape = input_tensor.shape
# Handle dimension mismatch
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
b, h, t, _ = cos_freqs.shape
# Reshape from (B, T, H*D) to (B, H, T, D)
input_tensor = mx.reshape(input_tensor, (b, t, h, -1))
input_tensor = mx.swapaxes(input_tensor, 1, 2)
needs_reshape = True
# Cast to float32 for computation precision
input_tensor = input_tensor.astype(mx.float32)
cos_freqs = cos_freqs.astype(mx.float32)
sin_freqs = sin_freqs.astype(mx.float32)
# Split into two halves: (..., dim) -> (..., 2, dim//2)
dim = input_tensor.shape[-1]
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
# Get first and second halves
first_half = split_input[..., 0, :] # (..., dim//2)
second_half = split_input[..., 1, :] # (..., dim//2)
# Apply cosine to both halves
output_first = first_half * cos_freqs
output_second = second_half * cos_freqs
# Apply sine cross-terms (addcmul pattern)
output_first = output_first - sin_freqs * second_half
output_second = output_second + sin_freqs * first_half
# Stack back together
output = mx.stack([output_first, output_second], axis=-2)
# Flatten: (..., 2, dim//2) -> (..., dim)
output = mx.reshape(output, input_tensor.shape)
if needs_reshape:
# Reshape back: (B, H, T, D) -> (B, T, H*D)
b, h, t, d = output.shape
output = mx.swapaxes(output, 1, 2)
output = mx.reshape(output, (b, t, h * d))
return output.astype(input_dtype)
def generate_freq_grid(
positional_embedding_theta: float,
positional_embedding_max_pos_count: int,
inner_dim: int,
) -> mx.array:
"""Generate frequency grid for RoPE.
Args:
positional_embedding_theta: Base theta value
positional_embedding_max_pos_count: Number of position dimensions
inner_dim: Inner dimension of the model
Returns:
Frequency indices tensor
"""
theta = positional_embedding_theta
start = 1.0
end = theta
n_elem = 2 * positional_embedding_max_pos_count
# Compute logarithmic spacing
log_start = math.log(start) / math.log(theta)
log_end = math.log(end) / math.log(theta)
num_indices = inner_dim // n_elem
if num_indices == 0:
num_indices = 1
# Create linearly spaced values in log space
lin_space = mx.linspace(log_start, log_end, num_indices)
# Compute power indices
pow_indices = mx.power(theta, lin_space)
# Scale by pi/2
return pow_indices * (math.pi / 2)
def get_fractional_positions(
indices_grid: mx.array,
max_pos: List[int],
) -> mx.array:
"""Convert indices to fractional positions.
Args:
indices_grid: Grid of position indices of shape (B, n_pos_dims, ...)
max_pos: Maximum position for each dimension
Returns:
Fractional positions in range [-1, 1] after scaling
"""
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), (
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
)
# Divide each dimension by its max position
fractional_positions = []
for i in range(n_pos_dims):
frac = indices_grid[:, i] / max_pos[i]
fractional_positions.append(frac)
return mx.stack(fractional_positions, axis=-1)
def generate_freqs(
indices: mx.array,
indices_grid: mx.array,
max_pos: List[int],
use_middle_indices_grid: bool,
) -> mx.array:
"""Generate frequencies from indices and position grid.
Args:
indices: Frequency indices
indices_grid: Position indices grid
max_pos: Maximum positions per dimension
use_middle_indices_grid: Whether to use middle of index ranges
Returns:
Frequency tensor
"""
# Handle middle indices grid
if use_middle_indices_grid:
# indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end]
assert len(indices_grid.shape) == 4
assert indices_grid.shape[-1] == 2
indices_grid_start = indices_grid[..., 0]
indices_grid_end = indices_grid[..., 1]
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid.shape) == 4:
indices_grid = indices_grid[..., 0]
# Get fractional positions
fractional_positions = get_fractional_positions(indices_grid, max_pos)
# Compute frequencies
# fractional_positions: (B, T, n_dims)
# indices: (inner_dim // n_elem,)
# Result: (B, T, inner_dim // n_elem * n_dims)
# Scale fractional positions to [-1, 1]
scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims)
# Outer product with indices
# (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims(
mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0
)
# Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims)
freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims)
freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,))
return freqs
def split_freqs_cis(
freqs: mx.array,
pad_size: int,
num_attention_heads: int,
) -> Tuple[mx.array, mx.array]:
"""Prepare cos/sin frequencies for split RoPE.
Args:
freqs: Frequency tensor
pad_size: Padding size for dimension alignment
num_attention_heads: Number of attention heads
Returns:
Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2)
"""
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Add padding if needed
if pad_size != 0:
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention
b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
# Swap axes: (B, T, H, D//2) -> (B, H, T, D//2)
cos_freq = mx.swapaxes(cos_freq, 1, 2)
sin_freq = mx.swapaxes(sin_freq, 1, 2)
return cos_freq, sin_freq
def interleaved_freqs_cis(
freqs: mx.array,
pad_size: int,
) -> Tuple[mx.array, mx.array]:
"""Prepare cos/sin frequencies for interleaved RoPE.
Args:
freqs: Frequency tensor of shape (B, T, dim//2)
pad_size: Padding size for dimension alignment
Returns:
Tuple of (cos_freq, sin_freq) with shape (B, T, dim)
"""
# Compute cos and sin
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Repeat interleave: each element repeated twice
# (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...]
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
# Add padding if needed
if pad_size != 0:
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
return cos_freq, sin_freq
def precompute_freqs_cis(
indices_grid: mx.array,
dim: int,
theta: float = 10000.0,
max_pos: Optional[List[int]] = None,
use_middle_indices_grid: bool = False,
num_attention_heads: int = 32,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
double_precision: bool = False,
) -> Tuple[mx.array, mx.array]:
"""Precompute RoPE frequencies.
Args:
indices_grid: Position indices grid
dim: Dimension for RoPE
theta: Base theta value for frequency computation
max_pos: Maximum position per dimension
use_middle_indices_grid: Whether to use middle indices
num_attention_heads: Number of attention heads
rope_type: Type of RoPE (INTERLEAVED or SPLIT)
double_precision: If True, compute frequencies in float64 for higher precision
Returns:
Tuple of (cos_freq, sin_freq) tensors
"""
if max_pos is None:
max_pos = [20, 2048, 2048]
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
num_attention_heads, rope_type
)
# Keep positions in float32 for RoPE computation.
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
# empirical comparison shows float32 positions produce RoPE values matching
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
# position computation that gets amplified by high-frequency indices
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
indices_grid = indices_grid.astype(mx.float32)
# Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
# Generate frequencies
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
# Prepare cos/sin based on rope type
if rope_type == LTXRopeType.SPLIT:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
else:
# Interleaved
n_elem = 2 * indices_grid.shape[1]
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq, sin_freq
def _precompute_freqs_cis_double_precision(
indices_grid: mx.array,
dim: int,
theta: float,
max_pos: List[int],
use_middle_indices_grid: bool,
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
frequency grid computation (log-spaced values), then converts to float32.
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
model dtype throughout generate_freqs).
"""
import numpy as np
# Keep positions in float32 — same reasoning as the non-double-precision path.
indices_grid_f32 = indices_grid.astype(mx.float32)
n_pos_dims = indices_grid_f32.shape[1]
n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
# This is the critical precision step - PyTorch uses np.float64 here
log_start = np.log(1.0) / np.log(theta)
log_end = np.log(theta) / np.log(theta) # = 1.0
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
# Use numpy float64 for the linspace computation (matches PyTorch)
pow_indices = np.power(
theta,
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
)
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
# Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
if use_middle_indices_grid:
assert len(indices_grid_f32.shape) == 4
assert indices_grid_f32.shape[-1] == 2
indices_grid_start = indices_grid_f32[..., 0]
indices_grid_end = indices_grid_f32[..., 1]
indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid_f32.shape) == 4:
indices_grid_f32 = indices_grid_f32[..., 0]
# After handling: indices_grid_f32 shape is (B, n_dims, T)
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
# Compute fractional positions for each dimension
fractional_list = []
for i in range(n_pos_dims):
frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T)
fractional_list.append(frac)
# Stack: (B, T, n_dims)
fractional_positions = mx.stack(fractional_list, axis=-1)
# Scale to [-1, 1]
scaled_positions = fractional_positions * 2 - 1
# Compute frequencies: outer product
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
# freqs: (B, T, n_dims, num_indices)
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
freqs = mx.swapaxes(freqs, -1, -2)
freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1))
# Compute cos/sin
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
# Prepare based on rope type
if rope_type == LTXRopeType.SPLIT:
expected_freqs = dim // 2
current_freqs = cos_freq.shape[-1]
pad_size = expected_freqs - current_freqs
# Add padding
if pad_size > 0:
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
b, t = cos_freq.shape[0], cos_freq.shape[1]
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
cos_freq = mx.swapaxes(cos_freq, 1, 2)
sin_freq = mx.swapaxes(sin_freq, 1, 2)
else:
# Interleaved
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
pad_size = dim % n_elem
if pad_size > 0:
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
return cos_freq, sin_freq

View File

@@ -0,0 +1,181 @@
"""Second-order res_2s sampler for diffusion models.
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
noise injection, ported from the LTX-2 PyTorch implementation.
"""
import math
from typing import Optional
import mlx.core as mx
# ---------------------------------------------------------------------------
# Phi functions and RK coefficients (pure Python math, no MLX needed)
# ---------------------------------------------------------------------------
def phi(j: int, neg_h: float) -> float:
"""Compute phi_j(z) where z = -h (negative step size in log-space).
phi_1(z) = (e^z - 1) / z
phi_2(z) = (e^z - 1 - z) / z^2
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
"""
if abs(neg_h) < 1e-10:
return 1.0 / math.factorial(j)
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
return (math.exp(neg_h) - remainder) / (neg_h**j)
def get_res2s_coefficients(
h: float,
phi_cache: dict,
c2: float = 0.5,
) -> tuple[float, float, float]:
"""Compute res_2s Runge-Kutta coefficients for a given step size.
Args:
h: Step size in log-space = log(sigma / sigma_next)
phi_cache: Dictionary to cache phi function results.
c2: Substep position (default 0.5 = midpoint)
Returns:
(a21, b1, b2): RK coefficients.
"""
def get_phi(j: int, neg_h: float) -> float:
cache_key = (j, neg_h)
if cache_key in phi_cache:
return phi_cache[cache_key]
result = phi(j, neg_h)
phi_cache[cache_key] = result
return result
neg_h_c2 = -h * c2
phi_1_c2 = get_phi(1, neg_h_c2)
a21 = c2 * phi_1_c2
neg_h_full = -h
phi_2_full = get_phi(2, neg_h_full)
b2 = phi_2_full / c2
phi_1_full = get_phi(1, neg_h_full)
b1 = phi_1_full - b2
return a21, b1, b2
# ---------------------------------------------------------------------------
# SDE noise injection
# ---------------------------------------------------------------------------
def get_sde_coeff(
sigma_next: float,
) -> tuple[float, float, float]:
"""Compute SDE coefficients for variance-preserving noise injection.
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
Returns:
(alpha_ratio, sigma_down, sigma_up)
"""
sigma_up = sigma_next * 0.5
# Clamp sigma_up to avoid sqrt(negative)
sigma_up = min(sigma_up, sigma_next * 0.9999)
sigma_signal = 1.0 - sigma_next # sigma_max=1
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
alpha_ratio = sigma_signal + sigma_residual
if alpha_ratio == 0:
sigma_down = sigma_next
else:
sigma_down = sigma_residual / alpha_ratio
# Handle NaN edge cases
if math.isnan(sigma_up):
sigma_up = 0.0
if math.isnan(sigma_down):
sigma_down = sigma_next
if math.isnan(alpha_ratio):
alpha_ratio = 1.0
return alpha_ratio, sigma_down, sigma_up
def sde_noise_step(
sample: mx.array,
denoised_sample: mx.array,
sigma: float,
sigma_next: float,
noise: mx.array,
) -> mx.array:
"""Apply SDE noise injection step.
Advances sample from sigma to sigma_next with stochastic noise injection.
Args:
sample: Current sample (anchor point)
denoised_sample: Denoised prediction at this step
sigma: Current noise level
sigma_next: Next noise level
noise: Pre-generated noise tensor (channel-wise normalized)
Returns:
Noised sample at sigma_next
"""
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
if sigma_up == 0 or sigma_next == 0:
return denoised_sample
# Float32 arithmetic
sample_f32 = sample.astype(mx.float32)
denoised_f32 = denoised_sample.astype(mx.float32)
noise_f32 = noise.astype(mx.float32)
# Extract epsilon prediction
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
denoised_next = sample_f32 - sigma * eps_next
# Mix deterministic and stochastic components
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
return x_noised
# ---------------------------------------------------------------------------
# Noise generation
# ---------------------------------------------------------------------------
def channelwise_normalize(x: mx.array) -> mx.array:
"""Normalize each channel to zero mean and unit variance over spatial dims.
Operates on the last 2 dimensions (spatial H, W or time, freq).
"""
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
x = x - mean
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
x = x / std
return x
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
"""Generate channel-wise normalized Gaussian noise.
PyTorch uses float64; we use float32 (MLX doesn't support float64).
The channel-wise normalization is the key quality-affecting step.
Args:
shape: Shape of the noise tensor
key: MLX random key for deterministic generation
Returns:
Channel-wise normalized noise in float32
"""
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
# Global normalization
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
# Channel-wise normalization
noise = channelwise_normalize(noise)
return noise

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
import mlx.core as mx
import mlx.nn as nn
class PixArtAlphaTextProjection(nn.Module):
def __init__(
self,
in_features: int,
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="tanh") # Must match PyTorch's approximate="tanh"
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
x = self.linear1(x)
x = self.act(x)
x = self.linear2(x)
return x

View File

@@ -0,0 +1,403 @@
from dataclasses import dataclass, replace
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx_2.attention import Attention
from mlx_video.models.ltx_2.feed_forward import FeedForward
from mlx_video.utils import rms_norm
@dataclass(frozen=True)
class Modality:
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
enabled: bool = True
context_mask: Optional[mx.array] = None
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
# Raw sigma value (scalar per batch) for prompt adaln (LTX-2.3)
sigma: Optional[mx.array] = None
@dataclass(frozen=True)
class TransformerArgs:
x: mx.array
context: mx.array
context_mask: Optional[mx.array]
timesteps: mx.array
embedded_timestep: mx.array
positional_embeddings: Tuple[mx.array, mx.array]
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
cross_scale_shift_timestep: Optional[mx.array]
cross_gate_timestep: Optional[mx.array]
enabled: bool
# LTX-2.3: prompt-conditioned timestep embeddings for cross-attention
prompt_timesteps: Optional[mx.array] = None
prompt_embedded_timestep: Optional[mx.array] = None
class BasicAVTransformerBlock(nn.Module):
"""Audio-Video transformer block with cross-modal attention.
Supports video-only, audio-only, or combined audio-video processing
with bidirectional cross-attention between modalities.
"""
def __init__(
self,
idx: int,
video: Optional[TransformerConfig] = None,
audio: Optional[TransformerConfig] = None,
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
norm_eps: float = 1e-6,
has_prompt_adaln: bool = False,
):
super().__init__()
self.idx = idx
self.norm_eps = norm_eps
self.has_prompt_adaln = has_prompt_adaln
# Video components
if video is not None:
self.attn1 = Attention(
query_dim=video.dim,
heads=video.heads,
dim_head=video.d_head,
context_dim=None, # Self-attention
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.attn2 = Attention(
query_dim=video.dim,
context_dim=video.context_dim,
heads=video.heads,
dim_head=video.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.ff = FeedForward(video.dim, dim_out=video.dim)
# 9 params for LTX-2.3 (self-attn + cross-attn + FFN), 6 for LTX-2
num_ada_params = 9 if has_prompt_adaln else 6
self.scale_shift_table = mx.zeros((num_ada_params, video.dim))
if has_prompt_adaln:
self.prompt_scale_shift_table = mx.zeros((2, video.dim))
# Audio components
if audio is not None:
self.audio_attn1 = Attention(
query_dim=audio.dim,
heads=audio.heads,
dim_head=audio.d_head,
context_dim=None,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.audio_attn2 = Attention(
query_dim=audio.dim,
context_dim=audio.context_dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
num_audio_ada_params = 9 if has_prompt_adaln else 6
self.audio_scale_shift_table = mx.zeros((num_audio_ada_params, audio.dim))
if has_prompt_adaln:
self.audio_prompt_scale_shift_table = mx.zeros((2, audio.dim))
# Cross-modal attention (when both video and audio are enabled)
if audio is not None and video is not None:
# Audio-to-Video: Q from video, K/V from audio
self.audio_to_video_attn = Attention(
query_dim=video.dim,
context_dim=audio.dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
# Video-to-Audio: Q from audio, K/V from video
self.video_to_audio_attn = Attention(
query_dim=audio.dim,
context_dim=video.dim,
heads=audio.heads,
dim_head=audio.d_head,
rope_type=rope_type,
norm_eps=norm_eps,
has_gate_logits=has_prompt_adaln,
)
# Scale-shift tables for cross-attention
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
self.scale_shift_table_a2v_ca_video = mx.zeros((5, video.dim))
def get_ada_values(
self,
scale_shift_table: mx.array,
batch_size: int,
timestep: mx.array,
indices: slice,
) -> Tuple[mx.array, ...]:
"""Get adaptive normalization values from scale-shift table.
Args:
scale_shift_table: Table of shape (num_params, dim)
batch_size: Batch size
timestep: Timestep embeddings of shape (B, 1, num_params * dim) or similar
indices: Slice for which parameters to extract
Returns:
Tuple of scale-shift values
"""
num_ada_params = scale_shift_table.shape[0]
# scale_shift_table[indices]: (num_selected, dim)
# Add batch and sequence dimensions: (1, 1, num_selected, dim)
table_slice = scale_shift_table[indices]
table_expanded = mx.expand_dims(mx.expand_dims(table_slice, axis=0), axis=0)
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
timestep_reshaped = mx.reshape(
timestep,
(batch_size, timestep.shape[1], num_ada_params, -1)
)
# Extract the relevant indices
timestep_slice = timestep_reshaped[:, :, indices, :]
# Add table values to timestep
ada_values = table_expanded + timestep_slice
# Unbind along the parameter dimension
# Result: tuple of tensors, each of shape (B, seq, dim)
num_sliced = ada_values.shape[2]
result = tuple(ada_values[:, :, i, :] for i in range(num_sliced))
return result
def get_av_ca_ada_values(
self,
scale_shift_table: mx.array,
batch_size: int,
scale_shift_timestep: mx.array,
gate_timestep: mx.array,
num_scale_shift_values: int = 4,
) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]:
"""Get adaptive values for cross-modal attention.
Args:
scale_shift_table: Table with 5 parameters (4 scale-shift + 1 gate)
batch_size: Batch size
scale_shift_timestep: Timestep for scale-shift
gate_timestep: Timestep for gating
num_scale_shift_values: Number of scale-shift values (default 4)
Returns:
Tuple of 5 tensors: (scale1, shift1, scale2, shift2, gate)
"""
# Get scale-shift values
scale_shift_ada = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :],
batch_size,
scale_shift_timestep,
slice(None, None),
)
# Get gate values
gate_ada = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :],
batch_size,
gate_timestep,
slice(None, None),
)
# Squeeze the sequence dimension if it's 1
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
return (*scale_shift_squeezed, *gate_squeezed)
def __call__(
self,
video: Optional[TransformerArgs] = None,
audio: Optional[TransformerArgs] = None,
skip_video_self_attn: bool = False,
skip_audio_self_attn: bool = False,
skip_cross_modal: bool = False,
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Forward pass through transformer block.
Args:
video: Video modality arguments
audio: Audio modality arguments
skip_video_self_attn: Skip video self-attention (for STG perturbation)
skip_audio_self_attn: Skip audio self-attention (for STG perturbation)
skip_cross_modal: Skip all cross-modal attention (for modality isolation)
Returns:
Tuple of (updated_video, updated_audio) TransformerArgs
"""
batch_size = video.x.shape[0] if video is not None else audio.x.shape[0]
vx = video.x if video is not None else None
ax = audio.x if audio is not None else None
# Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
# Process video self-attention and cross-attention with text
if run_vx:
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
)
# Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
# Cross-attention with text context
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
vshift_q, vscale_q, vgate_q = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
)
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
)
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
else:
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
context=video.context,
mask=video.context_mask,
)
# Process audio self-attention and cross-attention with text
if run_ax:
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
)
# Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
# Cross-attention with text context
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
ashift_q, ascale_q, agate_q = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
)
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
)
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
else:
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),
context=audio.context,
mask=audio.context_mask,
)
# Audio-Video cross-modal attention
if run_a2v or run_v2a:
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
# Get adaptive values for audio cross-attention
(
scale_ca_audio_a2v,
shift_ca_audio_a2v,
scale_ca_audio_v2a,
shift_ca_audio_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
audio.cross_scale_shift_timestep,
audio.cross_gate_timestep,
)
# Get adaptive values for video cross-attention
(
scale_ca_video_a2v,
shift_ca_video_a2v,
scale_ca_video_v2a,
shift_ca_video_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
video.cross_scale_shift_timestep,
video.cross_gate_timestep,
)
# Audio-to-Video cross-attention
if run_a2v:
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
vx = vx + (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=video.cross_positional_embeddings,
k_pe=audio.cross_positional_embeddings,
)
* gate_out_a2v
)
# Video-to-Audio cross-attention
if run_v2a:
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
ax = ax + (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=audio.cross_positional_embeddings,
k_pe=video.cross_positional_embeddings,
)
* gate_out_v2a
)
# Process video feed-forward
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
)
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
vx = vx + self.ff(vx_scaled) * vgate_mlp
# Process audio feed-forward
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
)
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
# Return updated TransformerArgs
video_out = replace(video, x=vx) if video is not None else None
audio_out = replace(audio, x=ax) if audio is not None else None
return video_out, audio_out

View File

@@ -0,0 +1,371 @@
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
class Conv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
# Weight shape: (C_out, KD, KH, KW, C_in)
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
)
if bias:
self.bias = mx.zeros((out_channels,))
else:
self.bias = None
def __call__(self, x: mx.array) -> mx.array:
"""Forward pass.
Args:
x: Input tensor of shape (N, D, H, W, C_in)
Returns:
Output tensor of shape (N, D', H', W', C_out)
"""
y = mx.conv3d(
x,
self.weight,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
if self.bias is not None:
y = y + self.bias
return y
class GroupNorm3d(nn.Module):
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.weight = mx.ones((num_channels,))
self.bias = mx.zeros((num_channels,))
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C)
n, d, h, w, c = x.shape
input_dtype = x.dtype
x = x.astype(mx.float32)
# Reshape to (N, D*H*W, num_groups, C//num_groups)
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
# Compute mean and var over spatial and channel group dims
mean = mx.mean(x, axis=(1, 3), keepdims=True)
var = mx.var(x, axis=(1, 3), keepdims=True)
# Normalize
x = (x - mean) / mx.sqrt(var + self.eps)
# Reshape back
x = mx.reshape(x, (n, d, h, w, c))
# Apply weight and bias
weight = self.weight.astype(mx.float32)
bias = self.bias.astype(mx.float32)
x = x * weight + bias
# Convert back to input dtype
x = x.astype(input_dtype)
return x
class PixelShuffle2D(nn.Module):
"""Pixel shuffle for 2D spatial upsampling."""
def __init__(self, upscale_factor: int = 2):
super().__init__()
self.upscale_factor = upscale_factor
def __call__(self, x: mx.array) -> mx.array:
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
n, h, w, c = x.shape
r = self.upscale_factor
out_c = c // (r * r)
# Reshape: (N, H, W, out_c, r, r)
x = mx.reshape(x, (n, h, w, out_c, r, r))
# Permute: (N, H, r, W, r, out_c)
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
# Reshape: (N, H*r, W*r, out_c)
x = mx.reshape(x, (n, h * r, w * r, out_c))
return x
class SpatialRationalResampler(nn.Module):
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
super().__init__()
self.scale = scale
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
# Blur kernel for antialiasing
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
self.pixel_shuffle = PixelShuffle2D(2)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C) - channels last 3D format
n, d, h, w, c = x.shape
# Process frame by frame
# Reshape to (N*D, H, W, C) for 2D operations
x = mx.reshape(x, (n * d, h, w, c))
# Apply 2D conv
x = self.conv(x)
# Pixel shuffle for 2x upscaling
x = self.pixel_shuffle(x)
# Reshape back to (N, D, H*2, W*2, C)
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
return x
class ResBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv1 = Conv3d(channels, channels, kernel_size=3, padding=1)
self.norm1 = GroupNorm3d(32, channels)
self.conv2 = Conv3d(channels, channels, kernel_size=3, padding=1)
self.norm2 = GroupNorm3d(32, channels)
def __call__(self, x: mx.array) -> mx.array:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = nn.silu(x)
x = self.conv2(x)
x = self.norm2(x)
# Activation AFTER residual addition
x = nn.silu(x + residual)
return x
class LatentUpsampler(nn.Module):
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
# Initial projection
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = GroupNorm3d(32, mid_channels)
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
# Upsampler: 2D spatial upsampling (frame-by-frame)
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
# Final projection
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
"""Upsample latents by 2x spatially.
Args:
latent: Input tensor of shape (B, C, F, H, W) - channels first
debug: If True, print intermediate values for debugging
Returns:
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
"""
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
if debug:
print(" [DEBUG] LatentUpsampler forward pass:")
debug_stats("Input (channels first)", latent)
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
x = mx.transpose(latent, (0, 2, 3, 4, 1))
if debug:
debug_stats("After transpose to channels-last", x)
# Initial conv
x = self.initial_conv(x)
if debug:
debug_stats("After initial_conv", x)
x = self.initial_norm(x)
if debug:
debug_stats("After initial_norm", x)
x = nn.silu(x)
if debug:
debug_stats("After silu", x)
# Pre-upsample blocks
for i in sorted(self.res_blocks.keys()):
x = self.res_blocks[i](x)
if debug:
debug_stats(f"After res_blocks[{i}]", x)
# Upsample (2D spatial, frame-by-frame)
x = self.upsampler(x)
if debug:
debug_stats("After upsampler (spatial 2x)", x)
# Post-upsample blocks
for i in sorted(self.post_upsample_res_blocks.keys()):
x = self.post_upsample_res_blocks[i](x)
if debug:
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
# Final conv
x = self.final_conv(x)
if debug:
debug_stats("After final_conv", x)
# Convert back to channels first (B, C, F, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
if debug:
debug_stats("Output (channels first)", x)
return x
def upsample_latents(
latent: mx.array,
upsampler: LatentUpsampler,
latent_mean: mx.array,
latent_std: mx.array,
debug: bool = False,
) -> mx.array:
# Un-normalize: latent * std + mean
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
latent = latent * latent_std + latent_mean
# Upsample
latent = upsampler(latent, debug=debug)
# Re-normalize: (latent - mean) / std
latent = (latent - latent_mean) / latent_std
return latent
def load_upsampler(weights_path: str) -> LatentUpsampler:
"""Load upsampler from safetensors weights.
Args:
weights_path: Path to upsampler weights file
Returns:
Loaded LatentUpsampler model
"""
print(f"Loading spatial upsampler from {weights_path}...")
raw_weights = mx.load(weights_path)
# Check weight shapes to determine mid_channels
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
sample_key = "res_blocks.0.conv1.weight"
if sample_key in raw_weights:
mid_channels = raw_weights[sample_key].shape[0]
else:
mid_channels = 1024 # default
print(f" Detected mid_channels: {mid_channels}")
# Create model
upsampler = LatentUpsampler(
in_channels=128,
mid_channels=mid_channels,
num_blocks_per_stage=4,
)
# Sanitize weights - convert from PyTorch to MLX format
sanitized = {}
for key, value in raw_weights.items():
new_key = key
# LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
if key.startswith("upsampler.0."):
new_key = key.replace("upsampler.0.", "upsampler.conv.")
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
# Load weights
upsampler.load_weights(list(sanitized.items()), strict=False)
print(f" Loaded {len(sanitized)} weights")
return upsampler

View File

@@ -0,0 +1,8 @@
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
)

View File

@@ -0,0 +1,294 @@
from enum import Enum
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
class PaddingModeType(Enum):
ZEROS = "zeros"
REFLECT = "reflect"
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
"""Apply reflect padding to spatial dimensions of a 5D tensor.
Args:
x: Input tensor of shape (B, D, H, W, C) - channels last
pad_h: Padding for height dimension
pad_w: Padding for width dimension
Returns:
Padded tensor
"""
if pad_h == 0 and pad_w == 0:
return x
# Height padding (axis 2)
if pad_h > 0:
# Get reflection indices - exclude boundary
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
# Width padding (axis 3)
if pad_w > 0:
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
x = mx.concatenate([left_pad, x, right_pad], axis=3)
return x
def make_conv_nd(
dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
stride: Union[int, Tuple[int, ...]] = 1,
padding: Union[int, Tuple[int, ...], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> nn.Module:
if dims == 2:
return CausalConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
causal=causal,
spatial_padding_mode=spatial_padding_mode,
)
elif dims == 3:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
causal=causal,
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unsupported number of dimensions: {dims}")
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.causal = causal
self.spatial_padding_mode = spatial_padding_mode
# Normalize kernel_size and stride to tuples
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
self.kernel_size = kernel_size
self.stride = stride
self.time_kernel_size = kernel_size[0]
# Calculate spatial padding (temporal is handled separately via frame replication)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
self.spatial_padding = (height_pad, width_pad)
# Create the base convolution (without padding, we'll handle it manually)
self.conv = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0, # We handle padding manually
bias=True,
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
use_causal = causal if causal is not None else self.causal
# Apply temporal padding via frame replication
# Only apply if kernel_size > 1
if self.time_kernel_size > 1:
if use_causal:
# Causal: replicate first frame kernel_size-1 times at the beginning
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
x = mx.concatenate([first_frame_pad, x], axis=2)
else:
# Non-causal: replicate first frame at start, last frame at end
pad_size = (self.time_kernel_size - 1) // 2
if pad_size > 0:
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
x = mx.transpose(x, (0, 2, 3, 4, 1))
# Apply spatial padding
pad_h, pad_w = self.spatial_padding
if pad_h > 0 or pad_w > 0:
if self.spatial_padding_mode == PaddingModeType.REFLECT:
# Use reflect padding for spatial dimensions
x = reflect_pad_2d(x, pad_h, pad_w)
else:
# Use zero padding for spatial dimensions
pad_width = [
(0, 0), # Batch
(0, 0), # D (temporal - already padded)
(pad_h, pad_h), # H
(pad_w, pad_w), # W
(0, 0), # C
]
x = mx.pad(x, pad_width)
# Apply convolution with chunking for large tensors
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
x = self._chunked_conv3d(x)
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
x = mx.transpose(x, (0, 4, 1, 2, 3))
return x
def _chunked_conv3d(self, x: mx.array) -> mx.array:
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
Args:
x: Input tensor of shape (B, D, H, W, C) in channels-last format
Returns:
Output tensor after conv3d
"""
b, d, h, w, c = x.shape
total_elements = d * h * w * c
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
if total_elements <= max_safe_elements:
return self.conv(x)
elements_per_frame = h * w * c
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
kernel_t = self.time_kernel_size
overlap = kernel_t - 1
expected_output_frames = d - overlap
outputs = []
out_idx = 0
# Process chunks
in_start = 0
while out_idx < expected_output_frames:
remaining = expected_output_frames - out_idx
out_frames_this_chunk = min(chunk_size, remaining)
in_frames_needed = out_frames_this_chunk + overlap
in_end = min(in_start + in_frames_needed, d)
chunk = x[:, in_start:in_end, :, :, :]
chunk_out = self.conv(chunk)
mx.eval(chunk_out)
outputs.append(chunk_out)
out_idx += chunk_out.shape[1]
in_start += chunk_out.shape[1]
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=1)
class CausalConv2d(nn.Module):
"""2D convolution with optional causal padding."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int], str] = 0,
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
"""Initialize CausalConv2d."""
super().__init__()
self.causal = causal
self.spatial_padding_mode = spatial_padding_mode
# Normalize kernel_size and stride
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride)
self.kernel_size = kernel_size
self.stride = stride
# Calculate padding
if isinstance(padding, str) and padding == "same":
self.padding = (
(kernel_size[0] - 1) // 2,
(kernel_size[1] - 1) // 2,
)
elif isinstance(padding, int):
self.padding = (padding, padding)
else:
self.padding = padding
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
bias=True,
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
"""Forward pass."""
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
x = mx.transpose(x, (0, 2, 3, 1))
# Apply padding
pad_h, pad_w = self.padding
if pad_h != 0 or pad_w != 0:
pad_width = [
(0, 0), # Batch
(pad_h, pad_h), # H
(pad_w, pad_w), # W
(0, 0), # C
]
x = mx.pad(x, pad_width)
x = self.conv(x)
# Transpose back: (B, H, W, C) -> (B, C, H, W)
x = mx.transpose(x, (0, 3, 1, 2))
return x

View File

@@ -0,0 +1,692 @@
"""Video VAE Decoder for LTX-2 with timestep conditioning.
Architecture (from PyTorch weights):
- conv_in: 128 -> 1024
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
- pixel_norm + timestep modulation (last_scale_shift_table)
- conv_out: 128 -> 48
- unpatchify: 48 -> 3 with patch_size=4
"""
import math
from typing import Optional, Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
def get_timestep_embedding(
timesteps: mx.array,
embedding_dim: int,
flip_sin_to_cos: bool = True,
downscale_freq_shift: float = 0,
scale: float = 1,
max_period: int = 10000,
) -> mx.array:
"""Create sinusoidal timestep embeddings."""
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = mx.exp(exponent)
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
emb = scale * emb
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
if flip_sin_to_cos:
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
if embedding_dim % 2 == 1:
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb
class TimestepEmbedding(nn.Module):
"""MLP for timestep embedding."""
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
self.act = nn.SiLU()
def __call__(self, sample: mx.array) -> mx.array:
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class PixArtAlphaTimestepEmbedder(nn.Module):
"""Combined timestep embedding (sinusoidal + MLP)."""
def __init__(self, embedding_dim: int):
super().__init__()
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim
)
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
timesteps_proj = get_timestep_embedding(
timestep,
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb
class ResnetBlock3DSimple(nn.Module):
"""ResNet block with optional timestep conditioning.
Weight keys: conv1.conv, conv2.conv, scale_shift_table
"""
def __init__(
self,
channels: int,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = False,
):
super().__init__()
self.timestep_conditioning = timestep_conditioning
# Nested conv structure to match PyTorch naming: conv1.conv.weight
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
self.act = nn.SiLU()
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
if timestep_conditioning:
self.scale_shift_table = mx.zeros((4, channels))
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_ch,
out_channels=out_ch,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
def __call__(
self,
x: mx.array,
causal: bool = False,
timestep_embed: Optional[mx.array] = None,
) -> mx.array:
residual = x
batch_size = x.shape[0]
# Block 1 with optional timestep conditioning
x = self.pixel_norm(x)
if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
ada_values = ada_values + ts_reshaped
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
scale1 = ada_values[:, 1]
shift2 = ada_values[:, 2]
scale2 = ada_values[:, 3]
x = x * (1 + scale1) + shift1
x = self.act(x)
x = self.conv1(x, causal=causal)
# Block 2 with optional timestep conditioning
x = self.pixel_norm(x)
if self.timestep_conditioning and timestep_embed is not None:
x = x * (1 + scale2) + shift2
x = self.act(x)
x = self.conv2(x, causal=causal)
return x + residual
class ResBlockGroup(nn.Module):
"""Group of ResNet blocks with shared timestep embedding.
PyTorch naming: res_blocks.0, res_blocks.1, etc.
"""
def __init__(
self,
channels: int,
num_layers: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = False,
):
super().__init__()
self.timestep_conditioning = timestep_conditioning
# Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=channels * 4
)
# Use dict with int keys for MLX to track parameters properly
self.res_blocks = {
i: ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
)
for i in range(num_layers)
}
def __call__(
self,
x: mx.array,
causal: bool = False,
timestep: Optional[mx.array] = None,
) -> mx.array:
timestep_embed = None
if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0]
timestep_embed = self.time_embedder(
timestep.flatten(),
hidden_dtype=x.dtype
)
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
for res_block in self.res_blocks.values():
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
return x
class LTX2VideoDecoder(nn.Module):
"""LTX-2 Video VAE Decoder with timestep conditioning.
Architecture:
- conv_in: 128 -> 1024
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
- up_blocks.1: Upsampler 1024 -> 512
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
- up_blocks.3: Upsampler 512 -> 256
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
- up_blocks.5: Upsampler 256 -> 128
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
"""
# Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride)
# stride is (D, H, W) tuple
DEFAULT_BLOCKS = [
("res", 1024, 5),
("d2s", 1024, 2, (2, 2, 2)),
("res", 512, 5),
("d2s", 512, 2, (2, 2, 2)),
("res", 256, 5),
("d2s", 256, 2, (2, 2, 2)),
("res", 128, 5),
]
def __init__(
self,
in_channels: int = 128,
out_channels: int = 3,
patch_size: int = 4,
num_layers_per_block: int = 5,
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
timestep_conditioning: bool = True,
decoder_blocks: list = None,
):
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.timestep_conditioning = timestep_conditioning
# Decode parameters (configurable via constructor)
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights)
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
blocks = decoder_blocks or self.DEFAULT_BLOCKS
first_ch = blocks[0][1]
last_ch = blocks[-1][1]
# Initial conv: in_channels -> first block channels
class ConvInWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=in_channels,
out_channels=first_ch,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Build up blocks from config
self.up_blocks = {}
for idx, block_def in enumerate(blocks):
block_type = block_def[0]
ch = block_def[1]
if block_type == "res":
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
residual = block_def[4] if len(block_def) > 4 else True
self.up_blocks[idx] = DepthToSpaceUpsample(
dims=3,
in_channels=ch,
stride=stride,
residual=residual,
out_channels_reduction_factor=reduction,
spatial_padding_mode=spatial_padding_mode,
)
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
self_inner.conv = CausalConv3d(
in_channels=last_ch,
out_channels=final_out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper()
self.act = nn.SiLU()
if timestep_conditioning:
self.timestep_scale_multiplier = mx.array(1000.0)
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=last_ch * 2
)
self.last_scale_shift_table = mx.zeros((2, last_ch))
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
if key.startswith("vae.per_channel_statistics."):
# Map per-channel statistics (use exact key matching)
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue # Skip other statistics keys
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
"""Load a pretrained decoder from a directory with config.json and weights.
Args:
model_path: Path to directory containing config.json and safetensors files,
or path to a single safetensors file.
strict: Whether to require all weight keys to match.
Returns:
Loaded LTX2VideoDecoder instance
"""
import json
model_path = Path(model_path)
config_dict = {}
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Load weights from directory
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
# Determine spatial padding mode from config
spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect")
spatial_padding_mode = PaddingModeType(spatial_padding_mode_str)
model = cls(
timestep_conditioning=config_dict.get("timestep_conditioning", False),
decoder_blocks=decoder_blocks,
spatial_padding_mode=spatial_padding_mode,
)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
@staticmethod
def _infer_blocks(weights: dict) -> list:
"""Infer decoder block structure from weight keys."""
block_indices = set()
for k in weights:
if "up_blocks." in k:
idx_str = k.split("up_blocks.")[1].split(".")[0]
if idx_str.isdigit():
block_indices.add(int(idx_str))
if not block_indices:
return None
# First pass: collect block info
raw_blocks = []
for idx in sorted(block_indices):
has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights)
res_indices = set()
for k in weights:
prefix = f"up_blocks.{idx}.res_blocks."
if prefix in k:
res_idx = k.split(prefix)[1].split(".")[0]
if res_idx.isdigit():
res_indices.add(int(res_idx))
if has_conv and not res_indices:
# D2S block - get conv shape
for k, v in weights.items():
if f"up_blocks.{idx}.conv." in k and "weight" in k:
in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1]
conv_out_ch = v.shape[0]
raw_blocks.append(("d2s", in_ch, conv_out_ch))
break
elif res_indices:
num_res = max(res_indices) + 1
for k, v in weights.items():
if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k:
ch = v.shape[0]
raw_blocks.append(("res", ch, num_res))
break
# Second pass: determine d2s strides using the channel progression
# For each d2s block, the next res block tells us the expected output channels
blocks = []
d2s_strides = []
for i, block in enumerate(raw_blocks):
if block[0] == "res":
blocks.append(block)
elif block[0] == "d2s":
in_ch, conv_out_ch = block[1], block[2]
# Find next res block's channels
next_ch = None
for j in range(i + 1, len(raw_blocks)):
if raw_blocks[j][0] == "res":
next_ch = raw_blocks[j][1]
break
if next_ch is None:
next_ch = in_ch // 2 # fallback
# out_ch = in_ch // reduction
reduction = in_ch // next_ch if next_ch > 0 else 2
# conv_out = next_ch * multiplier → multiplier = conv_out / next_ch
multiplier = conv_out_ch // next_ch if next_ch > 0 else 8
# Determine stride from multiplier
if multiplier == 8:
stride = (2, 2, 2)
elif multiplier == 4:
stride = (1, 2, 2)
elif multiplier == 2:
stride = (2, 1, 1)
else:
stride = (2, 2, 2)
d2s_strides.append(stride)
blocks.append(("d2s", in_ch, reduction, stride))
if not blocks:
return None
# Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True
# LTX-2.3 has mixed strides or reduction=1 → residual=False
has_mixed_strides = len(set(d2s_strides)) > 1
has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s")
use_residual = not has_mixed_strides and not has_non_standard_reduction
# Apply residual flag to all d2s blocks
final_blocks = []
for block in blocks:
if block[0] == "d2s":
final_blocks.append(("d2s", block[1], block[2], block[3], use_residual))
else:
final_blocks.append(block)
return final_blocks
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
def __call__(
self,
sample: mx.array,
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
chunked_conv: bool = False,
) -> mx.array:
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
scaled_timestep = None
if self.timestep_conditioning and timestep is not None:
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep)
elif isinstance(block, DepthToSpaceUpsample):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
x = self.pixel_norm(x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(),
hidden_dtype=x.dtype
)
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
x = self.act(x)
x = self.conv_out(x, causal=causal)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
return x
def decode_tiled(
self,
sample: mx.array,
tiling_config: Optional[TilingConfig] = None,
tiling_mode: str = "auto",
causal: bool = False,
timestep: Optional[mx.array] = None,
debug: bool = False,
on_frames_ready: Optional[callable] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
This method is useful for decoding large videos that would otherwise
cause out-of-memory errors. It divides the latents into tiles,
decodes each tile separately, and blends them together.
Args:
sample: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration. If None, uses TilingConfig.default().
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
debug: Whether to print debug info.
Returns:
Decoded video of shape (B, 3, F*8, H*8, W*8).
"""
if tiling_config is None:
tiling_config = TilingConfig.default()
# Check if tiling is actually needed
_, _, f, h, w = sample.shape
needs_spatial_tiling = False
needs_temporal_tiling = False
# Spatial scale is 32 (8x VAE upsample + 4x unpatchify)
# Temporal scale is 8
spatial_scale = 32
temporal_scale = 8
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale
if h > tile_size_latent or w > tile_size_latent:
needs_spatial_tiling = True
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale
if f > tile_size_latent:
needs_temporal_tiling = True
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
return decode_with_tiling(
decoder_fn=self,
latents=sample,
tiling_config=tiling_config,
spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x
temporal_scale=8, # VAE temporal upsampling factor
causal=causal,
timestep=timestep,
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)
# Backward-compatible alias
VideoDecoder = LTX2VideoDecoder

View File

@@ -0,0 +1,44 @@
"""Video VAE Encoder for LTX-2 Image-to-Video.
The encoder compresses input images/videos to latent representations.
Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation.
"""
import mlx.core as mx
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
def encode_image(
image: mx.array,
encoder: VideoEncoder,
) -> mx.array:
"""Encode a single image to latent space.
Args:
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
encoder: Loaded VAE encoder
Returns:
Latent tensor of shape (1, 128, 1, H//32, W//32)
"""
# Add batch dimension if needed
if image.ndim == 3:
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
# Convert from (B, H, W, C) to (B, C, H, W)
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
# Normalize to [-1, 1]
if image.max() > 1.0:
image = image / 255.0
image = image * 2.0 - 1.0
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
# Encode
latent = encoder(image)
return latent

View File

@@ -0,0 +1,125 @@
"""Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert video to patches.
Moves spatial pixels from H, W dimensions to channel dimension.
Args:
x: Input tensor of shape (B, C, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
"""
b, c, f, h, w = x.shape
# Check dimensions are divisible
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
assert f % patch_size_t == 0
# New dimensions
new_h = h // patch_size_hw
new_w = w // patch_size_hw
new_f = f // patch_size_t
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
x = mx.transpose(x, (0, 1, 3, 7, 5, 2, 4, 6))
# Reshape: (B, C, pt, pw, ph, F', H', W') -> (B, C*pt*pw*ph, F', H', W')
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
return x
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
"""Convert patches back to video.
Inverse of patchify - moves pixels from channel dimension back to spatial.
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
Args:
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
patch_size_hw: Spatial patch size
patch_size_t: Temporal patch size
Returns:
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
"""
b, c_packed, f, h, w = x.shape
# Calculate original channel count
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
# Permute to interleave patches with spatial dims:
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
return x
class PerChannelStatistics(nn.Module):
def __init__(self, latent_channels: int = 128):
super().__init__()
self.latent_channels = latent_channels
# Learnable per-channel mean and std
self.mean = mx.zeros((latent_channels,))
self.std = mx.ones((latent_channels,))
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latents using per-channel statistics.
Args:
x: Input tensor of shape (B, C, ...)
Returns:
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return ((x - mean) / std).astype(dtype)
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics.
Args:
x: Normalized tensor of shape (B, C, ...)
Returns:
Denormalized tensor
"""
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return (x * std + mean).astype(dtype)

View File

@@ -0,0 +1,172 @@
"""ResNet blocks for Video VAE."""
from enum import Enum
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.utils import PixelNorm
class NormLayerType(Enum):
GROUP_NORM = "group_norm"
PIXEL_NORM = "pixel_norm"
def get_norm_layer(
norm_type: NormLayerType,
num_channels: int,
num_groups: int = 32,
eps: float = 1e-6,
) -> nn.Module:
if norm_type == NormLayerType.GROUP_NORM:
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
elif norm_type == NormLayerType.PIXEL_NORM:
return PixelNorm(eps=eps)
else:
raise ValueError(f"Unknown norm type: {norm_type}")
class ResnetBlock3D(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
out_channels: Optional[int] = None,
eps: float = 1e-6,
groups: int = 32,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.inject_noise = inject_noise
# First normalization and convolution
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
self.conv1 = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
# Second normalization and convolution
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
self.conv2 = CausalConv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
# Shortcut connection if channels change
if in_channels != out_channels:
self.shortcut = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
spatial_padding_mode=spatial_padding_mode,
)
else:
self.shortcut = None
# Activation
self.act = nn.SiLU()
def __call__(
self,
x: mx.array,
causal: bool = True,
generator: Optional[int] = None,
) -> mx.array:
residual = x
# First block
x = self.norm1(x)
x = self.act(x)
x = self.conv1(x, causal=causal)
# Inject noise if enabled
if self.inject_noise and generator is not None:
noise = mx.random.normal(x.shape)
x = x + noise * 0.01
# Second block
x = self.norm2(x)
x = self.act(x)
x = self.conv2(x, causal=causal)
# Shortcut
if self.shortcut is not None:
residual = self.shortcut(residual, causal=causal)
return x + residual
class UNetMidBlock3D(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
inject_noise: bool = False,
timestep_conditioning: bool = False,
attention_head_dim: Optional[int] = None,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.num_layers = num_layers
# Create ResNet blocks - use dict for MLX parameter tracking
# Named res_blocks to match PyTorch weight keys
self.res_blocks = {
i: ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for i in range(num_layers)
}
def __call__(
self,
x: mx.array,
causal: bool = True,
timestep: Optional[mx.array] = None,
generator: Optional[int] = None,
) -> mx.array:
for resnet in self.res_blocks.values():
x = resnet(x, causal=causal, generator=generator)
return x

View File

@@ -0,0 +1,275 @@
"""Sampling operations for Video VAE (upsampling/downsampling)."""
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
class SpaceToDepthDownsample(nn.Module):
"""Space-to-depth downsampling with 3x3 conv and skip connection.
PyTorch-compatible implementation:
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
2. Space-to-depth on conv output: channels * prod(stride)
3. Space-to-depth on input with group averaging for skip connection
4. Add skip connection
"""
def __init__(
self,
dims: int,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]],
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
stride = (stride, stride, stride)
self.stride = stride
self.dims = dims
self.out_channels = out_channels
# Calculate channels
multiplier = stride[0] * stride[1] * stride[2]
self.group_size = in_channels * multiplier // out_channels
conv_out_channels = out_channels // multiplier
# 3x3 convolution (not 1x1)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def _space_to_depth(self, x: mx.array) -> mx.array:
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Reshape to group spatial elements
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Temporal padding for causal mode
if st == 2:
# Duplicate first frame for padding
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
d = d + 1
# Pad if necessary to make dimensions divisible by stride
pad_d = (st - d % st) % st
pad_h = (sh - h % sh) % sh
pad_w = (sw - w % sw) % sw
if pad_d > 0 or pad_h > 0 or pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
# Skip connection: space-to-depth on input, then group mean
x_in = self._space_to_depth(x)
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
b2, c2, d2, h2, w2 = x_in.shape
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
# Conv branch: apply conv then space-to-depth
x_conv = self.conv(x, causal=causal)
x_conv = self._space_to_depth(x_conv)
# Add skip connection
return x_conv + x_in
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int,
in_channels: int,
stride: Union[int, Tuple[int, int, int]],
residual: bool = False,
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
stride = (stride, stride, stride)
self.stride = stride
self.dims = dims
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
# Calculate output channels
multiplier = stride[0] * stride[1] * stride[2]
out_channels = in_channels // out_channels_reduction_factor
self.out_channels = out_channels
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
self.conv = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels * multiplier,
kernel_size=3,
stride=1,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def _depth_to_space(self, x: mx.array) -> mx.array:
b, c_packed, d, h, w = x.shape
st, sh, sw = self.stride
c = c_packed // (st * sh * sw)
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
return x
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Compute residual path if enabled
x_residual = None
if self.residual:
# Reshape input: treat channels as spatial factors
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
x_residual = self._depth_to_space(x)
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
# num_repeat = prod(stride) / out_channels_reduction_factor
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
# Remove first temporal frame if temporal upsampling
if st > 1:
x_residual = x_residual[:, :, 1:, :, :]
# Use chunked mode for large tensors to reduce peak memory
if chunked_conv and d > 4:
x = self._chunked_conv_depth_to_space(x, causal)
else:
# Apply conv
x = self.conv(x, causal=causal)
# Depth to space rearrangement
x = self._depth_to_space(x)
# Remove first frame for causal temporal upsampling
if st > 1:
x = x[:, :, 1:, :, :]
# Add residual
if self.residual and x_residual is not None:
x = x + x_residual
return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
immediately apply depth_to_space.
Args:
x: Input tensor of shape (B, C, D, H, W)
causal: Whether to use causal convolutions
Returns:
Output tensor after conv + depth_to_space
"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
out_c = self.out_channels
# Output dimensions
out_d = d * st
out_h = h * sh
out_w = w * sw
# Chunk size in temporal dimension (process 4 frames at a time)
chunk_size = 4
kernel_t = 3 # Temporal kernel size
# For causal conv, we need (kernel_t - 1) frames of padding at the start
# For non-causal, we need (kernel_t - 1) // 2 on each side
if causal:
# Pad start with first frame repeated
pad_start = kernel_t - 1
pad_end = 0
else:
pad_start = (kernel_t - 1) // 2
pad_end = (kernel_t - 1) // 2
# Allocate output
outputs = []
# Process in chunks with overlap for conv kernel
t_pos = 0
while t_pos < d:
t_end = min(t_pos + chunk_size, d)
# Calculate input range with padding for kernel
in_start = max(0, t_pos - pad_start)
in_end = min(d, t_end + pad_end)
# Extract chunk
chunk = x[:, :, in_start:in_end, :, :]
# Apply conv to chunk
chunk_conv = self.conv(chunk, causal=causal)
# Apply depth_to_space
chunk_out = self._depth_to_space(chunk_conv)
# Calculate valid output range (excluding padding effects)
# Each input frame produces st output frames
out_start = (t_pos - in_start) * st
out_end = out_start + (t_end - t_pos) * st
# Extract valid portion
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
outputs.append(chunk_out)
# Evaluate to free intermediate memory
mx.eval(outputs[-1])
t_pos = t_end
# Concatenate all chunks
if len(outputs) == 1:
return outputs[0]
return mx.concatenate(outputs, axis=2)

View File

@@ -0,0 +1,492 @@
"""VAE Tiling Configuration for decoding large videos.
Implements spatial and temporal tiling with trapezoidal blending masks
to decode large videos without running out of memory.
Default configuration (from PyTorch):
- Spatial: 512px tiles with 64px overlap
- Temporal: 64 frames with 24 frame overlap
"""
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
import mlx.core as mx
def compute_trapezoidal_mask_1d(
length: int,
ramp_left: int,
ramp_right: int,
left_starts_from_0: bool = False,
) -> mx.array:
"""Generate a 1D trapezoidal blending mask with linear ramps.
Args:
length: Output length of the mask.
ramp_left: Fade-in length on the left.
ramp_right: Fade-out length on the right.
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
Useful for temporal tiles where the first tile is causal.
Returns:
A 1D array of shape (length,) with values in [0, 1].
"""
if length <= 0:
raise ValueError("Mask length must be positive.")
ramp_left = max(0, min(ramp_left, length))
ramp_right = max(0, min(ramp_right, length))
# Start with ones
mask = [1.0] * length
# Apply left ramp (fade in)
if ramp_left > 0:
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
# Create fade_in values using linspace logic
fade_in_full = [i / (interval_length - 1) for i in range(interval_length)]
fade_in = fade_in_full[:-1] # Remove last element
if not left_starts_from_0:
fade_in = fade_in[1:] # Remove first element too
for i in range(min(ramp_left, len(fade_in))):
mask[i] *= fade_in[i]
# Apply right ramp (fade out)
if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i]
return mx.clip(mx.array(mask), 0, 1)
@dataclass(frozen=True)
class SpatialTilingConfig:
"""Configuration for dividing each frame into spatial tiles with optional overlap."""
tile_size_in_pixels: int
tile_overlap_in_pixels: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
)
@dataclass(frozen=True)
class TemporalTilingConfig:
"""Configuration for dividing a video into temporal tiles."""
tile_size_in_frames: int
tile_overlap_in_frames: int = 0
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
)
@dataclass(frozen=True)
class TilingConfig:
"""Configuration for splitting video into tiles with optional overlap."""
spatial_config: Optional[SpatialTilingConfig] = None
temporal_config: Optional[TemporalTilingConfig] = None
@classmethod
def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
)
@classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
temporal_config=None,
)
@classmethod
def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig":
"""Temporal tiling only (for long videos with small resolution)."""
return cls(
spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
)
@classmethod
def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
)
@classmethod
def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
)
@classmethod
def auto(
cls,
height: int,
width: int,
num_frames: int,
spatial_threshold: int = 512,
temporal_threshold: int = 65,
) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions.
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
enough context for CausalConv3d and sufficient overlap for clean blending.
Args:
height: Video height in pixels
width: Video width in pixels
num_frames: Number of video frames
spatial_threshold: Enable spatial tiling if either dimension exceeds this
temporal_threshold: Enable temporal tiling if frames exceed this
Returns:
TilingConfig if tiling is needed, None otherwise
"""
needs_spatial = height > spatial_threshold or width > spatial_threshold
needs_temporal = num_frames > temporal_threshold
if not needs_spatial and not needs_temporal:
return None
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
# Smaller tiles cause quality degradation because CausalConv3d needs
# sufficient temporal context and overlap for clean blending.
spatial_config = None
temporal_config = None
if needs_spatial:
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
if needs_temporal:
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@dataclass
class DimensionIntervals:
"""Intervals for splitting a single dimension."""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
"""Split a spatial dimension into intervals."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
ends = [start + size for start in starts]
ends[-1] = dimension_size
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
# Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size)
# Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1
starts = intervals.starts.copy()
left_ramps = intervals.left_ramps.copy()
for i in range(1, len(starts)):
starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask."""
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask."""
start = begin * scale
stop = end * scale
left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
return slice(start, stop), mask
def decode_with_tiling(
decoder_fn,
latents: mx.array,
tiling_config: TilingConfig,
spatial_scale: int = 32,
temporal_scale: int = 8,
causal: bool = False,
timestep: Optional[mx.array] = None,
chunked_conv: bool = False,
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
) -> mx.array:
"""Decode latents using tiling to reduce memory usage.
Args:
decoder_fn: Decoder function to call for each tile.
latents: Input latents of shape (B, C, F, H, W).
tiling_config: Tiling configuration.
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
temporal_scale: Temporal scale factor (8 for LTX VAE).
causal: Whether to use causal convolutions.
timestep: Optional timestep for conditioning.
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
start_idx: Starting frame index in the full video.
Returns:
Decoded video.
"""
import gc
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = 1 + (f_latent - 1) * temporal_scale
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
# Get tile size and overlap in latent space
if tiling_config.spatial_config is not None:
s_cfg = tiling_config.spatial_config
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
else:
spatial_tile_size = max(h_latent, w_latent)
spatial_overlap = 0
if tiling_config.temporal_config is not None:
t_cfg = tiling_config.temporal_config
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
else:
temporal_tile_size = f_latent
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
num_t_tiles = len(temporal_intervals.starts)
num_h_tiles = len(height_intervals.starts)
num_w_tiles = len(width_intervals.starts)
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
# Initialize output and weight accumulator
# Use float32 for accumulation to avoid precision issues
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
mx.eval(output, weights)
tile_idx = 0
for t_idx in range(num_t_tiles):
t_start = temporal_intervals.starts[t_idx]
t_end = temporal_intervals.ends[t_idx]
t_left = temporal_intervals.left_ramps[t_idx]
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
h_end = height_intervals.ends[h_idx]
h_left = height_intervals.left_ramps[h_idx]
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
w_end = width_intervals.ends[w_idx]
w_left = width_intervals.left_ramps[w_idx]
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
mx.eval(tile_output)
# Clear tile_latents reference
del tile_latents
# Get actual decoded dimensions
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
expected_t = out_t_slice.stop - out_t_slice.start
expected_h = out_h_slice.stop - out_h_slice.start
expected_w = out_w_slice.stop - out_w_slice.start
# Handle potential size mismatches (use minimum)
actual_t = min(decoded_t, expected_t)
actual_h = min(decoded_h, expected_h)
actual_w = min(decoded_w, expected_w)
# Build blend mask
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
# Clear full tile_output
del tile_output
# Compute output coordinates
t_out_start = out_t_slice.start
t_out_end = t_out_start + actual_t
h_out_start = out_h_slice.start
h_out_end = h_out_start + actual_h
w_out_start = out_w_slice.start
w_out_end = w_out_start + actual_w
# Use direct slice assignment (MLX supports this)
# Weighted accumulation
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
)
# Force evaluation to free memory
mx.eval(output, weights)
# Clean up tile-specific arrays
del tile_output_slice, weighted_tile, blend_mask
del t_mask_slice, h_mask_slice, w_mask_slice
tile_idx += 1
# Periodic garbage collection and cache clearing
if tile_idx % 4 == 0:
gc.collect()
try:
mx.clear_cache()
except Exception:
pass # May not be available on all platforms
# After completing all spatial tiles for this temporal tile,
# check if any frames are now finalized (no future tiles will contribute)
if on_frames_ready is not None and num_t_tiles > 1:
# Determine the finalized frame boundary
# Frames before the start of the next tile's output region are finalized
if t_idx < num_t_tiles - 1:
# Next tile starts at temporal_intervals.starts[t_idx + 1]
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
# Map to output frame index (first frame of next tile's contribution)
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
if next_tile_start_out > emitted:
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
on_frames_ready(finalized_output, emitted)
decode_with_tiling._emitted_frames = next_tile_start_out
del finalized_output, finalized_weights
gc.collect()
# Normalize by weights
weights = mx.maximum(weights, 1e-8)
output = output / weights
mx.eval(output)
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
on_frames_ready(remaining_output, emitted)
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
del decode_with_tiling._emitted_frames
# Clean up weights
del weights
gc.collect()
# Convert back to original dtype if needed
return output.astype(latents.dtype)

View File

@@ -0,0 +1,597 @@
"""Video VAE Encoder and Decoder for LTX-2."""
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample,
SpaceToDepthDownsample,
)
from mlx_video.utils import PixelNorm
class LogVarianceType(Enum):
"""Log variance mode for VAE."""
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
NONE = "none"
def _make_encoder_block(
block_name: str,
block_config: Dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
"""Create an encoder block.
Args:
block_name: Type of block
block_config: Block configuration
in_channels: Input channels
convolution_dimensions: Number of dimensions
norm_layer: Normalization layer type
norm_num_groups: Number of groups for group norm
spatial_padding_mode: Padding mode
Returns:
Tuple of (block, output_channels)
"""
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 1, 1),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(1, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
out_channels = in_channels * block_config.get("multiplier", 2)
block = CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=(2, 2, 2),
padding=1,
causal=True,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time_res":
out_channels = in_channels * block_config.get("multiplier", 2)
block = SpaceToDepthDownsample(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown encoder block: {block_name}")
return block, out_channels
def _make_decoder_block(
block_name: str,
block_config: Dict[str, Any],
in_channels: int,
convolution_dimensions: int,
norm_layer: NormLayerType,
timestep_conditioning: bool,
norm_num_groups: int,
spatial_padding_mode: PaddingModeType,
) -> Tuple[nn.Module, int]:
"""Create a decoder block."""
out_channels = in_channels
if block_name == "res_x":
block = UNetMidBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
num_layers=block_config["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
out_channels = in_channels // block_config.get("multiplier", 2)
block = ResnetBlock3D(
dims=convolution_dimensions,
in_channels=in_channels,
out_channels=out_channels,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
inject_noise=block_config.get("inject_noise", False),
timestep_conditioning=False,
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
out_channels = in_channels // block_config.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=convolution_dimensions,
in_channels=in_channels,
stride=(2, 2, 2),
residual=block_config.get("residual", False),
out_channels_reduction_factor=block_config.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown decoder block: {block_name}")
return block, out_channels
class VideoEncoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(self, config: "VideoEncoderModelConfig"):
"""Initialize VideoEncoder from config.
Args:
config: VideoEncoderModelConfig with encoder parameters
"""
super().__init__()
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
self.patch_size = config.patch_size
self.norm_layer = config.norm_layer
self.latent_channels = config.out_channels
self.latent_log_var = config.latent_log_var
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
# After patchify, channels increase by patch_size^2
in_channels = config.in_channels * config.patch_size ** 2
feature_channels = config.out_channels
# Initial convolution
self.conv_in = CausalConv3d(
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {}
for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=config.convolution_dimensions,
norm_layer=config.norm_layer,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks[idx] = block
# Output normalization and convolution
if config.norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif config.norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
# Calculate output convolution channels
conv_out_channels = config.out_channels
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
in_channels=feature_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=encoder_spatial_padding_mode,
)
def __call__(self, sample: mx.array) -> mx.array:
"""Encode video to latent representation.
Args:
sample: Input video of shape (B, C, F, H, W).
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
Returns:
Normalized latent means of shape (B, 128, F', H', W')
"""
# Validate frame count
frames_count = sample.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError(
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
)
# Initial patchify
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for i in range(len(self.down_blocks)):
down_block = self.down_blocks[i]
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else:
sample = down_block(sample, causal=True)
# Output processing
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=True)
# Handle log variance modes
if self.latent_log_var == LogVarianceType.UNIFORM:
means = sample[:, :-1, ...]
logvar = sample[:, -1:, ...]
num_channels = means.shape[1]
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
sample = mx.concatenate([means, repeated_logvar], axis=1)
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30
sample = mx.concatenate([
sample,
mx.full_like(sample, approx_ln_0),
], axis=1)
# Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
if "position_ids" in key:
continue
# Only process VAE encoder weights
if not key.startswith("vae."):
continue
# Handle per-channel statistics
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue
elif key.startswith("vae.encoder."):
new_key = key.replace("vae.encoder.", "")
else:
continue
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
"""Load a pretrained VideoEncoder from a directory with weights and config.
Args:
model_path: Path to directory containing safetensors weights and config.json
Returns:
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
config = VideoEncoderModelConfig(**config_dict)
else:
config = VideoEncoderModelConfig()
# Load weights
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
if model_path.is_file():
weights = mx.load(str(model_path))
else:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
else:
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Create model, sanitize and load weights
model = cls(config)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=False)
return model
class VideoDecoder(nn.Module):
_DEFAULT_NORM_NUM_GROUPS = 32
def __init__(
self,
convolution_dimensions: int = 3,
in_channels: int = 128,
out_channels: int = 3,
decoder_blocks: List[Tuple[str, Any]] = None,
patch_size: int = 4,
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
causal: bool = False,
timestep_conditioning: bool = False,
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
):
"""Initialize VideoDecoder.
Args:
convolution_dimensions: Number of dimensions
in_channels: Input latent channels
out_channels: Output channels (3 for RGB)
decoder_blocks: List of (block_name, config) tuples
patch_size: Spatial patch size
norm_layer: Normalization layer type
causal: Whether to use causal convolutions
timestep_conditioning: Whether to use timestep conditioning
decoder_spatial_padding_mode: Padding mode
"""
super().__init__()
if decoder_blocks is None:
decoder_blocks = []
self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
# Per-channel statistics for denormalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Noise and timestep parameters
self.decode_noise_scale = 0.025
self.decode_timestep = 0.05
# Compute initial feature channels
feature_channels = in_channels
for block_name, block_params in list(reversed(decoder_blocks)):
block_config = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
feature_channels = feature_channels * block_config.get("multiplier", 2)
if block_name == "compress_all":
feature_channels = feature_channels * block_config.get("multiplier", 1)
# Initial convolution
self.conv_in = CausalConv3d(
in_channels=in_channels,
out_channels=feature_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
# Build decoder blocks (reversed order)
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_decoder_block(
block_name=block_name,
block_config=block_config,
in_channels=feature_channels,
convolution_dimensions=convolution_dimensions,
norm_layer=norm_layer,
timestep_conditioning=timestep_conditioning,
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks[idx] = block
# Output normalization
if norm_layer == NormLayerType.GROUP_NORM:
self.conv_norm_out = nn.GroupNorm(
num_groups=self._norm_num_groups,
dims=feature_channels,
eps=1e-6,
)
elif norm_layer == NormLayerType.PIXEL_NORM:
self.conv_norm_out = PixelNorm()
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(
in_channels=feature_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
spatial_padding_mode=decoder_spatial_padding_mode,
)
def __call__(
self,
sample: mx.array,
timestep: Optional[mx.array] = None,
) -> mx.array:
"""Decode latent to video.
Args:
sample: Latent tensor of shape (B, 128, F', H', W')
timestep: Optional timestep for conditioning
Returns:
Decoded video of shape (B, 3, F, H, W)
"""
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
# Denormalize latents
sample = self.per_channel_statistics.un_normalize(sample)
# Use default timestep if not provided
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
# Initial convolution
sample = self.conv_in(sample, causal=self.causal)
# Process through decoder blocks
for i in range(len(self.up_blocks)):
up_block = self.up_blocks[i]
if isinstance(up_block, UNetMidBlock3D):
sample = up_block(sample, causal=self.causal)
elif isinstance(up_block, ResnetBlock3D):
sample = up_block(sample, causal=self.causal)
else:
sample = up_block(sample, causal=self.causal)
# Output processing
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
# Unpatchify to restore spatial resolution
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample