Refactor LTX-2 model structure
This commit is contained in:
8
mlx_video/models/ltx_2/__init__.py
Normal file
8
mlx_video/models/ltx_2/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
TransformerConfig,
|
||||
LTXModelType,
|
||||
)
|
||||
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
|
||||
161
mlx_video/models/ltx_2/adaln.py
Normal file
161
mlx_video/models/ltx_2/adaln.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.utils import get_timestep_embedding
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
embedding_coefficient: int = 6,
|
||||
use_additional_conditions: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim=embedding_dim,
|
||||
size_emb_dim=0 if not use_additional_conditions else embedding_dim // 3,
|
||||
use_additional_conditions=use_additional_conditions,
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
added_cond_kwargs: dict | None = None,
|
||||
batch_size: int | None = None,
|
||||
hidden_dtype: mx.Dtype | None = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
added_cond_kwargs = added_cond_kwargs or {}
|
||||
|
||||
embedded_timestep = self.emb(
|
||||
timestep,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
**added_cond_kwargs,
|
||||
)
|
||||
|
||||
scale_shift_params = self.linear(self.silu(embedded_timestep))
|
||||
return scale_shift_params, embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
size_emb_dim: int = 0,
|
||||
use_additional_conditions: bool = False,
|
||||
timestep_proj_dim: int = 256,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.size_emb_dim = size_emb_dim
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
|
||||
|
||||
if use_additional_conditions and size_emb_dim > 0:
|
||||
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
resolution: mx.array | None = None,
|
||||
aspect_ratio: mx.array | None = None,
|
||||
batch_size: int | None = None,
|
||||
hidden_dtype: mx.Dtype | None = None,
|
||||
) -> mx.array:
|
||||
# Project timestep
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
if hidden_dtype is not None:
|
||||
timesteps_proj = timesteps_proj.astype(hidden_dtype)
|
||||
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj)
|
||||
|
||||
# Add additional conditions if enabled
|
||||
if self.use_additional_conditions and self.size_emb_dim > 0:
|
||||
if resolution is not None and aspect_ratio is not None:
|
||||
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
|
||||
timesteps_emb = timesteps_emb + additional_embeds
|
||||
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def __call__(self, timesteps: mx.array) -> mx.array:
|
||||
return get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_dim = out_dim or time_embed_dim
|
||||
self.linear1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
|
||||
self.linear2 = nn.Linear(time_embed_dim, out_dim)
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
sample = self.linear1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class ConditionEmbedding(nn.Module):
|
||||
def __init__(self, size_emb_dim: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.resolution_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
|
||||
self.aspect_ratio_embedder = TimestepEmbedding(size_emb_dim, embedding_dim)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
resolution: mx.array,
|
||||
aspect_ratio: mx.array,
|
||||
hidden_dtype: mx.Dtype | None = None,
|
||||
) -> mx.array:
|
||||
resolution_emb = self.resolution_embedder(resolution)
|
||||
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio)
|
||||
|
||||
if hidden_dtype is not None:
|
||||
resolution_emb = resolution_emb.astype(hidden_dtype)
|
||||
aspect_ratio_emb = aspect_ratio_emb.astype(hidden_dtype)
|
||||
|
||||
return resolution_emb + aspect_ratio_emb
|
||||
154
mlx_video/models/ltx_2/attention.py
Normal file
154
mlx_video/models/ltx_2/attention.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Attention module for LTX-2."""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType
|
||||
from mlx_video.models.ltx_2.rope import apply_rotary_emb
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
v: mx.array,
|
||||
heads: int,
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
|
||||
b, q_seq_len, dim = q.shape
|
||||
_, kv_seq_len, _ = k.shape
|
||||
dim_head = dim // heads
|
||||
|
||||
# Reshape to (B, seq_len, heads, dim_head)
|
||||
q = mx.reshape(q, (b, q_seq_len, heads, dim_head))
|
||||
k = mx.reshape(k, (b, kv_seq_len, heads, dim_head))
|
||||
v = mx.reshape(v, (b, kv_seq_len, heads, dim_head))
|
||||
|
||||
# Transpose to (B, heads, seq_len, dim_head)
|
||||
q = mx.swapaxes(q, 1, 2)
|
||||
k = mx.swapaxes(k, 1, 2)
|
||||
v = mx.swapaxes(v, 1, 2)
|
||||
|
||||
# Handle mask dimensions
|
||||
if mask is not None:
|
||||
# Add batch dimension if needed
|
||||
if mask.ndim == 2:
|
||||
mask = mx.expand_dims(mask, axis=0)
|
||||
# Add heads dimension if needed
|
||||
if mask.ndim == 3:
|
||||
mask = mx.expand_dims(mask, axis=1)
|
||||
|
||||
# Compute scaled dot-product attention
|
||||
scale = 1.0 / math.sqrt(dim_head)
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask=mask)
|
||||
|
||||
# Reshape back to (B, q_seq_len, heads * dim_head)
|
||||
out = mx.swapaxes(out, 1, 2)
|
||||
out = mx.reshape(out, (b, q_seq_len, heads * dim_head))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head attention with rotary position embeddings.
|
||||
|
||||
Supports both self-attention and cross-attention.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
context_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
norm_eps: float = 1e-6,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
has_gate_logits: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.rope_type = rope_type
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = query_dim if context_dim is None else context_dim
|
||||
|
||||
# Q, K, V projections
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=True)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=True)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=True)
|
||||
|
||||
# Q and K normalization
|
||||
self.q_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
|
||||
self.k_norm = nn.RMSNorm(inner_dim, eps=norm_eps)
|
||||
|
||||
# Output projection
|
||||
self.to_out = nn.Linear(inner_dim, query_dim, bias=True)
|
||||
|
||||
# Per-head gating (LTX-2.3)
|
||||
if has_gate_logits:
|
||||
self.to_gate_logits = nn.Linear(query_dim, heads, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
context: Optional[mx.array] = None,
|
||||
mask: Optional[mx.array] = None,
|
||||
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
k_pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
skip_attention: bool = False,
|
||||
) -> mx.array:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Query input of shape (B, seq_len, query_dim)
|
||||
context: Context for cross-attention. If None, uses x (self-attention)
|
||||
mask: Attention mask
|
||||
pe: Position embeddings for query (and key if k_pe is None)
|
||||
k_pe: Position embeddings for key (optional, uses pe if None)
|
||||
skip_attention: If True, bypass Q*K*V attention and use value projection
|
||||
only (for STG perturbation). Matches PyTorch all_perturbed=True.
|
||||
|
||||
Returns:
|
||||
Attention output of shape (B, seq_len, query_dim)
|
||||
"""
|
||||
# Compute per-head gate early (from original input)
|
||||
gate = None
|
||||
if hasattr(self, "to_gate_logits"):
|
||||
gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads)
|
||||
|
||||
context = x if context is None else context
|
||||
v = self.to_v(context)
|
||||
|
||||
if skip_attention:
|
||||
# STG: bypass Q*K*V attention, use value projection only
|
||||
out = v
|
||||
else:
|
||||
# Standard attention
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(context)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe, self.rope_type)
|
||||
k_pe_to_use = pe if k_pe is None else k_pe
|
||||
k = apply_rotary_emb(k, k_pe_to_use, self.rope_type)
|
||||
|
||||
out = scaled_dot_product_attention(q, k, v, self.heads, mask)
|
||||
|
||||
# Apply per-head gating
|
||||
if gate is not None:
|
||||
b, seq_len, _ = out.shape
|
||||
out = mx.reshape(out, (b, seq_len, self.heads, self.dim_head))
|
||||
out = out * gate[..., None]
|
||||
out = mx.reshape(out, (b, seq_len, -1))
|
||||
|
||||
# Project output
|
||||
return self.to_out(out)
|
||||
48
mlx_video/models/ltx_2/audio_vae/__init__.py
Normal file
48
mlx_video/models/ltx_2/audio_vae/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Audio VAE module for LTX-2 audio generation."""
|
||||
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
|
||||
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, ResnetBlock
|
||||
from .upsample import Upsample, build_upsampling_path
|
||||
from .vocoder import Vocoder, load_vocoder
|
||||
|
||||
__all__ = [
|
||||
# Main components
|
||||
"AudioEncoder",
|
||||
"AudioDecoder",
|
||||
"Vocoder",
|
||||
"load_vocoder",
|
||||
"decode_audio",
|
||||
# Audio processing
|
||||
"load_audio",
|
||||
"ensure_stereo",
|
||||
"waveform_to_mel",
|
||||
# Ops
|
||||
"AudioLatentShape",
|
||||
"AudioPatchifier",
|
||||
"PerChannelStatistics",
|
||||
# Building blocks
|
||||
"AttentionType",
|
||||
"AttnBlock",
|
||||
"make_attn",
|
||||
"CausalConv2d",
|
||||
"make_conv2d",
|
||||
"CausalityAxis",
|
||||
"Downsample",
|
||||
"build_downsampling_path",
|
||||
"NormType",
|
||||
"PixelNorm",
|
||||
"build_normalization_layer",
|
||||
"ResBlock1",
|
||||
"ResBlock2",
|
||||
"ResnetBlock",
|
||||
"LRELU_SLOPE",
|
||||
"Upsample",
|
||||
"build_upsampling_path",
|
||||
]
|
||||
108
mlx_video/models/ltx_2/audio_vae/attention.py
Normal file
108
mlx_video/models/ltx_2/audio_vae/attention.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Attention blocks for audio VAE."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
"""Enum for specifying the attention mechanism type."""
|
||||
|
||||
VANILLA = "vanilla"
|
||||
LINEAR = "linear"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""Self-attention block for audio VAE."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: NormType = NormType.GROUP,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
# Using Conv2d with kernel_size=1 for Q, K, V projections
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass through attention block.
|
||||
Args:
|
||||
x: Input tensor of shape (B, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Output tensor with attention applied (residual connection)
|
||||
"""
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# Compute attention
|
||||
# x shape: (B, H, W, C)
|
||||
b, h, w, c = q.shape
|
||||
|
||||
# Reshape for attention: (B, H*W, C)
|
||||
q = q.reshape(b, h * w, c)
|
||||
k = k.reshape(b, h * w, c)
|
||||
v = v.reshape(b, h * w, c)
|
||||
|
||||
# Attention: Q @ K^T / sqrt(d)
|
||||
# q: (B, HW, C), k: (B, HW, C) -> k^T: (B, C, HW)
|
||||
# w_: (B, HW, HW)
|
||||
scale = float(c) ** (-0.5)
|
||||
w_ = mx.matmul(q, k.transpose(0, 2, 1)) * scale
|
||||
w_ = mx.softmax(w_, axis=-1)
|
||||
|
||||
# Attend to values
|
||||
# w_: (B, HW, HW), v: (B, HW, C) -> h_: (B, HW, C)
|
||||
h_ = mx.matmul(w_, v)
|
||||
|
||||
# Reshape back to spatial dims
|
||||
h_ = h_.reshape(b, h, w, c)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
"""Identity module that returns input unchanged."""
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return x
|
||||
|
||||
|
||||
def make_attn(
|
||||
in_channels: int,
|
||||
attn_type: AttentionType = AttentionType.VANILLA,
|
||||
norm_type: NormType = NormType.GROUP,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create an attention module based on type.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
attn_type: Type of attention mechanism
|
||||
norm_type: Type of normalization
|
||||
Returns:
|
||||
Attention module
|
||||
"""
|
||||
if attn_type == AttentionType.VANILLA:
|
||||
return AttnBlock(in_channels, norm_type=norm_type)
|
||||
elif attn_type == AttentionType.NONE:
|
||||
return Identity()
|
||||
elif attn_type == AttentionType.LINEAR:
|
||||
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
135
mlx_video/models/ltx_2/audio_vae/audio_processor.py
Normal file
135
mlx_video/models/ltx_2/audio_vae/audio_processor.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Audio processing utilities for loading audio files and computing mel-spectrograms.
|
||||
|
||||
Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrogram)
|
||||
using librosa for macOS/MLX compatibility.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
target_sr: int = 16000,
|
||||
start_time: float = 0.0,
|
||||
max_duration: float | None = None,
|
||||
mono: bool = False,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""Load audio file, resample to target sample rate.
|
||||
|
||||
Args:
|
||||
path: Path to audio file (WAV, FLAC, MP3, OGG, or video with audio track).
|
||||
target_sr: Target sample rate (default 16000 Hz).
|
||||
start_time: Start time in seconds.
|
||||
max_duration: Maximum duration in seconds. None = read to end.
|
||||
mono: If True, convert to mono. Default False (preserve channels).
|
||||
|
||||
Returns:
|
||||
(waveform, sample_rate) where waveform is (channels, samples) float32 numpy array.
|
||||
"""
|
||||
import librosa
|
||||
|
||||
# librosa.load returns mono by default; we want to preserve stereo
|
||||
y, sr = librosa.load(
|
||||
path,
|
||||
sr=target_sr,
|
||||
mono=mono,
|
||||
offset=start_time,
|
||||
duration=max_duration,
|
||||
)
|
||||
|
||||
# Ensure 2D: (channels, samples)
|
||||
if y.ndim == 1:
|
||||
y = y[np.newaxis, :] # (1, samples)
|
||||
|
||||
return y.astype(np.float32), sr
|
||||
|
||||
|
||||
def ensure_stereo(waveform: np.ndarray) -> np.ndarray:
|
||||
"""Ensure waveform is stereo (2, samples). Duplicates mono if needed."""
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform[np.newaxis, :]
|
||||
if waveform.shape[0] == 1:
|
||||
waveform = np.concatenate([waveform, waveform], axis=0)
|
||||
elif waveform.shape[0] > 2:
|
||||
waveform = waveform[:2]
|
||||
return waveform
|
||||
|
||||
|
||||
def waveform_to_mel(
|
||||
waveform: np.ndarray,
|
||||
sample_rate: int = 16000,
|
||||
n_fft: int = 1024,
|
||||
hop_length: int = 160,
|
||||
win_length: int = 1024,
|
||||
n_mels: int = 64,
|
||||
fmin: float = 0.0,
|
||||
fmax: float = 8000.0,
|
||||
) -> mx.array:
|
||||
"""Convert waveform to log-mel spectrogram matching PyTorch MelSpectrogram.
|
||||
|
||||
PyTorch reference:
|
||||
MelSpectrogram(sample_rate=16000, n_fft=1024, win_length=1024, hop_length=160,
|
||||
f_min=0.0, f_max=8000.0, n_mels=64, power=1.0,
|
||||
mel_scale="slaney", norm="slaney", center=True, pad_mode="reflect")
|
||||
|
||||
Args:
|
||||
waveform: (channels, samples) float32 numpy array.
|
||||
sample_rate: Sample rate of the waveform.
|
||||
n_fft: FFT size.
|
||||
hop_length: Hop length.
|
||||
win_length: Window length.
|
||||
n_mels: Number of mel bins.
|
||||
fmin: Minimum frequency for mel filterbank.
|
||||
fmax: Maximum frequency for mel filterbank.
|
||||
|
||||
Returns:
|
||||
Log-mel spectrogram as mx.array of shape (1, channels, time, n_mels).
|
||||
"""
|
||||
import librosa
|
||||
|
||||
# Ensure 2D
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform[np.newaxis, :]
|
||||
|
||||
channels = waveform.shape[0]
|
||||
mels = []
|
||||
|
||||
for ch in range(channels):
|
||||
# Magnitude spectrogram (power=1.0)
|
||||
S = np.abs(librosa.stft(
|
||||
waveform[ch],
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
))
|
||||
|
||||
# Mel filterbank with slaney normalization
|
||||
mel_basis = librosa.filters.mel(
|
||||
sr=sample_rate,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
norm="slaney",
|
||||
)
|
||||
mel = mel_basis @ S
|
||||
|
||||
# Log scale
|
||||
mel = np.log(np.clip(mel, a_min=1e-5, a_max=None))
|
||||
|
||||
# Transpose: (n_mels, time) -> (time, n_mels)
|
||||
mel = mel.T
|
||||
mels.append(mel)
|
||||
|
||||
# Stack channels: (channels, time, n_mels)
|
||||
mel_spec = np.stack(mels, axis=0)
|
||||
|
||||
# Add batch dim: (1, channels, time, n_mels)
|
||||
mel_spec = mel_spec[np.newaxis, ...]
|
||||
|
||||
return mx.array(mel_spec, dtype=mx.float32)
|
||||
532
mlx_video/models/ltx_2/audio_vae/audio_vae.py
Normal file
532
mlx_video/models/ltx_2/audio_vae/audio_vae.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""Audio VAE encoder and decoder for LTX-2."""
|
||||
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
from .resnet import ResnetBlock
|
||||
from .upsample import build_upsampling_path
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
def build_mid_block(
|
||||
channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float,
|
||||
norm_type: NormType,
|
||||
causality_axis: CausalityAxis,
|
||||
attn_type: AttentionType,
|
||||
add_attention: bool,
|
||||
) -> dict:
|
||||
"""Build the middle block with two ResNet blocks and optional attention."""
|
||||
mid = {}
|
||||
mid["block_1"] = ResnetBlock(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
temb_channels=temb_channels,
|
||||
dropout=dropout,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
mid["attn_1"] = (
|
||||
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
|
||||
)
|
||||
mid["block_2"] = ResnetBlock(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
temb_channels=temb_channels,
|
||||
dropout=dropout,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
return mid
|
||||
|
||||
|
||||
def run_mid_block(mid: dict, features: mx.array) -> mx.array:
|
||||
"""Run features through the middle block."""
|
||||
features = mid["block_1"](features, temb=None)
|
||||
if mid["attn_1"] is not None:
|
||||
features = mid["attn_1"](features)
|
||||
return mid["block_2"](features, temb=None)
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
"""Encoder that compresses audio spectrograms into latent representations."""
|
||||
|
||||
def __init__(self, config: AudioEncoderModelConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||
self.sample_rate = config.sample_rate
|
||||
self.mel_hop_length = config.mel_hop_length
|
||||
self.is_causal = config.is_causal
|
||||
self.mel_bins = config.mel_bins
|
||||
|
||||
self.patchifier = AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=config.sample_rate,
|
||||
hop_length=config.mel_hop_length,
|
||||
is_causal=config.is_causal,
|
||||
)
|
||||
|
||||
self.ch = config.ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(config.ch_mult)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
self.resolution = config.resolution
|
||||
self.in_channels = config.in_channels
|
||||
self.z_channels = config.z_channels
|
||||
self.double_z = config.double_z
|
||||
self.norm_type = config.norm_type
|
||||
self.causality_axis = config.causality_axis
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.in_channels, self.ch, kernel_size=3, stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.down, block_in = build_downsampling_path(
|
||||
ch=config.ch,
|
||||
ch_mult=config.ch_mult,
|
||||
num_resolutions=self.num_resolutions,
|
||||
num_res_blocks=config.num_res_blocks,
|
||||
resolution=config.resolution,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
attn_resolutions=config.attn_resolutions or set(),
|
||||
resamp_with_conv=config.resamp_with_conv,
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
add_attention=config.mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
|
||||
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
|
||||
self.conv_out = make_conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio encoder weights from PyTorch format."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
if key.startswith("audio_vae.encoder."):
|
||||
new_key = key.replace("audio_vae.encoder.", "")
|
||||
elif key.startswith("encoder."):
|
||||
new_key = key.replace("encoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif "per_channel_statistics" in key:
|
||||
if "mean-of-means" in key or "latents_mean" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key or "latents_std" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
elif key == "latents_mean":
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif key == "latents_std":
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
|
||||
"""Load audio encoder from pretrained weights."""
|
||||
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
encoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
encoder.load_weights(list(weights.items()), strict=True)
|
||||
return encoder
|
||||
|
||||
def __call__(self, spectrogram: mx.array) -> mx.array:
|
||||
"""Encode audio spectrogram into normalized latent representation.
|
||||
|
||||
Args:
|
||||
spectrogram: (B, C, T, F) PyTorch format or (B, T, F, C) MLX format.
|
||||
Returns:
|
||||
Normalized latent (B, T', F', z_channels) in MLX channels-last format.
|
||||
"""
|
||||
if spectrogram.ndim == 4 and spectrogram.shape[1] == self.in_channels:
|
||||
spectrogram = mx.transpose(spectrogram, (0, 2, 3, 1))
|
||||
|
||||
h = self.conv_in(spectrogram)
|
||||
h = self._run_downsampling_path(h)
|
||||
h = run_mid_block(self.mid, h)
|
||||
h = self._finalize_output(h)
|
||||
return self._normalize_latents(h)
|
||||
|
||||
def _run_downsampling_path(self, h: mx.array) -> mx.array:
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx in range(self.num_res_blocks):
|
||||
h = stage["block"][block_idx](h, temb=None)
|
||||
if block_idx in stage["attn"]:
|
||||
h = stage["attn"][block_idx](h)
|
||||
if level != self.num_resolutions - 1 and "downsample" in stage:
|
||||
h = stage["downsample"](h)
|
||||
return h
|
||||
|
||||
def _finalize_output(self, h: mx.array) -> mx.array:
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
return self.conv_out(h)
|
||||
|
||||
def _normalize_latents(self, h: mx.array) -> mx.array:
|
||||
"""Normalize encoder output using per-channel statistics.
|
||||
|
||||
Takes first half of channels (mean) when double_z=True,
|
||||
then patchifies, normalizes, and unpatchifies.
|
||||
"""
|
||||
# h shape: (B, T', F', 2*z_channels) in MLX format
|
||||
z_channels = self.z_channels
|
||||
means = h[..., :z_channels]
|
||||
|
||||
latent_shape = AudioLatentShape(
|
||||
batch=means.shape[0],
|
||||
channels=means.shape[3],
|
||||
frames=means.shape[1],
|
||||
mel_bins=means.shape[2],
|
||||
)
|
||||
|
||||
patched = self.patchifier.patchify(means)
|
||||
normalized = self.per_channel_statistics.normalize(patched)
|
||||
return self.patchifier.unpatchify(normalized, latent_shape)
|
||||
|
||||
|
||||
class AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
The decoder mirrors the encoder structure with configurable channel multipliers,
|
||||
attention resolutions, and causal convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AudioDecoderModelConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the AudioDecoder.
|
||||
Args:
|
||||
ch: Base number of feature channels
|
||||
out_ch: Number of output channels (2 for stereo)
|
||||
ch_mult: Multiplicative factors for channels at each resolution
|
||||
num_res_blocks: Number of residual blocks per resolution
|
||||
attn_resolutions: Resolutions at which to apply attention
|
||||
resolution: Input spatial resolution
|
||||
z_channels: Number of latent channels
|
||||
norm_type: Normalization type
|
||||
causality_axis: Axis for causal convolutions
|
||||
dropout: Dropout probability
|
||||
mid_block_add_attention: Whether to add attention in middle block
|
||||
sample_rate: Audio sample rate
|
||||
mel_hop_length: Hop length for mel spectrogram
|
||||
is_causal: Whether to use causal convolutions
|
||||
mel_bins: Number of mel frequency bins
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
# Uses ch (base channel count) to match the patchified latent dimension
|
||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
|
||||
# ch=128 matches this dimension, so use ch for per_channel_statistics
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||
self.sample_rate = config.sample_rate
|
||||
self.mel_hop_length = config.mel_hop_length
|
||||
self.is_causal = config.is_causal
|
||||
self.mel_bins = config.mel_bins
|
||||
|
||||
self.patchifier = AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=config.sample_rate,
|
||||
hop_length=config.mel_hop_length,
|
||||
is_causal=config.is_causal,
|
||||
)
|
||||
|
||||
self.ch = config.ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(config.ch_mult)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
self.resolution = config.resolution
|
||||
self.out_ch = config.out_ch
|
||||
self.give_pre_end = config.give_pre_end
|
||||
self.tanh_out = config.tanh_out
|
||||
self.norm_type = config.norm_type
|
||||
self.z_channels = config.z_channels
|
||||
self.channel_multipliers = config.ch_mult
|
||||
self.attn_resolutions = config.attn_resolutions
|
||||
self.causality_axis = config.causality_axis
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
base_block_channels = config.ch * self.channel_multipliers[-1]
|
||||
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
add_attention=config.mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.up, final_block_channels = build_upsampling_path(
|
||||
ch=config.ch,
|
||||
ch_mult=config.ch_mult,
|
||||
num_resolutions=self.num_resolutions,
|
||||
num_res_blocks=config.num_res_blocks,
|
||||
resolution=config.resolution,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
attn_resolutions=config.attn_resolutions,
|
||||
resamp_with_conv=config.resamp_with_conv,
|
||||
initial_block_channels=base_block_channels,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for audio VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle audio_vae.decoder weights
|
||||
if key.startswith("audio_vae.decoder."):
|
||||
new_key = key.replace("audio_vae.decoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
# Map per-channel statistics
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
else:
|
||||
continue # Skip non-decoder keys
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
decoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
# weights = decoder.sanitize(weights)
|
||||
decoder.load_weights(list(weights.items()), strict=True)
|
||||
return decoder
|
||||
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
Args:
|
||||
sample: Encoded latent representation of shape (B, H, W, C) in MLX format
|
||||
or (B, C, H, W) in PyTorch format (will be transposed)
|
||||
Returns:
|
||||
Reconstructed audio spectrogram
|
||||
"""
|
||||
# Handle input format - if channels are in dim 1, transpose to channels-last
|
||||
if sample.shape[1] == self.z_channels and sample.ndim == 4:
|
||||
# PyTorch format (B, C, H, W) -> MLX format (B, H, W, C)
|
||||
sample = mx.transpose(sample, (0, 2, 3, 1))
|
||||
|
||||
sample, target_shape = self._denormalize_latents(sample)
|
||||
|
||||
h = self.conv_in(sample)
|
||||
h = run_mid_block(self.mid, h)
|
||||
h = self._run_upsampling_path(h)
|
||||
h = self._finalize_output(h)
|
||||
|
||||
return self._adjust_output_shape(h, target_shape)
|
||||
|
||||
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
# sample shape: (B, H, W, C) in MLX format
|
||||
latent_shape = AudioLatentShape(
|
||||
batch=sample.shape[0],
|
||||
channels=sample.shape[3], # channels last
|
||||
frames=sample.shape[1], # height = frames
|
||||
mel_bins=sample.shape[2], # width = mel_bins
|
||||
)
|
||||
|
||||
sample_patched = self.patchifier.patchify(sample)
|
||||
sample_denormalized = self.per_channel_statistics.un_normalize(sample_patched)
|
||||
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
||||
|
||||
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
if self.causality_axis != CausalityAxis.NONE:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_shape = AudioLatentShape(
|
||||
batch=latent_shape.batch,
|
||||
channels=self.out_ch,
|
||||
frames=target_frames,
|
||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||
)
|
||||
|
||||
return sample, target_shape
|
||||
|
||||
def _adjust_output_shape(
|
||||
self,
|
||||
decoded_output: mx.array,
|
||||
target_shape: AudioLatentShape,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Adjust output shape to match target dimensions for variable-length audio.
|
||||
Args:
|
||||
decoded_output: Tensor of shape (B, H, W, C) in MLX format
|
||||
target_shape: AudioLatentShape describing target dimensions
|
||||
Returns:
|
||||
Tensor adjusted to match target_shape exactly
|
||||
"""
|
||||
# Current output shape: (batch, frames, mel_bins, channels) in MLX format
|
||||
_, current_time, current_freq, _ = decoded_output.shape
|
||||
target_channels = target_shape.channels
|
||||
target_time = target_shape.frames
|
||||
target_freq = target_shape.mel_bins
|
||||
|
||||
# Step 1: Crop first to avoid exceeding target dimensions
|
||||
decoded_output = decoded_output[
|
||||
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
|
||||
]
|
||||
|
||||
# Step 2: Calculate padding needed for time and frequency dimensions
|
||||
time_padding_needed = target_time - decoded_output.shape[1]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[2]
|
||||
|
||||
# Step 3: Apply padding if needed
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
# MLX pad: [(before_0, after_0), ...]
|
||||
# For (B, H, W, C): H=time, W=freq
|
||||
padding = [
|
||||
(0, 0), # batch
|
||||
(0, max(time_padding_needed, 0)), # time
|
||||
(0, max(freq_padding_needed, 0)), # freq
|
||||
(0, 0), # channels
|
||||
]
|
||||
decoded_output = mx.pad(decoded_output, padding)
|
||||
|
||||
# Step 4: Final safety crop to ensure exact target shape
|
||||
decoded_output = decoded_output[:, :target_time, :target_freq, :target_channels]
|
||||
|
||||
# Transpose back to PyTorch format (B, C, H, W) for vocoder compatibility
|
||||
decoded_output = mx.transpose(decoded_output, (0, 3, 1, 2))
|
||||
|
||||
return decoded_output
|
||||
|
||||
def _run_upsampling_path(self, h: mx.array) -> mx.array:
|
||||
"""Run through upsampling path."""
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx in range(len(stage["block"])):
|
||||
h = stage["block"][block_idx](h, temb=None)
|
||||
if block_idx in stage["attn"]:
|
||||
h = stage["attn"][block_idx](h)
|
||||
|
||||
if level != 0 and "upsample" in stage:
|
||||
h = stage["upsample"](h)
|
||||
|
||||
return h
|
||||
|
||||
def _finalize_output(self, h: mx.array) -> mx.array:
|
||||
"""Apply final normalization and convolution."""
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return mx.tanh(h) if self.tanh_out else h
|
||||
|
||||
|
||||
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
|
||||
"""
|
||||
Decode an audio latent representation using the provided audio decoder and vocoder.
|
||||
Args:
|
||||
latent: Input audio latent tensor
|
||||
audio_decoder: Model to decode the latent to spectrogram
|
||||
vocoder: Model to convert spectrogram to audio waveform
|
||||
Returns:
|
||||
Decoded audio as a float tensor
|
||||
"""
|
||||
decoded_audio = audio_decoder(latent)
|
||||
decoded_audio = vocoder(decoded_audio)
|
||||
# Remove batch dimension if present
|
||||
if decoded_audio.shape[0] == 1:
|
||||
decoded_audio = decoded_audio[0]
|
||||
return decoded_audio.astype(mx.float32)
|
||||
146
mlx_video/models/ltx_2/audio_vae/causal_conv_2d.py
Normal file
146
mlx_video/models/ltx_2/audio_vae/causal_conv_2d.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Causal 2D convolutions for audio VAE."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..config import CausalityAxis
|
||||
|
||||
|
||||
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
|
||||
"""Convert int or tuple to tuple pair."""
|
||||
if isinstance(x, int):
|
||||
return (x, x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution.
|
||||
This layer ensures that the output at time `t` only depends on inputs
|
||||
at time `t` and earlier. It achieves this by applying asymmetric padding
|
||||
to the time dimension before the convolution.
|
||||
|
||||
Note: MLX Conv2d expects input shape (N, H, W, C) - channels last.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
# Ensure kernel_size and dilation are tuples
|
||||
kernel_size = _pair(kernel_size)
|
||||
dilation = _pair(dilation)
|
||||
|
||||
# Calculate padding dimensions
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
# Store padding for manual application
|
||||
# MLX pad order: [(before_axis0, after_axis0), (before_axis1, after_axis1), ...]
|
||||
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
# Non-causal: symmetric padding
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
|
||||
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
|
||||
# Causal on width: pad left (before width axis)
|
||||
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
# Causal on height: pad top (before height axis)
|
||||
self.padding = (pad_h, 0, pad_w // 2, pad_w - pad_w // 2)
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
# The internal convolution layer uses no padding, as we handle it manually
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass with causal padding.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Output tensor after causal convolution
|
||||
"""
|
||||
# Apply causal padding before convolution
|
||||
# padding format: (pad_h_top, pad_h_bottom, pad_w_left, pad_w_right)
|
||||
pad_h_top, pad_h_bottom, pad_w_left, pad_w_right = self.padding
|
||||
|
||||
if any(p > 0 for p in self.padding):
|
||||
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
|
||||
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
|
||||
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def make_conv2d(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
padding: Union[int, Tuple[int, int], None] = None,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: CausalityAxis | None = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create a 2D convolution layer that can be either causal or non-causal.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
out_channels: Number of output channels
|
||||
kernel_size: Size of the convolution kernel
|
||||
stride: Convolution stride
|
||||
padding: Padding (if None, will be calculated based on causal flag)
|
||||
dilation: Dilation rate
|
||||
groups: Number of groups for grouped convolution
|
||||
bias: Whether to use bias
|
||||
causality_axis: Dimension along which to apply causality.
|
||||
Returns:
|
||||
Either a regular Conv2d or CausalConv2d layer
|
||||
"""
|
||||
if causality_axis is not None:
|
||||
# For causal convolution, padding is handled internally by CausalConv2d
|
||||
return CausalConv2d(
|
||||
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
||||
)
|
||||
else:
|
||||
# For non-causal convolution, use symmetric padding if not specified
|
||||
if padding is None:
|
||||
if isinstance(kernel_size, int):
|
||||
padding = kernel_size // 2
|
||||
else:
|
||||
padding = tuple(k // 2 for k in kernel_size)
|
||||
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
127
mlx_video/models/ltx_2/audio_vae/downsample.py
Normal file
127
mlx_video/models/ltx_2/audio_vae/downsample.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Downsampling layers for audio VAE."""
|
||||
|
||||
from typing import Set, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer that can use either a strided convolution
|
||||
or average pooling. Supports standard and causal padding for the
|
||||
convolutional mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
with_conv: bool,
|
||||
causality_axis: CausalityAxis = CausalityAxis.WIDTH,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
|
||||
raise ValueError("causality is only supported when `with_conv=True`.")
|
||||
|
||||
if self.with_conv:
|
||||
# Do time downsampling here
|
||||
# no asymmetric padding in MLX conv, must do it ourselves
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass with downsampling.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Downsampled tensor
|
||||
"""
|
||||
if self.with_conv:
|
||||
# Padding tuple is in the order: (left, right, top, bottom) for PyTorch
|
||||
# For MLX pad: [(before_axis0, after_axis0), ...]
|
||||
# x shape: (N, H, W, C) -> pad on H and W axes
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
# pad: (left=0, right=1, top=0, bottom=1)
|
||||
pad = [(0, 0), (0, 1), (0, 1), (0, 0)]
|
||||
elif self.causality_axis == CausalityAxis.WIDTH:
|
||||
# pad: (left=2, right=0, top=0, bottom=1)
|
||||
pad = [(0, 0), (0, 1), (2, 0), (0, 0)]
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
# pad: (left=0, right=1, top=2, bottom=0)
|
||||
pad = [(0, 0), (2, 0), (0, 1), (0, 0)]
|
||||
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
# pad: (left=1, right=0, top=0, bottom=1)
|
||||
pad = [(0, 0), (0, 1), (1, 0), (0, 0)]
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
x = mx.pad(x, pad, constant_values=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# Average pooling with 2x2 kernel and stride 2
|
||||
# MLX doesn't have built-in avg_pool2d, implement manually
|
||||
# x shape: (N, H, W, C)
|
||||
n, h, w, c = x.shape
|
||||
# Reshape to (N, H//2, 2, W//2, 2, C) and mean over pooling dims
|
||||
x = x.reshape(n, h // 2, 2, w // 2, 2, c)
|
||||
x = mx.mean(x, axis=(2, 4))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_downsampling_path(
|
||||
*,
|
||||
ch: int,
|
||||
ch_mult: Tuple[int, ...],
|
||||
num_resolutions: int,
|
||||
num_res_blocks: int,
|
||||
resolution: int,
|
||||
temb_channels: int,
|
||||
dropout: float,
|
||||
norm_type: NormType,
|
||||
causality_axis: CausalityAxis,
|
||||
attn_type: AttentionType,
|
||||
attn_resolutions: Set[int],
|
||||
resamp_with_conv: bool,
|
||||
) -> tuple[dict, int]:
|
||||
"""Build the downsampling path with residual blocks, attention, and downsampling layers."""
|
||||
down_modules = {}
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, *tuple(ch_mult))
|
||||
block_in = ch
|
||||
|
||||
for i_level in range(num_resolutions):
|
||||
stage = {}
|
||||
stage["block"] = {}
|
||||
stage["attn"] = {}
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for i_block in range(num_res_blocks):
|
||||
stage["block"][i_block] = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=temb_channels,
|
||||
dropout=dropout,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
|
||||
if i_level != num_resolutions - 1:
|
||||
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
down_modules[i_level] = stage
|
||||
|
||||
return down_modules, block_in
|
||||
59
mlx_video/models/ltx_2/audio_vae/normalization.py
Normal file
59
mlx_video/models/ltx_2/audio_vae/normalization.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Normalization layers for audio VAE."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class NormType(Enum):
|
||||
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||
|
||||
GROUP = "group"
|
||||
PIXEL = "pixel"
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
For each element along the chosen dimension, this layer normalizes the tensor
|
||||
by the root-mean-square of its values across that dimension:
|
||||
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Dimension along which to compute the RMS (typically channels).
|
||||
eps: Small constant added for numerical stability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Apply RMS normalization along the configured dimension."""
|
||||
mean_sq = mx.mean(x**2, axis=self.dim, keepdims=True)
|
||||
rms = mx.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
def build_normalization_layer(
|
||||
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create a normalization layer based on the normalization type.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
num_groups: Number of groups for group normalization
|
||||
normtype: Type of normalization: "group" or "pixel"
|
||||
Returns:
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
|
||||
if normtype == NormType.PIXEL:
|
||||
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
|
||||
# PyTorch uses dim=1 for channels-first format (B, C, H, W)
|
||||
return PixelNorm(dim=-1, eps=1e-6)
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
98
mlx_video/models/ltx_2/audio_vae/ops.py
Normal file
98
mlx_video/models/ltx_2/audio_vae/ops.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Audio processing utilities for audio VAE."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioLatentShape:
|
||||
"""Shape descriptor for audio latent representations."""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
mel_bins: int
|
||||
|
||||
|
||||
class PerChannelStatistics(nn.Module):
|
||||
"""
|
||||
Per-channel statistics for normalizing and denormalizing the latent representation.
|
||||
This statistics is computed over the entire dataset and stored in model's checkpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, latent_channels: int = 128) -> None:
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
# Initialize buffers - will be loaded from weights
|
||||
# Using underscores for MLX compatibility with weight loading
|
||||
self.std_of_means = mx.ones((latent_channels,))
|
||||
self.mean_of_means = mx.zeros((latent_channels,))
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latent representation."""
|
||||
# Broadcast statistics to match x shape
|
||||
# x shape: (B, C, ...) or (B, ..., C)
|
||||
std = self.std_of_means.astype(x.dtype)
|
||||
mean = self.mean_of_means.astype(x.dtype)
|
||||
return (x * std) + mean
|
||||
|
||||
def normalize(self, x: mx.array) -> mx.array:
|
||||
"""Normalize latent representation."""
|
||||
std = self.std_of_means.astype(x.dtype)
|
||||
mean = self.mean_of_means.astype(x.dtype)
|
||||
return (x - mean) / std
|
||||
|
||||
|
||||
class AudioPatchifier:
|
||||
"""
|
||||
Audio patchifier for converting between audio latents and patches.
|
||||
Combines channels and mel_bins dimensions for per-channel statistics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
self.patch_size = patch_size
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.is_causal = is_causal
|
||||
|
||||
def patchify(self, x: mx.array) -> mx.array:
|
||||
"""Convert audio latents to patches.
|
||||
|
||||
Input shape: (B, T, F, C) in MLX format (channels last)
|
||||
Output shape: (B, T, C*F) - flattened for per-channel statistics
|
||||
|
||||
The output order is (c f) to match PyTorch's "b c t f -> b t (c f)".
|
||||
"""
|
||||
# x shape: (B, T, F, C) e.g., (1, 68, 16, 8)
|
||||
b, t, f, c = x.shape
|
||||
# Transpose to (B, T, C, F) for correct (c f) ordering
|
||||
x = mx.transpose(x, (0, 1, 3, 2))
|
||||
# Reshape to (B, T, C*F) e.g., (1, 68, 128)
|
||||
return x.reshape(b, t, c * f)
|
||||
|
||||
def unpatchify(self, x: mx.array, latent_shape: AudioLatentShape) -> mx.array:
|
||||
"""Convert patches back to audio latents.
|
||||
|
||||
Input shape: (B, T, C*F)
|
||||
Output shape: (B, T, F, C) in MLX format
|
||||
|
||||
Reverses patchify's "b t (c f) -> b c t f" then transposes to MLX format.
|
||||
"""
|
||||
# x shape: (B, T, C*F) e.g., (1, 68, 128)
|
||||
b, t, cf = x.shape
|
||||
c = latent_shape.channels
|
||||
f = latent_shape.mel_bins
|
||||
# Reshape to (B, T, C, F)
|
||||
x = x.reshape(b, t, c, f)
|
||||
# Transpose to MLX format (B, T, F, C)
|
||||
return mx.transpose(x, (0, 1, 3, 2))
|
||||
185
mlx_video/models/ltx_2/audio_vae/resnet.py
Normal file
185
mlx_video/models/ltx_2/audio_vae/resnet.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""ResNet blocks for audio VAE and vocoder."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def leaky_relu(x: mx.array, negative_slope: float = LRELU_SLOPE) -> mx.array:
|
||||
"""Leaky ReLU activation."""
|
||||
return mx.maximum(x, x * negative_slope)
|
||||
|
||||
|
||||
class ResBlock1(nn.Module):
|
||||
"""1D ResNet block for vocoder with dilated convolutions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: Tuple[int, int, int] = (1, 3, 5),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# First set of convolutions with different dilations
|
||||
self.convs1 = {
|
||||
i: nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=(kernel_size - 1) * d // 2,
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
# Second set of convolutions with dilation=1
|
||||
self.convs2 = {
|
||||
i: nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
for i in range(len(dilation))
|
||||
}
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Forward pass through residual blocks."""
|
||||
for i in range(len(self.convs1)):
|
||||
xt = leaky_relu(x, LRELU_SLOPE)
|
||||
xt = self.convs1[i](xt)
|
||||
xt = leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = self.convs2[i](xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock2(nn.Module):
|
||||
"""1D ResNet block for vocoder (alternative version)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: Tuple[int, int] = (1, 3),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.convs = {
|
||||
i: nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=d,
|
||||
padding=(kernel_size - 1) * d // 2,
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Forward pass through residual blocks."""
|
||||
for i in range(len(self.convs)):
|
||||
xt = leaky_relu(x, LRELU_SLOPE)
|
||||
xt = self.convs[i](xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
"""2D ResNet block for audio VAE encoder/decoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
out_channels: int | None = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
norm_type: NormType = NormType.GROUP,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis != CausalityAxis.NONE and norm_type == NormType.GROUP:
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.temb_channels = temb_channels
|
||||
|
||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
self.conv1 = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
|
||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||
self.dropout_rate = dropout
|
||||
self.conv2 = make_conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
temb: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Forward pass through ResNet block.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||
temb: Optional time embedding tensor
|
||||
Returns:
|
||||
Output tensor
|
||||
"""
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nn.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None and self.temb_channels > 0:
|
||||
# temb: (B, temb_channels) -> (B, out_channels)
|
||||
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
|
||||
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nn.silu(h)
|
||||
if self.dropout_rate > 0:
|
||||
h = nn.Dropout(self.dropout_rate)(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
135
mlx_video/models/ltx_2/audio_vae/upsample.py
Normal file
135
mlx_video/models/ltx_2/audio_vae/upsample.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Upsampling layers for audio VAE."""
|
||||
|
||||
from typing import Set, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
def nearest_neighbor_upsample(x: mx.array, scale_factor: int = 2) -> mx.array:
|
||||
"""
|
||||
Nearest neighbor upsampling for 4D tensors.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C)
|
||||
scale_factor: Upsampling factor
|
||||
Returns:
|
||||
Upsampled tensor of shape (N, H*scale_factor, W*scale_factor, C)
|
||||
"""
|
||||
n, h, w, c = x.shape
|
||||
# Repeat along height and width
|
||||
x = mx.repeat(x, scale_factor, axis=1)
|
||||
x = mx.repeat(x, scale_factor, axis=2)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""Upsampling layer with optional convolution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
with_conv: bool,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
Forward pass with upsampling.
|
||||
Args:
|
||||
x: Input tensor of shape (N, H, W, C) in MLX channels-last format
|
||||
Returns:
|
||||
Upsampled tensor
|
||||
"""
|
||||
# Nearest neighbor 2x upsampling
|
||||
x = nearest_neighbor_upsample(x, scale_factor=2)
|
||||
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
|
||||
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
|
||||
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
|
||||
# So the output elements rely on the following windows:
|
||||
# 0: [-,-,0]
|
||||
# 1: [-,0,0]
|
||||
# 2: [0,0,1]
|
||||
# 3: [0,1,1]
|
||||
# 4: [1,1,2]
|
||||
# 5: [1,2,2]
|
||||
# Notice that the first and second elements in the output rely only on the first element in the input,
|
||||
# while all other elements rely on two elements in the input.
|
||||
# So we can drop the first element to undo the padding (rather than the last element).
|
||||
# This is a no-op for non-causal convolutions.
|
||||
if self.causality_axis == CausalityAxis.NONE:
|
||||
pass # x remains unchanged
|
||||
elif self.causality_axis == CausalityAxis.HEIGHT:
|
||||
x = x[:, 1:, :, :]
|
||||
elif self.causality_axis == CausalityAxis.WIDTH:
|
||||
x = x[:, :, 1:, :]
|
||||
elif self.causality_axis == CausalityAxis.WIDTH_COMPATIBILITY:
|
||||
pass # x remains unchanged
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_upsampling_path(
|
||||
*,
|
||||
ch: int,
|
||||
ch_mult: Tuple[int, ...],
|
||||
num_resolutions: int,
|
||||
num_res_blocks: int,
|
||||
resolution: int,
|
||||
temb_channels: int,
|
||||
dropout: float,
|
||||
norm_type: NormType,
|
||||
causality_axis: CausalityAxis,
|
||||
attn_type: AttentionType,
|
||||
attn_resolutions: Set[int],
|
||||
resamp_with_conv: bool,
|
||||
initial_block_channels: int,
|
||||
) -> tuple[dict, int]:
|
||||
"""Build the upsampling path with residual blocks, attention, and upsampling layers."""
|
||||
up_modules = {}
|
||||
block_in = initial_block_channels
|
||||
curr_res = resolution // (2 ** (num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(num_resolutions)):
|
||||
stage = {}
|
||||
stage["block"] = {}
|
||||
stage["attn"] = {}
|
||||
block_out = ch * ch_mult[level]
|
||||
|
||||
for i_block in range(num_res_blocks + 1):
|
||||
stage["block"][i_block] = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=temb_channels,
|
||||
dropout=dropout,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
)
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
|
||||
|
||||
if level != 0:
|
||||
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
up_modules[level] = stage
|
||||
|
||||
return up_modules, block_in
|
||||
689
mlx_video/models/ltx_2/audio_vae/vocoder.py
Normal file
689
mlx_video/models/ltx_2/audio_vae/vocoder.py
Normal file
@@ -0,0 +1,689 @@
|
||||
"""Vocoder for converting mel spectrograms to audio waveforms.
|
||||
|
||||
Supports:
|
||||
- HiFi-GAN (LTX-2): ResBlock1 with LeakyReLU
|
||||
- BigVGAN v2 (LTX-2.3): AMPBlock1 with Snake/SnakeBeta + anti-aliased resampling
|
||||
- VocoderWithBWE (LTX-2.3): Base vocoder + bandwidth extension (16kHz -> 48kHz)
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..config import VocoderModelConfig
|
||||
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
|
||||
|
||||
|
||||
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snake / SnakeBeta activations (BigVGAN v2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
"""Snake activation: x + (1/alpha) * sin^2(alpha * x)."""
|
||||
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
alpha = self.alpha # (C,)
|
||||
if self.alpha_logscale:
|
||||
alpha = mx.exp(alpha)
|
||||
return x + (1.0 / (alpha + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""SnakeBeta activation: x + (1/beta) * sin^2(alpha * x)."""
|
||||
|
||||
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
alpha = self.alpha
|
||||
beta = self.beta
|
||||
if self.alpha_logscale:
|
||||
alpha = mx.exp(alpha)
|
||||
beta = mx.exp(beta)
|
||||
return x + (1.0 / (beta + 1e-9)) * mx.power(mx.sin(x * alpha), 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling (Kaiser-sinc filters)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: mx.array) -> mx.array:
|
||||
return mx.where(
|
||||
x == 0,
|
||||
mx.ones_like(x),
|
||||
mx.sin(mx.array(math.pi) * x) / (mx.array(math.pi) * x),
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
|
||||
"""Compute a Kaiser-windowed sinc filter."""
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if amplitude > 50.0:
|
||||
beta = 0.1102 * (amplitude - 8.7)
|
||||
elif amplitude >= 21.0:
|
||||
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
|
||||
# Kaiser window - compute using scipy-compatible formula
|
||||
import numpy as np
|
||||
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
|
||||
|
||||
if even:
|
||||
time = mx.arange(-half_size, half_size).astype(mx.float32) + 0.5
|
||||
else:
|
||||
time = mx.arange(kernel_size).astype(mx.float32) - half_size
|
||||
|
||||
if cutoff == 0:
|
||||
filter_ = mx.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ = filter_ / mx.sum(filter_)
|
||||
|
||||
return filter_.reshape(1, 1, kernel_size)
|
||||
|
||||
|
||||
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
|
||||
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
|
||||
import numpy as np
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
kernel_size = 2 * width * ratio + 1
|
||||
pad = width
|
||||
pad_left = 2 * width * ratio
|
||||
pad_right = kernel_size - ratio
|
||||
|
||||
time = (np.arange(kernel_size) / ratio - width) * rolloff
|
||||
time_clamped = np.clip(time, -lowpass_filter_width, lowpass_filter_width)
|
||||
window = np.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
sinc_filter = np.sinc(time) * window * rolloff / ratio
|
||||
|
||||
filter_ = mx.array(sinc_filter.astype(np.float32)).reshape(1, 1, kernel_size)
|
||||
return filter_, pad, pad_left, pad_right
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
"""Low-pass filter using depthwise convolution with Kaiser-sinc kernel."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cutoff: float = 0.5,
|
||||
half_width: float = 0.6,
|
||||
stride: int = 1,
|
||||
kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
# Filter buffer - shape (1, 1, K) in PyTorch format, loaded from weights
|
||||
self.filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
n, l, c = x.shape
|
||||
|
||||
# Pad with edge values: replicate first/last value
|
||||
first = mx.repeat(x[:, :1, :], self.pad_left, axis=1)
|
||||
last = mx.repeat(x[:, -1:, :], self.pad_right, axis=1)
|
||||
x = mx.concatenate([first, x, last], axis=1)
|
||||
|
||||
# Expand filter for depthwise conv: (1, 1, K) -> (C, K, 1) for groups=C
|
||||
# Filter is stored in PyTorch format (1, 1, K), need (C, K, 1) MLX format
|
||||
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
||||
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
||||
filt = mx.repeat(filt, c, axis=0) # (C, K, 1)
|
||||
|
||||
# Transpose x for depthwise conv: (N, L, C) -> (N*C, L, 1) then conv
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
|
||||
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
|
||||
|
||||
x = mx.conv1d(x, filt[:1], stride=self.stride, groups=1) # (N*C, L', 1)
|
||||
|
||||
x = x.reshape(n, c, -1) # (N, C, L')
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
||||
return x
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
"""Anti-aliased upsampling using transposed convolution with sinc filter."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int = None,
|
||||
window_type: str = "kaiser",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
filt, self.pad, self.pad_left, self.pad_right = hann_sinc_filter1d(ratio)
|
||||
self.kernel_size = filt.shape[2]
|
||||
self.filter = filt
|
||||
else:
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
self.filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, L, C) in MLX format
|
||||
n, l, c = x.shape
|
||||
|
||||
# Pad with edge values
|
||||
first = mx.repeat(x[:, :1, :], self.pad, axis=1)
|
||||
last = mx.repeat(x[:, -1:, :], self.pad, axis=1)
|
||||
x = mx.concatenate([first, x, last], axis=1)
|
||||
|
||||
# Process per-channel via reshape: (N, L, C) -> (N*C, L, 1)
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, C, L)
|
||||
x = x.reshape(n * c, -1, 1) # (N*C, L, 1)
|
||||
|
||||
# Transposed conv for upsampling
|
||||
# Filter: (1, 1, K) PyTorch -> (1, K, 1) MLX format for conv_transpose1d
|
||||
filt = self.filter.astype(x.dtype) # (1, 1, K)
|
||||
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
|
||||
|
||||
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
|
||||
|
||||
# Trim padding
|
||||
x = x[:, self.pad_left:-self.pad_right, :]
|
||||
|
||||
x = x.reshape(n, c, -1) # (N, C, L')
|
||||
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
"""Anti-aliased downsampling using low-pass filter."""
|
||||
|
||||
def __init__(self, ratio: int = 2, kernel_size: int = None) -> None:
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=kernel_size,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
"""Anti-aliased activation: upsample -> activate -> downsample."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation: nn.Module,
|
||||
up_ratio: int = 2,
|
||||
down_ratio: int = 2,
|
||||
up_kernel_size: int = 12,
|
||||
down_kernel_size: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
return self.downsample(x)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AMPBlock1 (BigVGAN v2 residual block)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AMPBlock1(nn.Module):
|
||||
"""BigVGAN v2 residual block with anti-aliased Snake activations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
dilation: Tuple[int, int, int] = (1, 3, 5),
|
||||
activation: str = "snakebeta",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
|
||||
self.convs1 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=d, padding=get_padding(kernel_size, d),
|
||||
)
|
||||
for i, d in enumerate(dilation)
|
||||
}
|
||||
|
||||
self.convs2 = {
|
||||
i: nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1,
|
||||
dilation=1, padding=get_padding(kernel_size, 1),
|
||||
)
|
||||
for i in range(len(dilation))
|
||||
}
|
||||
|
||||
self.acts1 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
|
||||
self.acts2 = {i: Activation1d(act_cls(channels)) for i in range(len(dilation))}
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
for i in range(len(self.convs1)):
|
||||
xt = self.acts1[i](x)
|
||||
xt = self.convs1[i](xt)
|
||||
xt = self.acts2[i](xt)
|
||||
xt = self.convs2[i](xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# STFT and MelSTFT (for BWE)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class STFTFn(nn.Module):
|
||||
"""STFT via conv1d with precomputed DFT x window bases (loaded from checkpoint)."""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int) -> None:
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
# Buffers loaded from checkpoint - PyTorch format (n_freqs*2, 1, filter_length)
|
||||
self.forward_basis = mx.zeros((n_freqs * 2, 1, filter_length))
|
||||
self.inverse_basis = mx.zeros((n_freqs * 2, 1, filter_length))
|
||||
|
||||
def __call__(self, y: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute magnitude and phase from waveform.
|
||||
|
||||
Args:
|
||||
y: (B, T) waveform
|
||||
|
||||
Returns:
|
||||
magnitude: (B, n_freqs, T_frames)
|
||||
phase: (B, n_freqs, T_frames)
|
||||
"""
|
||||
if y.ndim == 2:
|
||||
y = mx.expand_dims(y, -1) # (B, T, 1)
|
||||
|
||||
left_pad = max(0, self.win_length - self.hop_length)
|
||||
if left_pad > 0:
|
||||
first = mx.repeat(y[:, :1, :], left_pad, axis=1)
|
||||
y = mx.concatenate([first, y], axis=1)
|
||||
|
||||
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
|
||||
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
|
||||
|
||||
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
|
||||
spec = mx.conv1d(y, basis, stride=self.hop_length)
|
||||
|
||||
# Split real and imaginary
|
||||
n_freqs = spec.shape[-1] // 2
|
||||
real = spec[..., :n_freqs]
|
||||
imag = spec[..., n_freqs:]
|
||||
|
||||
magnitude = mx.sqrt(real ** 2 + imag ** 2)
|
||||
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
|
||||
|
||||
# Output: (B, T_frames, n_freqs) in MLX channels-last
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram from precomputed STFT bases."""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
|
||||
super().__init__()
|
||||
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.mel_basis = mx.zeros((n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(self, y: mx.array) -> mx.array:
|
||||
"""Compute log-mel spectrogram.
|
||||
|
||||
Args:
|
||||
y: (B, T) waveform
|
||||
|
||||
Returns:
|
||||
log_mel: (B, n_mels, T_frames) in channels-first for compatibility
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
# magnitude: (B, T_frames, n_freqs)
|
||||
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
|
||||
log_mel = mx.log(mx.clip(mel, 1e-5, None))
|
||||
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
|
||||
return mx.transpose(log_mel, (0, 2, 1))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Vocoder (supports both HiFi-GAN and BigVGAN v2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Vocoder(nn.Module):
|
||||
"""Vocoder for mel-to-waveform synthesis.
|
||||
|
||||
Supports resblock="1" (HiFi-GAN / LTX-2) and resblock="AMP1" (BigVGAN v2 / LTX-2.3).
|
||||
"""
|
||||
|
||||
def __init__(self, config: VocoderModelConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.output_sampling_rate = config.output_sample_rate
|
||||
self.num_kernels = len(config.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(config.upsample_rates)
|
||||
self.upsample_rates = config.upsample_rates
|
||||
self.is_amp = config.resblock == "AMP1"
|
||||
self.use_tanh_at_final = config.use_tanh_at_final
|
||||
self.apply_final_activation = config.apply_final_activation
|
||||
|
||||
in_channels = 128 if config.stereo else 64
|
||||
self.conv_pre = nn.Conv1d(
|
||||
in_channels, config.upsample_initial_channel,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
)
|
||||
|
||||
# Upsampling layers
|
||||
self.ups = {}
|
||||
for i, (stride, kernel_size) in enumerate(
|
||||
zip(config.upsample_rates, config.upsample_kernel_sizes)
|
||||
):
|
||||
in_ch = config.upsample_initial_channel // (2 ** i)
|
||||
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups[i] = nn.ConvTranspose1d(
|
||||
in_ch, out_ch,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
|
||||
# Residual blocks
|
||||
if self.is_amp:
|
||||
self.resblocks = {}
|
||||
block_idx = 0
|
||||
for i in range(len(self.ups)):
|
||||
ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = AMPBlock1(
|
||||
ch, kernel_size, tuple(dilations),
|
||||
activation=config.activation,
|
||||
)
|
||||
block_idx += 1
|
||||
else:
|
||||
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
|
||||
self.resblocks = {}
|
||||
block_idx = 0
|
||||
for i in range(len(self.ups)):
|
||||
ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(
|
||||
config.resblock_kernel_sizes, config.resblock_dilation_sizes
|
||||
):
|
||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||
block_idx += 1
|
||||
|
||||
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
|
||||
|
||||
# Post-activation
|
||||
if self.is_amp:
|
||||
act_cls = SnakeBeta if config.activation == "snakebeta" else Snake
|
||||
self.act_post = Activation1d(act_cls(final_channels))
|
||||
|
||||
# Final conv
|
||||
out_channels = 2 if config.stereo else 1
|
||||
self.conv_post = nn.Conv1d(
|
||||
final_channels, out_channels,
|
||||
kernel_size=7, stride=1, padding=3,
|
||||
bias=config.use_bias_at_final,
|
||||
)
|
||||
|
||||
self.upsample_factor = math.prod(config.upsample_rates)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Mel spectrogram (B, C, T, mel_bins) for stereo or (B, T, mel_bins) mono.
|
||||
|
||||
Returns:
|
||||
Waveform (B, out_channels, T_audio) in channels-first format.
|
||||
"""
|
||||
# (B, C, T, mel) -> (B, C, mel, T)
|
||||
x = mx.transpose(x, (0, 1, 3, 2))
|
||||
|
||||
if x.ndim == 4: # stereo: (B, 2, mel, T) -> (B, 2*mel, T)
|
||||
b, s, c, t = x.shape
|
||||
x = x.reshape(b, s * c, t)
|
||||
|
||||
# Channels-first (B, C, T) -> channels-last (B, T, C) for MLX conv
|
||||
x = mx.transpose(x, (0, 2, 1))
|
||||
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
if not self.is_amp:
|
||||
x = leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
|
||||
start = i * self.num_kernels
|
||||
end = start + self.num_kernels
|
||||
|
||||
block_outputs = mx.stack(
|
||||
[self.resblocks[idx](x) for idx in range(start, end)],
|
||||
axis=0,
|
||||
)
|
||||
x = mx.mean(block_outputs, axis=0)
|
||||
|
||||
if self.is_amp:
|
||||
x = self.act_post(x)
|
||||
else:
|
||||
x = nn.leaky_relu(x)
|
||||
|
||||
x = self.conv_post(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
x = mx.tanh(x) if self.use_tanh_at_final else mx.clip(x, -1, 1)
|
||||
|
||||
# Back to channels-first (B, T, C) -> (B, C, T)
|
||||
x = mx.transpose(x, (0, 2, 1))
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# VocoderWithBWE (Bandwidth Extension)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VocoderWithBWE(nn.Module):
|
||||
"""Vocoder + bandwidth extension upsampling (16kHz -> 48kHz).
|
||||
|
||||
Chains a base vocoder with a BWE generator that predicts a residual
|
||||
added to a sinc-resampled skip connection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocoder: Vocoder,
|
||||
bwe_generator: Vocoder,
|
||||
mel_stft: MelSTFT,
|
||||
input_sampling_rate: int = 16000,
|
||||
output_sampling_rate: int = 48000,
|
||||
hop_length: int = 80,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocoder = vocoder
|
||||
self.bwe_generator = bwe_generator
|
||||
self.mel_stft = mel_stft
|
||||
self.input_sampling_rate = input_sampling_rate
|
||||
self.output_sampling_rate = output_sampling_rate
|
||||
self.hop_length = hop_length
|
||||
# Hann-windowed sinc resampler (not stored in checkpoint)
|
||||
self.resampler = UpSample1d(
|
||||
ratio=output_sampling_rate // input_sampling_rate,
|
||||
window_type="hann",
|
||||
)
|
||||
|
||||
@property
|
||||
def output_sample_rate(self) -> int:
|
||||
return self.output_sampling_rate
|
||||
|
||||
def _compute_mel(self, audio: mx.array) -> mx.array:
|
||||
"""Compute log-mel spectrogram from waveform.
|
||||
|
||||
Args:
|
||||
audio: (B, C, T) waveform in channels-first
|
||||
|
||||
Returns:
|
||||
mel: (B, C, n_mels, T_frames)
|
||||
"""
|
||||
batch, n_channels, _ = audio.shape
|
||||
flat = audio.reshape(batch * n_channels, -1) # (B*C, T)
|
||||
mel = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(batch, n_channels, mel.shape[1], mel.shape[2])
|
||||
|
||||
def __call__(self, mel_spec: mx.array) -> mx.array:
|
||||
"""Run vocoder + BWE.
|
||||
|
||||
Args:
|
||||
mel_spec: Mel spectrogram, same format as Vocoder.forward input.
|
||||
|
||||
Returns:
|
||||
Waveform (B, out_channels, T_audio) at output_sampling_rate.
|
||||
"""
|
||||
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
|
||||
_, _, length_low_rate = x.shape
|
||||
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
|
||||
|
||||
# Pad to hop_length multiple
|
||||
remainder = length_low_rate % self.hop_length
|
||||
if remainder != 0:
|
||||
pad_amount = self.hop_length - remainder
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_amount)])
|
||||
|
||||
# Compute mel from vocoder output: (B, C, n_mels, T_frames)
|
||||
mel = self._compute_mel(x)
|
||||
|
||||
# BWE expects (B, C, T_frames, mel_bins) -> transpose last two dims
|
||||
mel_for_bwe = mx.transpose(mel, (0, 1, 3, 2)) # (B, C, T_frames, n_mels)
|
||||
residual = self.bwe_generator(mel_for_bwe) # (B, C, T_high)
|
||||
|
||||
# Sinc upsample skip connection
|
||||
# resampler expects (N, L, C): transpose from (B, C, T) -> (B, T, C)
|
||||
x_for_resample = mx.transpose(x, (0, 2, 1))
|
||||
skip = self.resampler(x_for_resample)
|
||||
skip = mx.transpose(skip, (0, 2, 1)) # back to (B, C, T)
|
||||
|
||||
return mx.clip(residual + skip, -1, 1)[..., :output_length]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory / from_pretrained
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def load_vocoder(model_path: Path) -> nn.Module:
|
||||
"""Load vocoder from pretrained model directory.
|
||||
|
||||
Automatically detects whether to load a simple Vocoder or VocoderWithBWE.
|
||||
"""
|
||||
import json
|
||||
|
||||
config_path = model_path / "config.json"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"No config.json found in {model_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
|
||||
has_bwe = config_dict.get("has_bwe_generator", False)
|
||||
|
||||
if has_bwe:
|
||||
return _load_vocoder_with_bwe(config_dict, weights)
|
||||
else:
|
||||
config = VocoderModelConfig.from_dict(config_dict)
|
||||
model = Vocoder(config)
|
||||
model.load_weights(list(weights.items()), strict=True)
|
||||
return model
|
||||
|
||||
|
||||
def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
|
||||
"""Load VocoderWithBWE from config and weights."""
|
||||
# Build vocoder from config
|
||||
vocoder_cfg = config_dict.get("vocoder", {})
|
||||
vocoder_config = VocoderModelConfig.from_dict(vocoder_cfg)
|
||||
vocoder = Vocoder(vocoder_config)
|
||||
|
||||
# Build BWE generator from config
|
||||
bwe_cfg = config_dict.get("bwe", {})
|
||||
bwe_config = VocoderModelConfig.from_dict(bwe_cfg)
|
||||
bwe_config.apply_final_activation = False
|
||||
bwe_generator = Vocoder(bwe_config)
|
||||
|
||||
# MelSTFT from weight shapes
|
||||
stft_basis = weights.get("mel_stft.stft_fn.forward_basis")
|
||||
filter_length = stft_basis.shape[2] if stft_basis is not None else 512
|
||||
mel_basis = weights.get("mel_stft.mel_basis")
|
||||
n_mel_channels = mel_basis.shape[0] if mel_basis is not None else 64
|
||||
|
||||
hop_length = bwe_cfg.get("hop_length", 80)
|
||||
input_sr = bwe_cfg.get("input_sampling_rate", 16000)
|
||||
output_sr = bwe_cfg.get("output_sampling_rate", 48000)
|
||||
|
||||
mel_stft = MelSTFT(
|
||||
filter_length=filter_length,
|
||||
hop_length=hop_length,
|
||||
win_length=filter_length,
|
||||
n_mel_channels=n_mel_channels,
|
||||
)
|
||||
|
||||
model = VocoderWithBWE(
|
||||
vocoder=vocoder,
|
||||
bwe_generator=bwe_generator,
|
||||
mel_stft=mel_stft,
|
||||
input_sampling_rate=input_sr,
|
||||
output_sampling_rate=output_sr,
|
||||
hop_length=hop_length,
|
||||
)
|
||||
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
3
mlx_video/models/ltx_2/conditioning/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Conditioning modules for LTX-2 video generation."""
|
||||
|
||||
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||
199
mlx_video/models/ltx_2/conditioning/latent.py
Normal file
199
mlx_video/models/ltx_2/conditioning/latent.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Latent-based conditioning for I2V (Image-to-Video) generation.
|
||||
|
||||
This module provides conditioning that injects encoded image latents into
|
||||
the video generation process at specific frame positions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoConditionByLatentIndex:
|
||||
"""Condition video generation by injecting latents at a specific frame index.
|
||||
|
||||
This replaces the latent at the specified frame index with the conditioned
|
||||
latent and controls how much denoising is applied via the strength parameter.
|
||||
|
||||
Args:
|
||||
latent: Encoded image latent of shape (B, C, 1, H, W)
|
||||
frame_idx: Frame index to condition (0 = first frame)
|
||||
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
||||
"""
|
||||
latent: mx.array
|
||||
frame_idx: int = 0
|
||||
strength: float = 1.0
|
||||
|
||||
def get_num_latent_frames(self) -> int:
|
||||
"""Get number of latent frames in the conditioning."""
|
||||
return self.latent.shape[2]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LatentState:
|
||||
"""State for latent diffusion with conditioning support.
|
||||
|
||||
Attributes:
|
||||
latent: Current noisy latent (B, C, F, H, W)
|
||||
clean_latent: Clean conditioning latent (B, C, F, H, W)
|
||||
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
||||
1.0 = full denoise, 0.0 = keep clean
|
||||
"""
|
||||
latent: mx.array
|
||||
clean_latent: mx.array
|
||||
denoise_mask: mx.array
|
||||
|
||||
def clone(self) -> "LatentState":
|
||||
"""Create a copy of the state."""
|
||||
return LatentState(
|
||||
latent=self.latent,
|
||||
clean_latent=self.clean_latent,
|
||||
denoise_mask=self.denoise_mask,
|
||||
)
|
||||
|
||||
|
||||
def create_initial_state(
|
||||
shape: Tuple[int, ...],
|
||||
seed: Optional[int] = None,
|
||||
noise_scale: float = 1.0,
|
||||
) -> LatentState:
|
||||
"""Create initial noisy latent state.
|
||||
|
||||
Args:
|
||||
shape: Shape of latent (B, C, F, H, W)
|
||||
seed: Optional random seed
|
||||
noise_scale: Scale for initial noise (sigma)
|
||||
|
||||
Returns:
|
||||
Initial LatentState with random noise
|
||||
"""
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
noise = mx.random.normal(shape)
|
||||
|
||||
return LatentState(
|
||||
latent=noise * noise_scale,
|
||||
clean_latent=mx.zeros(shape),
|
||||
denoise_mask=mx.ones((shape[0], 1, shape[2], 1, 1)), # Full denoise by default
|
||||
)
|
||||
|
||||
|
||||
def apply_conditioning(
|
||||
state: LatentState,
|
||||
conditionings: List[VideoConditionByLatentIndex],
|
||||
) -> LatentState:
|
||||
"""Apply conditioning items to a latent state.
|
||||
|
||||
Args:
|
||||
state: Current latent state
|
||||
conditionings: List of conditioning items to apply
|
||||
|
||||
Returns:
|
||||
Updated LatentState with conditioning applied
|
||||
"""
|
||||
state = state.clone()
|
||||
dtype = state.latent.dtype
|
||||
b, c, f, h, w = state.latent.shape
|
||||
|
||||
for cond in conditionings:
|
||||
cond_latent = cond.latent
|
||||
frame_idx = cond.frame_idx
|
||||
strength = cond.strength
|
||||
|
||||
# Validate shapes
|
||||
_, cond_c, cond_f, cond_h, cond_w = cond_latent.shape
|
||||
if (cond_c, cond_h, cond_w) != (c, h, w):
|
||||
raise ValueError(
|
||||
f"Conditioning latent spatial shape ({cond_c}, {cond_h}, {cond_w}) "
|
||||
f"does not match target shape ({c}, {h}, {w})"
|
||||
)
|
||||
|
||||
if frame_idx >= f:
|
||||
raise ValueError(
|
||||
f"Frame index {frame_idx} is out of bounds for latent with {f} frames"
|
||||
)
|
||||
|
||||
# Get the conditioning frames count
|
||||
num_cond_frames = cond_f
|
||||
end_idx = min(frame_idx + num_cond_frames, f)
|
||||
|
||||
# Replace latent at conditioning position
|
||||
# state.latent[:, :, frame_idx:end_idx] = cond_latent[:, :, :end_idx - frame_idx]
|
||||
latent_list = []
|
||||
clean_list = []
|
||||
mask_list = []
|
||||
|
||||
for i in range(f):
|
||||
if frame_idx <= i < end_idx:
|
||||
# Use conditioning latent
|
||||
cond_idx = i - frame_idx
|
||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||
else:
|
||||
# Keep original
|
||||
latent_list.append(state.latent[:, :, i:i+1])
|
||||
clean_list.append(state.clean_latent[:, :, i:i+1])
|
||||
mask_list.append(state.denoise_mask[:, :, i:i+1])
|
||||
|
||||
state.latent = mx.concatenate(latent_list, axis=2)
|
||||
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
||||
state.denoise_mask = mx.concatenate(mask_list, axis=2)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def apply_denoise_mask(
|
||||
denoised: mx.array,
|
||||
clean: mx.array,
|
||||
denoise_mask: mx.array,
|
||||
) -> mx.array:
|
||||
"""Blend denoised output with clean state based on mask.
|
||||
|
||||
Args:
|
||||
denoised: Denoised latent (B, C, F, H, W)
|
||||
clean: Clean conditioning latent (B, C, F, H, W)
|
||||
denoise_mask: Mask where 1.0 = use denoised, 0.0 = use clean
|
||||
|
||||
Returns:
|
||||
Blended latent
|
||||
"""
|
||||
one = mx.array(1.0, dtype=denoised.dtype)
|
||||
return denoised * denoise_mask + clean * (one - denoise_mask)
|
||||
|
||||
|
||||
def add_noise_with_state(
|
||||
state: LatentState,
|
||||
noise_scale: float,
|
||||
) -> LatentState:
|
||||
"""Add noise to state while respecting conditioning.
|
||||
|
||||
For conditioned frames (mask < 1.0), adds noise proportionally
|
||||
to allow some refinement while preserving the conditioning.
|
||||
|
||||
Args:
|
||||
state: Current latent state
|
||||
noise_scale: Scale for noise (sigma)
|
||||
|
||||
Returns:
|
||||
Updated state with noise added
|
||||
"""
|
||||
state = state.clone()
|
||||
|
||||
# Generate noise
|
||||
noise = mx.random.normal(state.latent.shape)
|
||||
|
||||
# For fully conditioned frames (mask=0), we want to add minimal noise
|
||||
# For unconditioned frames (mask=1), we want full noise
|
||||
# noisy = noise * sigma + latent * (1 - sigma)
|
||||
# But we scale sigma by the mask for conditioned regions
|
||||
|
||||
effective_scale = noise_scale * state.denoise_mask
|
||||
one = mx.array(1.0, dtype=state.latent.dtype)
|
||||
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
|
||||
|
||||
return state
|
||||
380
mlx_video/models/ltx_2/config.py
Normal file
380
mlx_video/models/ltx_2/config.py
Normal file
@@ -0,0 +1,380 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
class LTXModelType(Enum):
|
||||
AudioVideo = "ltx av model"
|
||||
VideoOnly = "ltx video only model"
|
||||
AudioOnly = "ltx audio only model"
|
||||
|
||||
def is_video_enabled(self) -> bool:
|
||||
return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly)
|
||||
|
||||
def is_audio_enabled(self) -> bool:
|
||||
return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly)
|
||||
|
||||
|
||||
class LTXRopeType(Enum):
|
||||
INTERLEAVED = "interleaved"
|
||||
SPLIT = "split"
|
||||
TWO_D = "2d"
|
||||
|
||||
class AttentionType(Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig":
|
||||
"""Create config from dictionary, filtering only valid parameters."""
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Export config to dictionary."""
|
||||
result = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if v is not None:
|
||||
if isinstance(v, Enum):
|
||||
result[k] = v.value
|
||||
elif hasattr(v, 'to_dict'):
|
||||
result[k] = v.to_dict()
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig(BaseModelConfig):
|
||||
dim: int
|
||||
heads: int
|
||||
d_head: int
|
||||
context_dim: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoVAEConfig(BaseModelConfig):
|
||||
convolution_dimensions: int = 3
|
||||
in_channels: int = 3
|
||||
out_channels: int = 128
|
||||
latent_channels: int = 128
|
||||
patch_size: int = 4
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
])
|
||||
decoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
])
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTXModelConfig(BaseModelConfig):
|
||||
|
||||
# Model type
|
||||
model_type: LTXModelType = LTXModelType.AudioVideo
|
||||
|
||||
# Video transformer config
|
||||
num_attention_heads: int = 32
|
||||
attention_head_dim: int = 128
|
||||
in_channels: int = 128
|
||||
out_channels: int = 128
|
||||
num_layers: int = 48
|
||||
cross_attention_dim: int = 4096
|
||||
caption_channels: int = 3840
|
||||
|
||||
# Audio transformer config
|
||||
audio_num_attention_heads: int = 32
|
||||
audio_attention_head_dim: int = 64
|
||||
audio_in_channels: int = 128
|
||||
audio_out_channels: int = 128
|
||||
audio_cross_attention_dim: int = 2048
|
||||
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
|
||||
|
||||
# Positional embedding config
|
||||
positional_embedding_theta: float = 10000.0
|
||||
positional_embedding_max_pos: Optional[List[int]] = None
|
||||
audio_positional_embedding_max_pos: Optional[List[int]] = None
|
||||
use_middle_indices_grid: bool = True
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED
|
||||
double_precision_rope: bool = False
|
||||
|
||||
# Timestep config
|
||||
timestep_scale_multiplier: int = 1000
|
||||
av_ca_timestep_scale_multiplier: int = 1000
|
||||
|
||||
# Normalization
|
||||
norm_eps: float = 1e-6
|
||||
|
||||
# Attention type
|
||||
attention_type: AttentionType = AttentionType.DEFAULT
|
||||
|
||||
# LTX-2.3: prompt-conditioned adaptive layer norm
|
||||
# Controls: gate_logits in attention, 9-param scale_shift_table,
|
||||
# prompt_adaln_single, per-block prompt_scale_shift_table,
|
||||
# removal of caption_projection
|
||||
has_prompt_adaln: bool = False
|
||||
|
||||
# VAE config
|
||||
vae_config: Optional[VideoVAEConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default values after initialization."""
|
||||
if self.positional_embedding_max_pos is None:
|
||||
self.positional_embedding_max_pos = [20, 2048, 2048]
|
||||
if self.audio_positional_embedding_max_pos is None:
|
||||
self.audio_positional_embedding_max_pos = [20]
|
||||
|
||||
# PyTorch LTX-2 configurator reads "frequencies_precision" (not
|
||||
# "double_precision_rope") from the config. For LTX-2 (no prompt adaln)
|
||||
# the key is absent, so double_precision_rope = False. For LTX-2.3
|
||||
# (has_prompt_adaln=True) the safetensors config has
|
||||
# frequencies_precision="float64", so double_precision_rope = True.
|
||||
if not self.has_prompt_adaln:
|
||||
self.double_precision_rope = False
|
||||
|
||||
# Convert string enum values if loading from dict
|
||||
if isinstance(self.model_type, str):
|
||||
self.model_type = LTXModelType(self.model_type)
|
||||
if isinstance(self.rope_type, str):
|
||||
self.rope_type = LTXRopeType(self.rope_type)
|
||||
if isinstance(self.attention_type, str):
|
||||
self.attention_type = AttentionType(self.attention_type)
|
||||
|
||||
@property
|
||||
def inner_dim(self) -> int:
|
||||
"""Video inner dimension."""
|
||||
return self.num_attention_heads * self.attention_head_dim
|
||||
|
||||
@property
|
||||
def audio_inner_dim(self) -> int:
|
||||
"""Audio inner dimension."""
|
||||
return self.audio_num_attention_heads * self.audio_attention_head_dim
|
||||
|
||||
def get_video_config(self) -> Optional[TransformerConfig]:
|
||||
"""Get video transformer configuration."""
|
||||
if not self.model_type.is_video_enabled():
|
||||
return None
|
||||
return TransformerConfig(
|
||||
dim=self.inner_dim,
|
||||
heads=self.num_attention_heads,
|
||||
d_head=self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
)
|
||||
|
||||
def get_audio_config(self) -> Optional[TransformerConfig]:
|
||||
"""Get audio transformer configuration."""
|
||||
if not self.model_type.is_audio_enabled():
|
||||
return None
|
||||
return TransformerConfig(
|
||||
dim=self.audio_inner_dim,
|
||||
heads=self.audio_num_attention_heads,
|
||||
d_head=self.audio_attention_head_dim,
|
||||
context_dim=self.audio_cross_attention_dim,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class CausalityAxis(Enum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioDecoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
out_ch: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
mid_block_add_attention: bool = True
|
||||
sample_rate: int = 16000
|
||||
mel_hop_length: int = 160
|
||||
is_causal: bool = True
|
||||
mel_bins: int | None = None
|
||||
resamp_with_conv: bool = True
|
||||
attn_type: str = None
|
||||
give_pre_end: bool = False
|
||||
tanh_out: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.attn_resolutions is not None:
|
||||
result["attn_resolutions"] = list(self.attn_resolutions)
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
# Import here to avoid circular imports
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
|
||||
# Convert causality_axis string to enum
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
|
||||
# Convert norm_type string to enum
|
||||
if isinstance(self.norm_type, str):
|
||||
self.norm_type = NormType(self.norm_type)
|
||||
|
||||
# Convert attn_type string to enum
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
@dataclass
|
||||
class AudioEncoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
in_channels: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
double_z: bool = True
|
||||
n_fft: int = 1024
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
mid_block_add_attention: bool = True
|
||||
sample_rate: int = 16000
|
||||
mel_hop_length: int = 160
|
||||
is_causal: bool = True
|
||||
mel_bins: int = 64
|
||||
resamp_with_conv: bool = True
|
||||
attn_type: str = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.attn_resolutions is not None:
|
||||
result["attn_resolutions"] = list(self.attn_resolutions)
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
if isinstance(self.norm_type, str):
|
||||
self.norm_type = NormType(self.norm_type)
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocoderModelConfig(BaseModelConfig):
|
||||
resblock_kernel_sizes: Optional[List[int]] = None
|
||||
upsample_rates: Optional[List[int]] = None
|
||||
upsample_kernel_sizes: Optional[List[int]] = None
|
||||
resblock_dilation_sizes: Optional[List[List[int]]] = None
|
||||
upsample_initial_channel: int = 1024
|
||||
stereo: bool = True
|
||||
resblock: str = "1"
|
||||
output_sample_rate: int = 24000
|
||||
activation: str = "snake"
|
||||
use_tanh_at_final: bool = True
|
||||
apply_final_activation: bool = True
|
||||
use_bias_at_final: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.resblock_kernel_sizes is None:
|
||||
self.resblock_kernel_sizes = [3, 7, 11]
|
||||
if self.upsample_rates is None:
|
||||
self.upsample_rates = [6, 5, 2, 2, 2]
|
||||
if self.upsample_kernel_sizes is None:
|
||||
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
||||
if self.resblock_dilation_sizes is None:
|
||||
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoDecoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
out_ch: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
timestep_conditioning: bool = False
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderModelConfig(BaseModelConfig):
|
||||
convolution_dimensions: int = 3
|
||||
in_channels: int = 3
|
||||
out_channels: int = 128
|
||||
patch_size: int = 4
|
||||
norm_layer: Enum = None
|
||||
latent_log_var: Enum = None
|
||||
encoder_spatial_padding_mode: Enum = None
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2})
|
||||
])
|
||||
|
||||
def __post_init__(self):
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
|
||||
if self.norm_layer is None:
|
||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||
if self.latent_log_var is None:
|
||||
self.latent_log_var = LogVarianceType.UNIFORM
|
||||
if self.encoder_spatial_padding_mode is None:
|
||||
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
|
||||
|
||||
if isinstance(self.norm_layer, str):
|
||||
self.norm_layer = NormLayerType(self.norm_layer)
|
||||
if isinstance(self.latent_log_var, str):
|
||||
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
||||
if isinstance(self.encoder_spatial_padding_mode, str):
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.encoder_blocks is not None:
|
||||
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
|
||||
return result
|
||||
785
mlx_video/models/ltx_2/convert.py
Normal file
785
mlx_video/models/ltx_2/convert.py
Normal file
@@ -0,0 +1,785 @@
|
||||
"""Convert LTX-2/2.3 safetensors to MLX directory layout.
|
||||
|
||||
Converts from the single-file format (e.g. Lightricks/LTX-2/ltx-2-19b-distilled.safetensors
|
||||
or Lightricks/LTX-2.3/ltx-2.3-22b-distilled.safetensors) to the modular directory structure:
|
||||
|
||||
output/
|
||||
├── transformer/ # DiT transformer weights (sharded)
|
||||
│ ├── config.json
|
||||
│ ├── model-00001-of-N.safetensors
|
||||
│ └── model.safetensors.index.json
|
||||
├── vae/
|
||||
│ ├── decoder/ # Video VAE decoder
|
||||
│ │ ├── config.json
|
||||
│ │ └── model.safetensors
|
||||
│ └── encoder/ # Video VAE encoder
|
||||
│ ├── config.json
|
||||
│ └── model.safetensors
|
||||
├── audio_vae/ # Audio VAE decoder
|
||||
│ ├── config.json
|
||||
│ └── model.safetensors
|
||||
├── vocoder/ # Audio vocoder
|
||||
│ ├── config.json
|
||||
│ └── model.safetensors
|
||||
└── text_projections/ # Text projection connectors
|
||||
└── model.safetensors
|
||||
|
||||
Usage:
|
||||
# From HF repo ID
|
||||
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2 --output LTX-2-distilled --variant distilled
|
||||
python -m mlx_video.models.ltx_2.convert --source Lightricks/LTX-2.3 --output LTX-2.3-distilled --variant distilled
|
||||
|
||||
# From local folder containing the monolithic safetensors
|
||||
python -m mlx_video.models.ltx_2.convert --source ./Lightricks-LTX-2/ --output LTX-2-distilled --variant distilled
|
||||
|
||||
# From a direct safetensors file path
|
||||
python -m mlx_video.models.ltx_2.convert --source ./ltx-2-19b-distilled.safetensors --output LTX-2-distilled --variant distilled
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ─── Key prefix routing ──────────────────────────────────────────────────────
|
||||
|
||||
TRANSFORMER_PREFIX = "model.diffusion_model."
|
||||
VAE_DECODER_PREFIX = "vae.decoder."
|
||||
VAE_ENCODER_PREFIX = "vae.encoder."
|
||||
VAE_STATS_PREFIX = "vae.per_channel_statistics."
|
||||
AUDIO_DECODER_PREFIX = "audio_vae.decoder."
|
||||
AUDIO_ENCODER_PREFIX = "audio_vae.encoder."
|
||||
AUDIO_STATS_PREFIX = "audio_vae.per_channel_statistics."
|
||||
VOCODER_PREFIX = "vocoder."
|
||||
TEXT_PROJ_PREFIX = "text_embedding_projection."
|
||||
VIDEO_CONNECTOR_PREFIX = "model.diffusion_model.video_embeddings_connector."
|
||||
AUDIO_CONNECTOR_PREFIX = "model.diffusion_model.audio_embeddings_connector."
|
||||
|
||||
|
||||
# ─── Sanitization functions ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize transformer keys: strip prefix, rename layers, cast to bfloat16."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(TRANSFORMER_PREFIX):
|
||||
continue
|
||||
# Skip connector weights (they go to text_projections)
|
||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
continue
|
||||
|
||||
new_key = key[len(TRANSFORMER_PREFIX):]
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
# Cast all weights to bfloat16 (matches MLX model loading behavior)
|
||||
if value.dtype != mx.bfloat16:
|
||||
value = value.astype(mx.bfloat16)
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize VAE decoder keys: strip prefix, transpose Conv3d, wrap .conv."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = None
|
||||
|
||||
if key.startswith(VAE_STATS_PREFIX):
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue
|
||||
elif key.startswith(VAE_DECODER_PREFIX):
|
||||
new_key = key[len(VAE_DECODER_PREFIX):]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Conv3d weight transpose: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Wrap .conv.weight -> .conv.conv.weight (CausalConv3d wrapper)
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize VAE encoder keys: strip prefix, transpose Conv3d/Conv2d."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = None
|
||||
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
if key.startswith(VAE_STATS_PREFIX):
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue
|
||||
# Per-channel statistics must stay float32 for precision
|
||||
if value.dtype != mx.float32:
|
||||
value = value.astype(mx.float32)
|
||||
elif key.startswith(VAE_ENCODER_PREFIX):
|
||||
new_key = key[len(VAE_ENCODER_PREFIX):]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio VAE decoder keys: strip prefix, transpose Conv2d."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = None
|
||||
|
||||
if key.startswith(AUDIO_DECODER_PREFIX):
|
||||
new_key = key[len(AUDIO_DECODER_PREFIX):]
|
||||
elif key.startswith(AUDIO_STATS_PREFIX):
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize vocoder keys: strip prefix, transpose Conv1d/ConvTranspose1d."""
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(VOCODER_PREFIX):
|
||||
continue
|
||||
|
||||
new_key = key[len(VOCODER_PREFIX):]
|
||||
|
||||
# Handle Conv1d/ConvTranspose1d weight shape conversion
|
||||
if "weight" in new_key and value.ndim == 3:
|
||||
if "ups" in new_key:
|
||||
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = mx.transpose(value, (1, 2, 0))
|
||||
else:
|
||||
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = mx.transpose(value, (0, 2, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def sanitize_connector_key(key: str) -> str:
|
||||
"""Sanitize connector sub-key names."""
|
||||
key = key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
key = key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
key = key.replace(".to_out.0.", ".to_out.")
|
||||
return key
|
||||
|
||||
|
||||
def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Extract text projection weights (aggregate_embed + connectors).
|
||||
|
||||
Handles both LTX-2 (aggregate_embed.weight) and LTX-2.3
|
||||
(video_aggregate_embed.*, audio_aggregate_embed.*) formats.
|
||||
"""
|
||||
extracted = {}
|
||||
|
||||
# aggregate_embed weights (text_embedding_projection.*)
|
||||
for key, value in weights.items():
|
||||
if key.startswith(TEXT_PROJ_PREFIX):
|
||||
new_key = key[len(TEXT_PROJ_PREFIX):]
|
||||
extracted[new_key] = value
|
||||
|
||||
# video_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
if key.startswith(VIDEO_CONNECTOR_PREFIX):
|
||||
suffix = key[len(VIDEO_CONNECTOR_PREFIX):]
|
||||
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
|
||||
extracted[new_key] = value
|
||||
|
||||
# audio_embeddings_connector
|
||||
for key, value in weights.items():
|
||||
if key.startswith(AUDIO_CONNECTOR_PREFIX):
|
||||
suffix = key[len(AUDIO_CONNECTOR_PREFIX):]
|
||||
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
|
||||
extracted[new_key] = value
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
# ─── Saving utilities ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def save_sharded(
|
||||
weights: Dict[str, mx.array],
|
||||
output_dir: Path,
|
||||
max_shard_size_bytes: int = 5 * 1024 * 1024 * 1024, # 5GB per shard
|
||||
):
|
||||
"""Save weights as sharded safetensors with an index file."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Sort keys for deterministic output
|
||||
sorted_keys = sorted(weights.keys())
|
||||
|
||||
# Calculate total size
|
||||
total_size = sum(weights[k].nbytes for k in sorted_keys)
|
||||
|
||||
# Determine sharding
|
||||
shards = []
|
||||
current_shard = {}
|
||||
current_size = 0
|
||||
|
||||
for key in sorted_keys:
|
||||
tensor = weights[key]
|
||||
tensor_size = tensor.nbytes
|
||||
|
||||
if current_size + tensor_size > max_shard_size_bytes and current_shard:
|
||||
shards.append(current_shard)
|
||||
current_shard = {}
|
||||
current_size = 0
|
||||
|
||||
current_shard[key] = tensor
|
||||
current_size += tensor_size
|
||||
|
||||
if current_shard:
|
||||
shards.append(current_shard)
|
||||
|
||||
num_shards = len(shards)
|
||||
weight_map = {}
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
if num_shards == 1:
|
||||
filename = "model.safetensors"
|
||||
else:
|
||||
filename = f"model-{i+1:05d}-of-{num_shards:05d}.safetensors"
|
||||
|
||||
mx.save_safetensors(str(output_dir / filename), shard)
|
||||
|
||||
for key in shard:
|
||||
weight_map[key] = filename
|
||||
|
||||
# Write index
|
||||
index = {
|
||||
"metadata": {"total_size": total_size},
|
||||
"weight_map": weight_map,
|
||||
}
|
||||
with open(output_dir / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
|
||||
return num_shards
|
||||
|
||||
|
||||
def save_single(weights: Dict[str, mx.array], output_dir: Path):
|
||||
"""Save weights as a single safetensors file with an index."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
mx.save_safetensors(str(output_dir / "model.safetensors"), weights)
|
||||
|
||||
# Also write index for consistency
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
weight_map = {k: "model.safetensors" for k in sorted(weights.keys())}
|
||||
index = {
|
||||
"metadata": {"total_size": total_size},
|
||||
"weight_map": weight_map,
|
||||
}
|
||||
with open(output_dir / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
|
||||
|
||||
def save_config(config: dict, output_dir: Path):
|
||||
"""Save config.json to a directory."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / "config.json", "w") as f:
|
||||
json.dump(config, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
# ─── Source resolution ─────────────────────────────────────────────────────────
|
||||
|
||||
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
|
||||
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
|
||||
|
||||
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
|
||||
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
|
||||
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$")
|
||||
|
||||
|
||||
def resolve_source(source: str, variant: str) -> Path:
|
||||
"""Resolve source to a monolithic safetensors file path.
|
||||
|
||||
Args:
|
||||
source: HF repo ID (e.g. "Lightricks/LTX-2"), local directory, or direct file path.
|
||||
variant: Model variant ("distilled" or "dev") to select the right file.
|
||||
|
||||
Returns:
|
||||
Path to the monolithic safetensors file.
|
||||
"""
|
||||
source_path = Path(source)
|
||||
|
||||
# Direct file path
|
||||
if source_path.is_file():
|
||||
return source_path
|
||||
|
||||
# Local directory — find the variant's safetensors file
|
||||
if source_path.is_dir():
|
||||
matches = []
|
||||
for f in sorted(source_path.glob("ltx-*b-*.safetensors")):
|
||||
m = MONOLITHIC_PATTERN.match(f.name)
|
||||
if m and m.group("variant") == variant:
|
||||
matches.append(f)
|
||||
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
# Broader fallback
|
||||
all_mono = sorted(source_path.glob("ltx-*.safetensors"))
|
||||
for f in all_mono:
|
||||
if variant in f.name and MONOLITHIC_PATTERN.match(f.name):
|
||||
return f
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"No monolithic *-{variant}.safetensors found in {source_path}. "
|
||||
f"Files found: {[f.name for f in all_mono]}"
|
||||
)
|
||||
|
||||
# HF repo ID — download via huggingface_hub
|
||||
if "/" in source and not source_path.exists():
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
|
||||
# Find the right file in the repo
|
||||
repo_files = list_repo_files(source)
|
||||
target = None
|
||||
for f in repo_files:
|
||||
m = MONOLITHIC_PATTERN.match(f)
|
||||
if m and m.group("variant") == variant:
|
||||
target = f
|
||||
break
|
||||
|
||||
if not target:
|
||||
raise FileNotFoundError(
|
||||
f"No *-{variant}.safetensors found in {source}. "
|
||||
f"Available: {[f for f in repo_files if f.endswith('.safetensors')]}"
|
||||
)
|
||||
|
||||
print(f"Downloading {target} from {source}...")
|
||||
local_path = hf_hub_download(repo_id=source, filename=target)
|
||||
return Path(local_path)
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Source not found: {source}. Provide an HF repo ID, local directory, or file path."
|
||||
)
|
||||
|
||||
|
||||
# ─── Config inference ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Infer transformer config from weight shapes."""
|
||||
# Count transformer layers
|
||||
max_layer = -1
|
||||
for key in weights:
|
||||
if "transformer_blocks." in key:
|
||||
parts = key.split(".")
|
||||
try:
|
||||
idx = parts.index("transformer_blocks") + 1
|
||||
if idx < len(parts) and parts[idx].isdigit():
|
||||
max_layer = max(max_layer, int(parts[idx]))
|
||||
except ValueError:
|
||||
pass
|
||||
num_layers = max_layer + 1 if max_layer >= 0 else 48
|
||||
|
||||
# Detect cross_attention_dim from attn2.to_k (cross-attention input dim)
|
||||
cross_attention_dim = 4096
|
||||
for key, value in weights.items():
|
||||
if "transformer_blocks.0.attn2.to_k.weight" in key:
|
||||
cross_attention_dim = value.shape[-1]
|
||||
break
|
||||
|
||||
# Check for prompt_adaln_single (LTX-2.3 feature)
|
||||
has_prompt_adaln = any("prompt_adaln_single" in k for k in weights)
|
||||
|
||||
config = {
|
||||
"attention_head_dim": 128,
|
||||
"attention_type": "default",
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_caption_channels": 3840,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_out_channels": 128,
|
||||
"audio_positional_embedding_max_pos": [20],
|
||||
"av_ca_timestep_scale_multiplier": 1000,
|
||||
"caption_channels": 3840,
|
||||
"cross_attention_dim": cross_attention_dim,
|
||||
"double_precision_rope": True,
|
||||
"in_channels": 128,
|
||||
"model_type": "ltx av model",
|
||||
"norm_eps": 1e-06,
|
||||
"num_attention_heads": 32,
|
||||
"num_layers": num_layers,
|
||||
"out_channels": 128,
|
||||
"positional_embedding_max_pos": [20, 2048, 2048],
|
||||
"positional_embedding_theta": 10000.0,
|
||||
"rope_type": "split",
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"use_middle_indices_grid": True,
|
||||
}
|
||||
|
||||
if has_prompt_adaln:
|
||||
config["has_prompt_adaln"] = True
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
|
||||
"""Infer VAE decoder config from weights."""
|
||||
# Check for timestep conditioning keys
|
||||
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
|
||||
|
||||
# Count channel multipliers from up_blocks
|
||||
max_block = -1
|
||||
for key in weights:
|
||||
if "up_blocks." in key:
|
||||
parts = key.split(".")
|
||||
try:
|
||||
idx = parts.index("up_blocks") + 1
|
||||
if idx < len(parts) and parts[idx].isdigit():
|
||||
max_block = max(max_block, int(parts[idx]))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Default config
|
||||
config = {
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"dropout": 0.0,
|
||||
"num_res_blocks": 2,
|
||||
"out_ch": 2,
|
||||
"resolution": 256,
|
||||
"timestep_conditioning": has_timestep,
|
||||
"z_channels": 8,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def infer_vae_encoder_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Return VAE encoder config (architecture is consistent across versions)."""
|
||||
return {
|
||||
"convolution_dimensions": 3,
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_space_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_time_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 6}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
["compress_all_res", {"multiplier": 2}],
|
||||
["res_x", {"num_layers": 2}],
|
||||
],
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"in_channels": 3,
|
||||
"latent_log_var": "uniform",
|
||||
"norm_layer": "pixel_norm",
|
||||
"out_channels": 128,
|
||||
"patch_size": 4,
|
||||
}
|
||||
|
||||
|
||||
def infer_audio_vae_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Return audio VAE config."""
|
||||
return {
|
||||
"attn_resolutions": [],
|
||||
"attn_type": "vanilla",
|
||||
"causality_axis": "height",
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"dropout": 0.0,
|
||||
"give_pre_end": False,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"mel_hop_length": 160,
|
||||
"mid_block_add_attention": False,
|
||||
"norm_type": "pixel",
|
||||
"num_res_blocks": 2,
|
||||
"out_ch": 2,
|
||||
"resamp_with_conv": True,
|
||||
"resolution": 256,
|
||||
"sample_rate": 16000,
|
||||
"tanh_out": False,
|
||||
"z_channels": 8,
|
||||
}
|
||||
|
||||
|
||||
def infer_vocoder_config(weights: Dict[str, mx.array]) -> dict:
|
||||
"""Infer vocoder config from weights."""
|
||||
# Check for bwe_generator (LTX-2.3 BigVGAN vocoder)
|
||||
has_bwe = any(k.startswith("bwe_generator") for k in weights)
|
||||
|
||||
if has_bwe:
|
||||
return {
|
||||
"type": "bigvgan",
|
||||
"has_bwe_generator": True,
|
||||
}
|
||||
|
||||
return {
|
||||
"output_sample_rate": 24000,
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"stereo": True,
|
||||
"upsample_initial_channel": 1024,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
}
|
||||
|
||||
|
||||
# ─── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def convert(source: str, output_path: Path, variant: str = "distilled"):
|
||||
"""Convert monolithic safetensors to modular directory layout.
|
||||
|
||||
Args:
|
||||
source: HF repo ID (e.g. "Lightricks/LTX-2"), local directory, or file path.
|
||||
output_path: Output directory for the modular layout.
|
||||
variant: "distilled" or "dev".
|
||||
"""
|
||||
source_path = resolve_source(source, variant)
|
||||
|
||||
print(f"Loading monolithic weights from {source_path.name}...")
|
||||
all_weights = mx.load(str(source_path))
|
||||
total_keys = len(all_weights)
|
||||
print(f" Loaded {total_keys} keys")
|
||||
|
||||
# Route keys to components
|
||||
print("\nExtracting components...")
|
||||
|
||||
# 1. Transformer
|
||||
print(" [1/6] Transformer...")
|
||||
transformer_weights = sanitize_transformer(all_weights)
|
||||
num_shards = save_sharded(transformer_weights, output_path / "transformer")
|
||||
config = infer_transformer_config(transformer_weights)
|
||||
save_config(config, output_path / "transformer")
|
||||
t_params = sum(v.size for v in transformer_weights.values())
|
||||
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
|
||||
|
||||
# 2. VAE Decoder
|
||||
print(" [2/6] VAE Decoder...")
|
||||
vae_decoder_weights = sanitize_vae_decoder(all_weights)
|
||||
save_single(vae_decoder_weights, output_path / "vae" / "decoder")
|
||||
config = infer_vae_decoder_config(vae_decoder_weights, variant)
|
||||
save_config(config, output_path / "vae" / "decoder")
|
||||
d_params = sum(v.size for v in vae_decoder_weights.values())
|
||||
print(f" {len(vae_decoder_weights)} keys, {d_params:,} params")
|
||||
|
||||
# 3. VAE Encoder
|
||||
print(" [3/6] VAE Encoder...")
|
||||
vae_encoder_weights = sanitize_vae_encoder(all_weights)
|
||||
save_single(vae_encoder_weights, output_path / "vae" / "encoder")
|
||||
config = infer_vae_encoder_config(vae_encoder_weights)
|
||||
save_config(config, output_path / "vae" / "encoder")
|
||||
e_params = sum(v.size for v in vae_encoder_weights.values())
|
||||
print(f" {len(vae_encoder_weights)} keys, {e_params:,} params")
|
||||
|
||||
# 4. Audio VAE Decoder
|
||||
print(" [4/6] Audio VAE Decoder...")
|
||||
audio_decoder_weights = sanitize_audio_decoder(all_weights)
|
||||
save_single(audio_decoder_weights, output_path / "audio_vae")
|
||||
config = infer_audio_vae_config(audio_decoder_weights)
|
||||
save_config(config, output_path / "audio_vae")
|
||||
a_params = sum(v.size for v in audio_decoder_weights.values())
|
||||
print(f" {len(audio_decoder_weights)} keys, {a_params:,} params")
|
||||
|
||||
# 5. Vocoder
|
||||
print(" [5/6] Vocoder...")
|
||||
vocoder_weights = sanitize_vocoder(all_weights)
|
||||
save_single(vocoder_weights, output_path / "vocoder")
|
||||
config = infer_vocoder_config(vocoder_weights)
|
||||
save_config(config, output_path / "vocoder")
|
||||
v_params = sum(v.size for v in vocoder_weights.values())
|
||||
print(f" {len(vocoder_weights)} keys, {v_params:,} params")
|
||||
|
||||
# 6. Text Projections
|
||||
print(" [6/6] Text Projections...")
|
||||
text_proj_weights = extract_text_projections(all_weights)
|
||||
tp_dir = output_path / "text_projections"
|
||||
tp_dir.mkdir(parents=True, exist_ok=True)
|
||||
mx.save_safetensors(str(tp_dir / "model.safetensors"), text_proj_weights)
|
||||
tp_params = sum(v.size for v in text_proj_weights.values())
|
||||
print(f" {len(text_proj_weights)} keys, {tp_params:,} params")
|
||||
|
||||
# 7. Copy upscaler files
|
||||
print("\nCopying upscaler files...")
|
||||
source_dir = source_path.parent
|
||||
is_hf_repo = "/" in source and not Path(source).exists()
|
||||
upscaler_files = []
|
||||
|
||||
if is_hf_repo:
|
||||
from huggingface_hub import list_repo_files
|
||||
|
||||
upscaler_files = [
|
||||
f for f in list_repo_files(source) if UPSCALER_PATTERN.match(f)
|
||||
]
|
||||
else:
|
||||
upscaler_files = [
|
||||
f.name for f in source_dir.iterdir()
|
||||
if f.is_file() and UPSCALER_PATTERN.match(f.name)
|
||||
]
|
||||
|
||||
if not upscaler_files:
|
||||
print(" No upscaler files found")
|
||||
|
||||
for upscaler_file in sorted(upscaler_files):
|
||||
dest = output_path / upscaler_file
|
||||
if dest.exists():
|
||||
print(f" {upscaler_file}: already exists, skipping")
|
||||
continue
|
||||
|
||||
local_candidate = source_dir / upscaler_file
|
||||
if local_candidate.is_file():
|
||||
shutil.copy2(str(local_candidate), str(dest))
|
||||
print(f" {upscaler_file}: copied")
|
||||
elif is_hf_repo:
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
print(f" {upscaler_file}: downloading from {source}...")
|
||||
downloaded = hf_hub_download(repo_id=source, filename=upscaler_file)
|
||||
shutil.copy2(downloaded, str(dest))
|
||||
print(f" {upscaler_file}: done")
|
||||
else:
|
||||
print(f" {upscaler_file}: not found, skipping")
|
||||
|
||||
# 8. Link text_encoder and tokenizer directories
|
||||
print("\nLinking text encoder & tokenizer...")
|
||||
for subdir in ["text_encoder", "tokenizer"]:
|
||||
dest = output_path / subdir
|
||||
if dest.exists():
|
||||
print(f" {subdir}/: already exists, skipping")
|
||||
continue
|
||||
|
||||
local_candidate = source_dir / subdir
|
||||
if local_candidate.is_dir():
|
||||
# Resolve through symlinks to get the real directory
|
||||
real_path = local_candidate.resolve()
|
||||
dest.symlink_to(real_path)
|
||||
print(f" {subdir}/: symlinked to {real_path}")
|
||||
elif is_hf_repo:
|
||||
from huggingface_hub import list_repo_files, snapshot_download
|
||||
|
||||
# Only download if the subdir exists in the repo
|
||||
repo_files = list_repo_files(source)
|
||||
if any(f.startswith(f"{subdir}/") for f in repo_files):
|
||||
print(f" {subdir}/: downloading from {source}...")
|
||||
snapshot_download(
|
||||
repo_id=source,
|
||||
allow_patterns=f"{subdir}/*",
|
||||
local_dir=str(output_path),
|
||||
)
|
||||
print(f" {subdir}/: done")
|
||||
else:
|
||||
print(f" {subdir}/: not in repo, skipping")
|
||||
else:
|
||||
print(f" {subdir}/: not found in source, skipping")
|
||||
|
||||
# Summary
|
||||
all_converted = (
|
||||
len(transformer_weights)
|
||||
+ len(vae_decoder_weights)
|
||||
+ len(vae_encoder_weights)
|
||||
+ len(audio_decoder_weights)
|
||||
+ len(vocoder_weights)
|
||||
+ len(text_proj_weights)
|
||||
)
|
||||
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
|
||||
if all_converted < total_keys:
|
||||
# Find unconverted keys
|
||||
converted_prefixes = set()
|
||||
for key in all_weights:
|
||||
if key.startswith(TRANSFORMER_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(VAE_DECODER_PREFIX) or key.startswith(VAE_STATS_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(VAE_ENCODER_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(AUDIO_DECODER_PREFIX) or key.startswith(AUDIO_STATS_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(AUDIO_ENCODER_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(VOCODER_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(TEXT_PROJ_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
elif key.startswith(VIDEO_CONNECTOR_PREFIX) or key.startswith(AUDIO_CONNECTOR_PREFIX):
|
||||
converted_prefixes.add(key)
|
||||
|
||||
skipped = set(all_weights.keys()) - converted_prefixes
|
||||
if skipped:
|
||||
print(f" Skipped {len(skipped)} keys:")
|
||||
for k in sorted(skipped)[:20]:
|
||||
print(f" {k}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert monolithic LTX-2/2.3 safetensors to modular MLX layout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--source",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HF repo ID (e.g. Lightricks/LTX-2, Lightricks/LTX-2.3), local directory, or direct safetensors file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory for modular layout",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
choices=["distilled", "dev"],
|
||||
default="distilled",
|
||||
help="Model variant (affects VAE decoder config and which file to download)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert(args.source, Path(args.output), variant=args.variant)
|
||||
40
mlx_video/models/ltx_2/feed_forward.py
Normal file
40
mlx_video/models/ltx_2/feed_forward.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, approximate: str = "tanh"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.approximate == "tanh":
|
||||
return nn.gelu_approx(x)
|
||||
else:
|
||||
return nn.gelu(x)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int | None = None,
|
||||
mult: int = 4,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dim_out = dim_out or dim
|
||||
inner_dim = int(dim * mult)
|
||||
|
||||
self.proj_in = nn.Linear(dim, inner_dim, bias=bias)
|
||||
self.act = GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
|
||||
x = self.proj_in(x)
|
||||
x = self.act(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
2566
mlx_video/models/ltx_2/generate.py
Normal file
2566
mlx_video/models/ltx_2/generate.py
Normal file
File diff suppressed because it is too large
Load Diff
651
mlx_video/models/ltx_2/ltx.py
Normal file
651
mlx_video/models/ltx_2/ltx.py
Normal file
@@ -0,0 +1,651 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from pathlib import Path
|
||||
from mlx_video.models.ltx_2.config import (
|
||||
LTXModelConfig,
|
||||
LTXModelType,
|
||||
LTXRopeType,
|
||||
TransformerConfig,
|
||||
)
|
||||
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
|
||||
from mlx_video.models.ltx_2.transformer import (
|
||||
BasicAVTransformerBlock,
|
||||
Modality,
|
||||
TransformerArgs,
|
||||
)
|
||||
from mlx_video.utils import to_denoised
|
||||
|
||||
|
||||
class TransformerArgsPreprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patchify_proj: nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: Optional[PixArtAlphaTextProjection],
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
num_attention_heads: int,
|
||||
use_middle_indices_grid: bool,
|
||||
timestep_scale_multiplier: int,
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
double_precision_rope: bool = False,
|
||||
prompt_adaln: Optional[AdaLayerNormSingle] = None,
|
||||
):
|
||||
self.patchify_proj = patchify_proj
|
||||
self.adaln = adaln
|
||||
self.caption_projection = caption_projection
|
||||
self.prompt_adaln = prompt_adaln
|
||||
self.inner_dim = inner_dim
|
||||
self.max_pos = max_pos
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_middle_indices_grid = use_middle_indices_grid
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
|
||||
def _prepare_timestep(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
|
||||
# Reshape to (batch, tokens, dim)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_timestep_with_adaln(
|
||||
self,
|
||||
adaln: AdaLayerNormSingle,
|
||||
timestep: mx.array,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
|
||||
return timestep_emb, embedded_timestep
|
||||
|
||||
def _prepare_context(
|
||||
self,
|
||||
context: mx.array,
|
||||
x: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
context = self.caption_projection(context)
|
||||
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||
return context, attention_mask
|
||||
|
||||
def _prepare_attention_mask(
|
||||
self,
|
||||
attention_mask: Optional[mx.array],
|
||||
x_dtype: mx.Dtype,
|
||||
) -> Optional[mx.array]:
|
||||
if attention_mask is None:
|
||||
return None
|
||||
|
||||
# Check if already float
|
||||
if attention_mask.dtype in [mx.float16, mx.float32, mx.bfloat16]:
|
||||
return attention_mask
|
||||
|
||||
# Convert boolean/int mask to float mask
|
||||
# 0 -> -inf (masked), 1 -> 0 (not masked)
|
||||
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
|
||||
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
return mask
|
||||
|
||||
def _prepare_positional_embeddings(
|
||||
self,
|
||||
positions: mx.array,
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
num_attention_heads: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
pe = precompute_freqs_cis(
|
||||
positions,
|
||||
dim=inner_dim,
|
||||
theta=self.positional_embedding_theta,
|
||||
max_pos=max_pos,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
num_attention_heads=num_attention_heads,
|
||||
rope_type=self.rope_type,
|
||||
double_precision=self.double_precision_rope,
|
||||
)
|
||||
return pe
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
|
||||
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||
if modality.positional_embeddings is not None:
|
||||
pe = modality.positional_embeddings
|
||||
else:
|
||||
pe = self._prepare_positional_embeddings(
|
||||
positions=modality.positions,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=self.max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
|
||||
# Prompt-conditioned timestep (LTX-2.3) - uses raw sigma, not per-token timesteps
|
||||
prompt_timestep = None
|
||||
prompt_embedded_timestep = None
|
||||
if self.prompt_adaln is not None and modality.sigma is not None:
|
||||
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
|
||||
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
|
||||
)
|
||||
|
||||
return TransformerArgs(
|
||||
x=x,
|
||||
context=context,
|
||||
context_mask=attention_mask,
|
||||
timesteps=timestep,
|
||||
embedded_timestep=embedded_timestep,
|
||||
positional_embeddings=pe,
|
||||
cross_positional_embeddings=None,
|
||||
cross_scale_shift_timestep=None,
|
||||
cross_gate_timestep=None,
|
||||
enabled=modality.enabled,
|
||||
prompt_timesteps=prompt_timestep,
|
||||
prompt_embedded_timestep=prompt_embedded_timestep,
|
||||
)
|
||||
|
||||
|
||||
class MultiModalTransformerArgsPreprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patchify_proj: nn.Linear,
|
||||
adaln: AdaLayerNormSingle,
|
||||
caption_projection: Optional[PixArtAlphaTextProjection],
|
||||
cross_scale_shift_adaln: AdaLayerNormSingle,
|
||||
cross_gate_adaln: AdaLayerNormSingle,
|
||||
inner_dim: int,
|
||||
max_pos: List[int],
|
||||
num_attention_heads: int,
|
||||
cross_pe_max_pos: int,
|
||||
use_middle_indices_grid: bool,
|
||||
audio_cross_attention_dim: int,
|
||||
timestep_scale_multiplier: int,
|
||||
positional_embedding_theta: float,
|
||||
rope_type: LTXRopeType,
|
||||
av_ca_timestep_scale_multiplier: int,
|
||||
double_precision_rope: bool = False,
|
||||
prompt_adaln: Optional[AdaLayerNormSingle] = None,
|
||||
):
|
||||
self.simple_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=patchify_proj,
|
||||
adaln=adaln,
|
||||
caption_projection=caption_projection,
|
||||
inner_dim=inner_dim,
|
||||
max_pos=max_pos,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
positional_embedding_theta=positional_embedding_theta,
|
||||
rope_type=rope_type,
|
||||
double_precision_rope=double_precision_rope,
|
||||
prompt_adaln=prompt_adaln,
|
||||
)
|
||||
self.cross_scale_shift_adaln = cross_scale_shift_adaln
|
||||
self.cross_gate_adaln = cross_gate_adaln
|
||||
self.cross_pe_max_pos = cross_pe_max_pos
|
||||
self.audio_cross_attention_dim = audio_cross_attention_dim
|
||||
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
from dataclasses import replace
|
||||
|
||||
transformer_args = self.simple_preprocessor.prepare(modality)
|
||||
|
||||
# Prepare cross-modal positional embeddings
|
||||
cross_pe = self.simple_preprocessor._prepare_positional_embeddings(
|
||||
positions=modality.positions[:, 0:1, :],
|
||||
inner_dim=self.audio_cross_attention_dim,
|
||||
max_pos=[self.cross_pe_max_pos],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=self.simple_preprocessor.num_attention_heads,
|
||||
)
|
||||
|
||||
# Prepare cross-attention timestep embeddings
|
||||
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=transformer_args.x.dtype,
|
||||
)
|
||||
|
||||
return replace(
|
||||
transformer_args,
|
||||
cross_positional_embeddings=cross_pe,
|
||||
cross_scale_shift_timestep=cross_scale_shift_timestep,
|
||||
cross_gate_timestep=cross_gate_timestep,
|
||||
)
|
||||
|
||||
def _prepare_cross_attention_timestep(
|
||||
self,
|
||||
timestep: mx.array,
|
||||
timestep_scale_multiplier: int,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * timestep_scale_multiplier
|
||||
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||
|
||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||
|
||||
return scale_shift_timestep, gate_timestep
|
||||
|
||||
|
||||
class LTXModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LTXModelConfig):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.model_type = config.model_type
|
||||
self.use_middle_indices_grid = config.use_middle_indices_grid
|
||||
self.rope_type = config.rope_type
|
||||
self.timestep_scale_multiplier = config.timestep_scale_multiplier
|
||||
self.positional_embedding_theta = config.positional_embedding_theta
|
||||
|
||||
cross_pe_max_pos = None
|
||||
|
||||
if config.model_type.is_video_enabled():
|
||||
self.positional_embedding_max_pos = config.positional_embedding_max_pos
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.inner_dim = config.inner_dim
|
||||
self._init_video(config)
|
||||
|
||||
if config.model_type.is_audio_enabled():
|
||||
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
|
||||
self.audio_num_attention_heads = config.audio_num_attention_heads
|
||||
self.audio_inner_dim = config.audio_inner_dim
|
||||
self._init_audio(config)
|
||||
|
||||
# Initialize cross-modal components
|
||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||
cross_pe_max_pos = max(
|
||||
config.positional_embedding_max_pos[0],
|
||||
config.audio_positional_embedding_max_pos[0],
|
||||
)
|
||||
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
|
||||
self.audio_cross_attention_dim = config.audio_cross_attention_dim
|
||||
self._init_audio_video(config)
|
||||
|
||||
self._init_preprocessors(config, cross_pe_max_pos)
|
||||
|
||||
self._init_transformer_blocks(config)
|
||||
|
||||
def _init_video(self, config: LTXModelConfig) -> None:
|
||||
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
|
||||
|
||||
adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
)
|
||||
|
||||
self.scale_shift_table = mx.zeros((2, self.inner_dim))
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=config.norm_eps, affine=False)
|
||||
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
|
||||
|
||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||
|
||||
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
|
||||
|
||||
if config.has_prompt_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.audio_caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
)
|
||||
|
||||
# Output components
|
||||
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
|
||||
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
|
||||
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
|
||||
|
||||
def _init_audio_video(self, config: LTXModelConfig) -> None:
|
||||
num_scale_shift_values = 4
|
||||
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
)
|
||||
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=num_scale_shift_values,
|
||||
)
|
||||
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
embedding_coefficient=1,
|
||||
)
|
||||
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=1,
|
||||
)
|
||||
|
||||
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
|
||||
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
|
||||
# Multi-modal preprocessors
|
||||
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
cross_scale_shift_adaln=self.av_ca_video_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_a2v_gate_adaln_single,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=config.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
cross_pe_max_pos=cross_pe_max_pos,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
audio_cross_attention_dim=config.audio_cross_attention_dim,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
self.audio_args_preprocessor = MultiModalTransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
cross_scale_shift_adaln=self.av_ca_audio_scale_shift_adaln_single,
|
||||
cross_gate_adaln=self.av_ca_v2a_gate_adaln_single,
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=config.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
cross_pe_max_pos=cross_pe_max_pos,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
audio_cross_attention_dim=config.audio_cross_attention_dim,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
av_ca_timestep_scale_multiplier=config.av_ca_timestep_scale_multiplier,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
elif config.model_type.is_video_enabled():
|
||||
self.video_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.patchify_proj,
|
||||
adaln=self.adaln_single,
|
||||
caption_projection=getattr(self, "caption_projection", None),
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=config.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "prompt_adaln_single", None),
|
||||
)
|
||||
elif config.model_type.is_audio_enabled():
|
||||
self.audio_args_preprocessor = TransformerArgsPreprocessor(
|
||||
patchify_proj=self.audio_patchify_proj,
|
||||
adaln=self.audio_adaln_single,
|
||||
caption_projection=getattr(self, "audio_caption_projection", None),
|
||||
inner_dim=self.audio_inner_dim,
|
||||
max_pos=config.audio_positional_embedding_max_pos,
|
||||
num_attention_heads=self.audio_num_attention_heads,
|
||||
use_middle_indices_grid=config.use_middle_indices_grid,
|
||||
timestep_scale_multiplier=config.timestep_scale_multiplier,
|
||||
positional_embedding_theta=config.positional_embedding_theta,
|
||||
rope_type=config.rope_type,
|
||||
double_precision_rope=config.double_precision_rope,
|
||||
prompt_adaln=getattr(self, "audio_prompt_adaln_single", None),
|
||||
)
|
||||
|
||||
def _init_transformer_blocks(self, config: LTXModelConfig) -> None:
|
||||
video_config = config.get_video_config()
|
||||
audio_config = config.get_audio_config()
|
||||
|
||||
self.transformer_blocks = {
|
||||
idx: BasicAVTransformerBlock(
|
||||
idx=idx,
|
||||
video=video_config,
|
||||
audio=audio_config,
|
||||
rope_type=config.rope_type,
|
||||
norm_eps=config.norm_eps,
|
||||
has_prompt_adaln=config.has_prompt_adaln,
|
||||
)
|
||||
for idx in range(config.num_layers)
|
||||
}
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self,
|
||||
video: Optional[TransformerArgs],
|
||||
audio: Optional[TransformerArgs],
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Process through all transformer blocks.
|
||||
|
||||
Args:
|
||||
stg_video_blocks: Block indices where video self-attention is skipped (STG).
|
||||
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
|
||||
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
|
||||
"""
|
||||
stg_v_set = set(stg_video_blocks) if stg_video_blocks else set()
|
||||
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
|
||||
for idx, block in self.transformer_blocks.items():
|
||||
video, audio = block(
|
||||
video=video, audio=audio,
|
||||
skip_video_self_attn=(idx in stg_v_set),
|
||||
skip_audio_self_attn=(idx in stg_a_set),
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
return video, audio
|
||||
|
||||
def _process_output(
|
||||
self,
|
||||
scale_shift_table: mx.array,
|
||||
norm_out: nn.LayerNorm,
|
||||
proj_out: nn.Linear,
|
||||
x: mx.array,
|
||||
embedded_timestep: mx.array,
|
||||
) -> mx.array:
|
||||
|
||||
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
|
||||
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
|
||||
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
|
||||
timestep_expanded = embedded_timestep[:, :, None, :] # (B, 1, 1, dim)
|
||||
|
||||
# Combine: (1, 1, 2, dim) + (B, 1, 1, dim) broadcasts to (B, 1, 2, dim)
|
||||
scale_shift_values = table_expanded + timestep_expanded
|
||||
|
||||
# Extract shift and scale (first index is shift, second is scale)
|
||||
shift = scale_shift_values[:, :, 0, :] # (B, 1, dim)
|
||||
scale = scale_shift_values[:, :, 1, :] # (B, 1, dim)
|
||||
|
||||
x = norm_out(x)
|
||||
x = x * (1 + scale) + shift # Broadcasts (B, 1, dim) to (B, seq, dim)
|
||||
x = proj_out(x)
|
||||
|
||||
return x
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[Modality] = None,
|
||||
audio: Optional[Modality] = None,
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
video: Video modality input.
|
||||
audio: Audio modality input.
|
||||
stg_video_blocks: Block indices where video self-attention is skipped (STG).
|
||||
stg_audio_blocks: Block indices where audio self-attention is skipped (STG).
|
||||
skip_cross_modal: Skip all A2V/V2A cross-attention (modality isolation).
|
||||
"""
|
||||
# Validate inputs
|
||||
if not self.model_type.is_video_enabled() and video is not None:
|
||||
raise ValueError("Video is not enabled for this model")
|
||||
if not self.model_type.is_audio_enabled() and audio is not None:
|
||||
raise ValueError("Audio is not enabled for this model")
|
||||
|
||||
# Preprocess arguments
|
||||
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
|
||||
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
|
||||
|
||||
# Process transformer blocks
|
||||
video_out, audio_out = self._process_transformer_blocks(
|
||||
video=video_args,
|
||||
audio=audio_args,
|
||||
stg_video_blocks=stg_video_blocks,
|
||||
stg_audio_blocks=stg_audio_blocks,
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
|
||||
# Process outputs
|
||||
vx = (
|
||||
self._process_output(
|
||||
self.scale_shift_table,
|
||||
self.norm_out,
|
||||
self.proj_out,
|
||||
video_out.x,
|
||||
video_out.embedded_timestep,
|
||||
)
|
||||
if video_out is not None
|
||||
else None
|
||||
)
|
||||
|
||||
ax = (
|
||||
self._process_output(
|
||||
self.audio_scale_shift_table,
|
||||
self.audio_norm_out,
|
||||
self.audio_proj_out,
|
||||
audio_out.x,
|
||||
audio_out.embedded_timestep,
|
||||
)
|
||||
if audio_out is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return vx, ax
|
||||
|
||||
def sanitize(self, weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
|
||||
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
|
||||
if not has_raw_prefix:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if not key.startswith("model.diffusion_model."):
|
||||
continue
|
||||
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
new_key = new_key.replace("model.diffusion_model.", "")
|
||||
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTXModel":
|
||||
import json
|
||||
|
||||
config_dict = {}
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config_dict = json.load(f)
|
||||
config = LTXModelConfig(**config_dict)
|
||||
model = cls(config)
|
||||
|
||||
weights = {}
|
||||
|
||||
for weight_file in model_path.glob("*.safetensors"):
|
||||
weights.update(mx.load(str(weight_file)))
|
||||
|
||||
|
||||
sanitized = model.sanitize(weights)
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
|
||||
model.load_weights(list(sanitized.items()), strict=strict)
|
||||
mx.eval(model.parameters())
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class X0Model(nn.Module):
|
||||
|
||||
def __init__(self, velocity_model: LTXModel):
|
||||
|
||||
super().__init__()
|
||||
self.velocity_model = velocity_model
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[Modality] = None,
|
||||
audio: Optional[Modality] = None,
|
||||
stg_video_blocks: Optional[List[int]] = None,
|
||||
stg_audio_blocks: Optional[List[int]] = None,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
|
||||
|
||||
vx, ax = self.velocity_model(
|
||||
video, audio,
|
||||
stg_video_blocks=stg_video_blocks,
|
||||
stg_audio_blocks=stg_audio_blocks,
|
||||
skip_cross_modal=skip_cross_modal,
|
||||
)
|
||||
|
||||
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
|
||||
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
|
||||
|
||||
return denoised_video, denoised_audio
|
||||
165
mlx_video/models/ltx_2/postprocess.py
Normal file
165
mlx_video/models/ltx_2/postprocess.py
Normal file
@@ -0,0 +1,165 @@
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
||||
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
||||
|
||||
Args:
|
||||
image: Input image as uint8 numpy array (H, W, C)
|
||||
d: Diameter of each pixel neighborhood
|
||||
sigma_color: Filter sigma in the color space
|
||||
sigma_space: Filter sigma in the coordinate space
|
||||
|
||||
Returns:
|
||||
Filtered image
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
|
||||
except ImportError:
|
||||
# Fallback to simple Gaussian blur if cv2 not available
|
||||
return gaussian_blur(image, kernel_size=3)
|
||||
|
||||
|
||||
def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
||||
"""Apply Gaussian blur.
|
||||
|
||||
Args:
|
||||
image: Input image as uint8 numpy array (H, W, C)
|
||||
kernel_size: Size of the Gaussian kernel (must be odd)
|
||||
|
||||
Returns:
|
||||
Blurred image
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
||||
except ImportError:
|
||||
# Simple box blur fallback
|
||||
from scipy.ndimage import uniform_filter
|
||||
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
|
||||
|
||||
|
||||
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
|
||||
"""Apply unsharp masking to enhance edges after blur.
|
||||
|
||||
Args:
|
||||
image: Input image as uint8 numpy array
|
||||
kernel_size: Size of the Gaussian kernel
|
||||
sigma: Gaussian sigma
|
||||
amount: Strength of sharpening
|
||||
|
||||
Returns:
|
||||
Sharpened image
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
|
||||
return np.clip(sharpened, 0, 255).astype(np.uint8)
|
||||
except ImportError:
|
||||
return image
|
||||
|
||||
|
||||
def reduce_grid_artifacts(
|
||||
video: np.ndarray,
|
||||
method: str = "bilateral",
|
||||
strength: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
"""Reduce grid artifacts in video frames.
|
||||
|
||||
Args:
|
||||
video: Video as numpy array (F, H, W, C) uint8
|
||||
method: "bilateral", "gaussian", or "frequency"
|
||||
strength: How strong to apply the filter (0-1)
|
||||
|
||||
Returns:
|
||||
Processed video
|
||||
"""
|
||||
if method == "bilateral":
|
||||
d = max(3, int(5 * strength))
|
||||
sigma = 50 + 50 * strength
|
||||
processed = np.stack([
|
||||
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
|
||||
for frame in video
|
||||
])
|
||||
elif method == "gaussian":
|
||||
kernel_size = max(3, int(3 + 4 * strength))
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
processed = np.stack([
|
||||
gaussian_blur(frame, kernel_size=kernel_size)
|
||||
for frame in video
|
||||
])
|
||||
elif method == "frequency":
|
||||
processed = np.stack([
|
||||
remove_grid_frequency(frame, grid_size=8)
|
||||
for frame in video
|
||||
])
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {method}")
|
||||
|
||||
# Optionally sharpen to recover some detail
|
||||
if strength < 1.0:
|
||||
# Blend with original based on strength
|
||||
alpha = strength
|
||||
processed = (alpha * processed + (1 - alpha) * video).astype(np.uint8)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
|
||||
"""Remove grid-frequency components using FFT.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, C) uint8
|
||||
grid_size: Expected grid periodicity in pixels
|
||||
|
||||
Returns:
|
||||
Filtered frame
|
||||
"""
|
||||
result = np.zeros_like(frame)
|
||||
|
||||
for c in range(frame.shape[2]):
|
||||
channel = frame[:, :, c].astype(np.float32)
|
||||
h, w = channel.shape
|
||||
|
||||
# FFT
|
||||
fft = np.fft.fft2(channel)
|
||||
fft_shifted = np.fft.fftshift(fft)
|
||||
|
||||
# Create notch filter at grid frequencies
|
||||
cy, cx = h // 2, w // 2
|
||||
mask = np.ones((h, w), dtype=np.float32)
|
||||
|
||||
# Attenuate frequencies at grid periodicity
|
||||
freq_y = h // grid_size
|
||||
freq_x = w // grid_size
|
||||
|
||||
for fy in range(-2, 3):
|
||||
for fx in range(-2, 3):
|
||||
if fy == 0 and fx == 0:
|
||||
continue
|
||||
y_pos = cy + fy * freq_y
|
||||
x_pos = cx + fx * freq_x
|
||||
if 0 <= y_pos < h and 0 <= x_pos < w:
|
||||
# Gaussian attenuation around the frequency
|
||||
for dy in range(-2, 3):
|
||||
for dx in range(-2, 3):
|
||||
yy, xx = y_pos + dy, x_pos + dx
|
||||
if 0 <= yy < h and 0 <= xx < w:
|
||||
dist = np.sqrt(dy**2 + dx**2)
|
||||
mask[yy, xx] *= min(1.0, dist / 3.0)
|
||||
|
||||
# Apply mask and inverse FFT
|
||||
fft_filtered = fft_shifted * mask
|
||||
channel_filtered = np.fft.ifft2(np.fft.ifftshift(fft_filtered)).real
|
||||
|
||||
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
30
mlx_video/models/ltx_2/prompts/gemma_i2v_system_prompt.txt
Normal file
30
mlx_video/models/ltx_2/prompts/gemma_i2v_system_prompt.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
|
||||
|
||||
#### Guidelines:
|
||||
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
|
||||
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
|
||||
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
|
||||
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||
- Chronological flow: Use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
|
||||
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
|
||||
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
|
||||
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
|
||||
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
|
||||
|
||||
#### Important notes:
|
||||
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
|
||||
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
|
||||
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
|
||||
- Format: Never start output with punctuation marks or special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
#### Example output:
|
||||
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
|
||||
40
mlx_video/models/ltx_2/prompts/gemma_t2v_system_prompt.txt
Normal file
40
mlx_video/models/ltx_2/prompts/gemma_t2v_system_prompt.txt
Normal file
@@ -0,0 +1,40 @@
|
||||
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||
|
||||
#### Guidelines
|
||||
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
|
||||
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
|
||||
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
|
||||
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
|
||||
- Speech (only when requested):
|
||||
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
|
||||
- Specify language if not English and accent if relevant.
|
||||
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
|
||||
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
|
||||
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
|
||||
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
|
||||
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
|
||||
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
|
||||
|
||||
#### Important notes:
|
||||
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
|
||||
- Camera motion: DO NOT invent camera motion unless requested by the user.
|
||||
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
|
||||
- Format: DO NOT start your response with special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single continuous paragraph in natural language (English).
|
||||
- NO titles, headings, prefaces, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
|
||||
|
||||
#### Example
|
||||
Input: "A woman at a coffee shop talking on the phone"
|
||||
Output:
|
||||
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.
|
||||
540
mlx_video/models/ltx_2/rope.py
Normal file
540
mlx_video/models/ltx_2/rope.py
Normal file
@@ -0,0 +1,540 @@
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
freqs_cis: Tuple[mx.array, mx.array],
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
) -> mx.array:
|
||||
"""Apply rotary position embeddings to input tensor.
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor to apply RoPE to
|
||||
freqs_cis: Tuple of (cos_freqs, sin_freqs)
|
||||
rope_type: Type of RoPE to apply (INTERLEAVED or SPLIT)
|
||||
|
||||
Returns:
|
||||
Tensor with rotary embeddings applied
|
||||
"""
|
||||
if rope_type == LTXRopeType.INTERLEAVED:
|
||||
return apply_interleaved_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
||||
elif rope_type == LTXRopeType.SPLIT:
|
||||
return apply_split_rotary_emb(input_tensor, freqs_cis[0], freqs_cis[1])
|
||||
else:
|
||||
raise ValueError(f"Invalid rope type: {rope_type}")
|
||||
|
||||
|
||||
def apply_interleaved_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
cos_freqs: mx.array,
|
||||
sin_freqs: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply interleaved rotary embeddings.
|
||||
|
||||
Pairs adjacent dimensions and applies rotation.
|
||||
Pattern: [x0, x1, x2, x3, ...] -> rotate pairs (x0,x1), (x2,x3), ...
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor of shape (..., dim)
|
||||
cos_freqs: Cosine frequencies
|
||||
sin_freqs: Sine frequencies
|
||||
|
||||
Returns:
|
||||
Tensor with interleaved rotary embeddings applied
|
||||
"""
|
||||
# Compute in float32 for better precision
|
||||
input_dtype = input_tensor.dtype
|
||||
input_tensor = input_tensor.astype(mx.float32)
|
||||
cos_freqs = cos_freqs.astype(mx.float32)
|
||||
sin_freqs = sin_freqs.astype(mx.float32)
|
||||
|
||||
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
||||
shape = input_tensor.shape
|
||||
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
||||
|
||||
# Extract pairs
|
||||
t1 = input_tensor[..., 0] # Even indices
|
||||
t2 = input_tensor[..., 1] # Odd indices
|
||||
|
||||
# Apply rotation: (-t2, t1) pattern
|
||||
t_rot = mx.stack([-t2, t1], axis=-1)
|
||||
|
||||
# Flatten back: (..., dim/2, 2) -> (..., dim)
|
||||
input_tensor = mx.reshape(input_tensor, shape)
|
||||
t_rot = mx.reshape(t_rot, shape)
|
||||
|
||||
# Apply rotary embeddings
|
||||
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
||||
|
||||
return out.astype(input_dtype)
|
||||
|
||||
|
||||
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||
"""Rotate for interleaved RoPE: [x0, x1, x2, x3] -> [-x1, x0, -x3, x2].
|
||||
|
||||
PyTorch equivalent:
|
||||
t_dup = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
return rearrange(t_dup, "... d r -> ... (d r)")
|
||||
"""
|
||||
# x: (..., dim) where dim is even
|
||||
x_even = x[..., 0::2] # [x0, x2, x4, ...]
|
||||
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
|
||||
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
|
||||
rotated = mx.stack([-x_odd, x_even], axis=-1)
|
||||
return mx.reshape(rotated, x.shape)
|
||||
|
||||
def apply_rotary_emb_1d(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
freqs_cis: mx.array,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Apply 1D rotary embeddings using precomputed frequencies (interleaved)."""
|
||||
# freqs_cis: (1, seq_len, num_heads, head_dim, 2) where [..., 0] = cos, [..., 1] = sin
|
||||
cos = freqs_cis[..., 0] # (1, seq_len, num_heads, head_dim)
|
||||
sin = freqs_cis[..., 1]
|
||||
|
||||
# q, k: (batch, seq_len, num_heads, head_dim)
|
||||
# Interleaved RoPE: pairs of adjacent dims rotate together
|
||||
q_r = q * cos + rotate_half_interleaved(q) * sin
|
||||
k_r = k * cos + rotate_half_interleaved(k) * sin
|
||||
|
||||
return q_r, k_r
|
||||
|
||||
|
||||
def apply_split_rotary_emb(
|
||||
input_tensor: mx.array,
|
||||
cos_freqs: mx.array,
|
||||
sin_freqs: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply split rotary embeddings.
|
||||
|
||||
Splits dimensions into two halves and applies rotation.
|
||||
Pattern: split into first half and second half
|
||||
|
||||
Args:
|
||||
input_tensor: Input tensor
|
||||
cos_freqs: Cosine frequencies of shape (B, H, T, D//2)
|
||||
sin_freqs: Sine frequencies of shape (B, H, T, D//2)
|
||||
|
||||
Returns:
|
||||
Tensor with split rotary embeddings applied
|
||||
"""
|
||||
input_dtype = input_tensor.dtype
|
||||
needs_reshape = False
|
||||
original_shape = input_tensor.shape
|
||||
|
||||
# Handle dimension mismatch
|
||||
if input_tensor.ndim != 4 and cos_freqs.ndim == 4:
|
||||
b, h, t, _ = cos_freqs.shape
|
||||
# Reshape from (B, T, H*D) to (B, H, T, D)
|
||||
input_tensor = mx.reshape(input_tensor, (b, t, h, -1))
|
||||
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
||||
needs_reshape = True
|
||||
|
||||
# Cast to float32 for computation precision
|
||||
input_tensor = input_tensor.astype(mx.float32)
|
||||
cos_freqs = cos_freqs.astype(mx.float32)
|
||||
sin_freqs = sin_freqs.astype(mx.float32)
|
||||
|
||||
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
||||
dim = input_tensor.shape[-1]
|
||||
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
||||
|
||||
# Get first and second halves
|
||||
first_half = split_input[..., 0, :] # (..., dim//2)
|
||||
second_half = split_input[..., 1, :] # (..., dim//2)
|
||||
|
||||
# Apply cosine to both halves
|
||||
output_first = first_half * cos_freqs
|
||||
output_second = second_half * cos_freqs
|
||||
|
||||
# Apply sine cross-terms (addcmul pattern)
|
||||
output_first = output_first - sin_freqs * second_half
|
||||
output_second = output_second + sin_freqs * first_half
|
||||
|
||||
# Stack back together
|
||||
output = mx.stack([output_first, output_second], axis=-2)
|
||||
|
||||
# Flatten: (..., 2, dim//2) -> (..., dim)
|
||||
output = mx.reshape(output, input_tensor.shape)
|
||||
|
||||
if needs_reshape:
|
||||
# Reshape back: (B, H, T, D) -> (B, T, H*D)
|
||||
b, h, t, d = output.shape
|
||||
output = mx.swapaxes(output, 1, 2)
|
||||
output = mx.reshape(output, (b, t, h * d))
|
||||
|
||||
return output.astype(input_dtype)
|
||||
|
||||
|
||||
def generate_freq_grid(
|
||||
positional_embedding_theta: float,
|
||||
positional_embedding_max_pos_count: int,
|
||||
inner_dim: int,
|
||||
) -> mx.array:
|
||||
"""Generate frequency grid for RoPE.
|
||||
|
||||
Args:
|
||||
positional_embedding_theta: Base theta value
|
||||
positional_embedding_max_pos_count: Number of position dimensions
|
||||
inner_dim: Inner dimension of the model
|
||||
|
||||
Returns:
|
||||
Frequency indices tensor
|
||||
"""
|
||||
theta = positional_embedding_theta
|
||||
start = 1.0
|
||||
end = theta
|
||||
|
||||
n_elem = 2 * positional_embedding_max_pos_count
|
||||
|
||||
# Compute logarithmic spacing
|
||||
log_start = math.log(start) / math.log(theta)
|
||||
log_end = math.log(end) / math.log(theta)
|
||||
|
||||
num_indices = inner_dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
|
||||
# Create linearly spaced values in log space
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
|
||||
# Compute power indices
|
||||
pow_indices = mx.power(theta, lin_space)
|
||||
|
||||
# Scale by pi/2
|
||||
return pow_indices * (math.pi / 2)
|
||||
|
||||
|
||||
def get_fractional_positions(
|
||||
indices_grid: mx.array,
|
||||
max_pos: List[int],
|
||||
) -> mx.array:
|
||||
"""Convert indices to fractional positions.
|
||||
|
||||
Args:
|
||||
indices_grid: Grid of position indices of shape (B, n_pos_dims, ...)
|
||||
max_pos: Maximum position for each dimension
|
||||
|
||||
Returns:
|
||||
Fractional positions in range [-1, 1] after scaling
|
||||
"""
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), (
|
||||
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
|
||||
)
|
||||
|
||||
# Divide each dimension by its max position
|
||||
fractional_positions = []
|
||||
for i in range(n_pos_dims):
|
||||
frac = indices_grid[:, i] / max_pos[i]
|
||||
fractional_positions.append(frac)
|
||||
|
||||
return mx.stack(fractional_positions, axis=-1)
|
||||
|
||||
|
||||
def generate_freqs(
|
||||
indices: mx.array,
|
||||
indices_grid: mx.array,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
) -> mx.array:
|
||||
"""Generate frequencies from indices and position grid.
|
||||
|
||||
Args:
|
||||
indices: Frequency indices
|
||||
indices_grid: Position indices grid
|
||||
max_pos: Maximum positions per dimension
|
||||
use_middle_indices_grid: Whether to use middle of index ranges
|
||||
|
||||
Returns:
|
||||
Frequency tensor
|
||||
"""
|
||||
# Handle middle indices grid
|
||||
if use_middle_indices_grid:
|
||||
# indices_grid shape: (B, n_dims, T, 2) where last dim is [start, end]
|
||||
assert len(indices_grid.shape) == 4
|
||||
assert indices_grid.shape[-1] == 2
|
||||
indices_grid_start = indices_grid[..., 0]
|
||||
indices_grid_end = indices_grid[..., 1]
|
||||
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid.shape) == 4:
|
||||
indices_grid = indices_grid[..., 0]
|
||||
|
||||
# Get fractional positions
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
|
||||
# Compute frequencies
|
||||
# fractional_positions: (B, T, n_dims)
|
||||
# indices: (inner_dim // n_elem,)
|
||||
# Result: (B, T, inner_dim // n_elem * n_dims)
|
||||
|
||||
# Scale fractional positions to [-1, 1]
|
||||
scaled_positions = fractional_positions * 2 - 1 # (B, T, n_dims)
|
||||
|
||||
# Outer product with indices
|
||||
# (B, T, n_dims, 1) * (1, 1, 1, n_indices) -> (B, T, n_dims, n_indices)
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.expand_dims(
|
||||
mx.expand_dims(mx.expand_dims(indices, axis=0), axis=0), axis=0
|
||||
)
|
||||
|
||||
# Transpose and flatten: (B, T, n_dims, n_indices) -> (B, T, n_indices * n_dims)
|
||||
freqs = mx.swapaxes(freqs, -1, -2) # (B, T, n_indices, n_dims)
|
||||
freqs = mx.reshape(freqs, freqs.shape[:-2] + (-1,))
|
||||
|
||||
return freqs
|
||||
|
||||
|
||||
def split_freqs_cis(
|
||||
freqs: mx.array,
|
||||
pad_size: int,
|
||||
num_attention_heads: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Prepare cos/sin frequencies for split RoPE.
|
||||
|
||||
Args:
|
||||
freqs: Frequency tensor
|
||||
pad_size: Padding size for dimension alignment
|
||||
num_attention_heads: Number of attention heads
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) with shape (B, H, T, D//2)
|
||||
"""
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
# Add padding if needed
|
||||
if pad_size != 0:
|
||||
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape for multi-head attention
|
||||
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||
|
||||
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
||||
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
||||
|
||||
# Swap axes: (B, T, H, D//2) -> (B, H, T, D//2)
|
||||
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
||||
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def interleaved_freqs_cis(
|
||||
freqs: mx.array,
|
||||
pad_size: int,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Prepare cos/sin frequencies for interleaved RoPE.
|
||||
|
||||
Args:
|
||||
freqs: Frequency tensor of shape (B, T, dim//2)
|
||||
pad_size: Padding size for dimension alignment
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) with shape (B, T, dim)
|
||||
"""
|
||||
# Compute cos and sin
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
# Repeat interleave: each element repeated twice
|
||||
# (B, T, D) -> (B, T, 2*D) with pattern [c0, c0, c1, c1, ...]
|
||||
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
||||
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
# Add padding if needed
|
||||
if pad_size != 0:
|
||||
cos_padding = mx.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = mx.zeros_like(sin_freq[:, :, :pad_size])
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def precompute_freqs_cis(
|
||||
indices_grid: mx.array,
|
||||
dim: int,
|
||||
theta: float = 10000.0,
|
||||
max_pos: Optional[List[int]] = None,
|
||||
use_middle_indices_grid: bool = False,
|
||||
num_attention_heads: int = 32,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
double_precision: bool = False,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Precompute RoPE frequencies.
|
||||
|
||||
Args:
|
||||
indices_grid: Position indices grid
|
||||
dim: Dimension for RoPE
|
||||
theta: Base theta value for frequency computation
|
||||
max_pos: Maximum position per dimension
|
||||
use_middle_indices_grid: Whether to use middle indices
|
||||
num_attention_heads: Number of attention heads
|
||||
rope_type: Type of RoPE (INTERLEAVED or SPLIT)
|
||||
double_precision: If True, compute frequencies in float64 for higher precision
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freq, sin_freq) tensors
|
||||
"""
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
num_attention_heads, rope_type
|
||||
)
|
||||
|
||||
# Keep positions in float32 for RoPE computation.
|
||||
# Even though PyTorch nominally casts positions to model dtype (bfloat16),
|
||||
# empirical comparison shows float32 positions produce RoPE values matching
|
||||
# PyTorch exactly (cosine=1.000). BFloat16 loses precision in fractional
|
||||
# position computation that gets amplified by high-frequency indices
|
||||
# (up to 15708), causing cos/sin sign flips and cosine sim of only 0.88.
|
||||
indices_grid = indices_grid.astype(mx.float32)
|
||||
|
||||
# Generate frequency indices
|
||||
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||
|
||||
# Generate frequencies
|
||||
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
||||
|
||||
# Prepare cos/sin based on rope type
|
||||
if rope_type == LTXRopeType.SPLIT:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
||||
else:
|
||||
# Interleaved
|
||||
n_elem = 2 * indices_grid.shape[1]
|
||||
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
|
||||
|
||||
def _precompute_freqs_cis_double_precision(
|
||||
indices_grid: mx.array,
|
||||
dim: int,
|
||||
theta: float,
|
||||
max_pos: List[int],
|
||||
use_middle_indices_grid: bool,
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
|
||||
|
||||
Matches PyTorch's generate_freq_grid_np: uses NumPy float64 for the critical
|
||||
frequency grid computation (log-spaced values), then converts to float32.
|
||||
Position grid stays in bfloat16 to match PyTorch behavior (positions are in
|
||||
model dtype throughout generate_freqs).
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Keep positions in float32 — same reasoning as the non-double-precision path.
|
||||
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||
|
||||
n_pos_dims = indices_grid_f32.shape[1]
|
||||
n_elem = 2 * n_pos_dims
|
||||
|
||||
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
|
||||
# This is the critical precision step - PyTorch uses np.float64 here
|
||||
log_start = np.log(1.0) / np.log(theta)
|
||||
log_end = np.log(theta) / np.log(theta) # = 1.0
|
||||
num_indices = dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
|
||||
# Use numpy float64 for the linspace computation (matches PyTorch)
|
||||
pow_indices = np.power(
|
||||
theta,
|
||||
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
|
||||
)
|
||||
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
|
||||
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
|
||||
|
||||
# Handle middle indices grid
|
||||
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||
if use_middle_indices_grid:
|
||||
assert len(indices_grid_f32.shape) == 4
|
||||
assert indices_grid_f32.shape[-1] == 2
|
||||
indices_grid_start = indices_grid_f32[..., 0]
|
||||
indices_grid_end = indices_grid_f32[..., 1]
|
||||
indices_grid_f32 = (indices_grid_start + indices_grid_end) / 2.0
|
||||
elif len(indices_grid_f32.shape) == 4:
|
||||
indices_grid_f32 = indices_grid_f32[..., 0]
|
||||
# After handling: indices_grid_f32 shape is (B, n_dims, T)
|
||||
|
||||
# Get fractional positions: (B, n_dims, T) -> (B, T, n_dims)
|
||||
# Compute fractional positions for each dimension
|
||||
fractional_list = []
|
||||
for i in range(n_pos_dims):
|
||||
frac = indices_grid_f32[:, i, :] / max_pos[i] # (B, T)
|
||||
fractional_list.append(frac)
|
||||
|
||||
# Stack: (B, T, n_dims)
|
||||
fractional_positions = mx.stack(fractional_list, axis=-1)
|
||||
|
||||
# Scale to [-1, 1]
|
||||
scaled_positions = fractional_positions * 2 - 1
|
||||
|
||||
# Compute frequencies: outer product
|
||||
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
|
||||
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
|
||||
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
|
||||
# freqs: (B, T, n_dims, num_indices)
|
||||
|
||||
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)
|
||||
freqs = mx.swapaxes(freqs, -1, -2)
|
||||
freqs = mx.reshape(freqs, (freqs.shape[0], freqs.shape[1], -1))
|
||||
|
||||
# Compute cos/sin
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
# Prepare based on rope type
|
||||
if rope_type == LTXRopeType.SPLIT:
|
||||
expected_freqs = dim // 2
|
||||
current_freqs = cos_freq.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
|
||||
# Add padding
|
||||
if pad_size > 0:
|
||||
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape for multi-head attention: (B, T, dim//2) -> (B, H, T, dim//2//H)
|
||||
b, t = cos_freq.shape[0], cos_freq.shape[1]
|
||||
cos_freq = mx.reshape(cos_freq, (b, t, num_attention_heads, -1))
|
||||
sin_freq = mx.reshape(sin_freq, (b, t, num_attention_heads, -1))
|
||||
cos_freq = mx.swapaxes(cos_freq, 1, 2)
|
||||
sin_freq = mx.swapaxes(sin_freq, 1, 2)
|
||||
else:
|
||||
# Interleaved
|
||||
cos_freq = mx.repeat(cos_freq, 2, axis=-1)
|
||||
sin_freq = mx.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
pad_size = dim % n_elem
|
||||
if pad_size > 0:
|
||||
cos_padding = mx.ones((*cos_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||
sin_padding = mx.zeros((*sin_freq.shape[:-1], pad_size), dtype=mx.float32)
|
||||
cos_freq = mx.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = mx.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
return cos_freq, sin_freq
|
||||
181
mlx_video/models/ltx_2/samplers.py
Normal file
181
mlx_video/models/ltx_2/samplers.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Second-order res_2s sampler for diffusion models.
|
||||
|
||||
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
|
||||
noise injection, ported from the LTX-2 PyTorch implementation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def phi(j: int, neg_h: float) -> float:
|
||||
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
||||
|
||||
phi_1(z) = (e^z - 1) / z
|
||||
phi_2(z) = (e^z - 1 - z) / z^2
|
||||
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
|
||||
"""
|
||||
if abs(neg_h) < 1e-10:
|
||||
return 1.0 / math.factorial(j)
|
||||
|
||||
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
|
||||
return (math.exp(neg_h) - remainder) / (neg_h**j)
|
||||
|
||||
|
||||
def get_res2s_coefficients(
|
||||
h: float,
|
||||
phi_cache: dict,
|
||||
c2: float = 0.5,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute res_2s Runge-Kutta coefficients for a given step size.
|
||||
|
||||
Args:
|
||||
h: Step size in log-space = log(sigma / sigma_next)
|
||||
phi_cache: Dictionary to cache phi function results.
|
||||
c2: Substep position (default 0.5 = midpoint)
|
||||
|
||||
Returns:
|
||||
(a21, b1, b2): RK coefficients.
|
||||
"""
|
||||
def get_phi(j: int, neg_h: float) -> float:
|
||||
cache_key = (j, neg_h)
|
||||
if cache_key in phi_cache:
|
||||
return phi_cache[cache_key]
|
||||
result = phi(j, neg_h)
|
||||
phi_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
neg_h_c2 = -h * c2
|
||||
phi_1_c2 = get_phi(1, neg_h_c2)
|
||||
a21 = c2 * phi_1_c2
|
||||
|
||||
neg_h_full = -h
|
||||
phi_2_full = get_phi(2, neg_h_full)
|
||||
b2 = phi_2_full / c2
|
||||
|
||||
phi_1_full = get_phi(1, neg_h_full)
|
||||
b1 = phi_1_full - b2
|
||||
|
||||
return a21, b1, b2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDE noise injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_sde_coeff(
|
||||
sigma_next: float,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute SDE coefficients for variance-preserving noise injection.
|
||||
|
||||
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
|
||||
|
||||
Returns:
|
||||
(alpha_ratio, sigma_down, sigma_up)
|
||||
"""
|
||||
sigma_up = sigma_next * 0.5
|
||||
# Clamp sigma_up to avoid sqrt(negative)
|
||||
sigma_up = min(sigma_up, sigma_next * 0.9999)
|
||||
|
||||
sigma_signal = 1.0 - sigma_next # sigma_max=1
|
||||
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
|
||||
alpha_ratio = sigma_signal + sigma_residual
|
||||
|
||||
if alpha_ratio == 0:
|
||||
sigma_down = sigma_next
|
||||
else:
|
||||
sigma_down = sigma_residual / alpha_ratio
|
||||
|
||||
# Handle NaN edge cases
|
||||
if math.isnan(sigma_up):
|
||||
sigma_up = 0.0
|
||||
if math.isnan(sigma_down):
|
||||
sigma_down = sigma_next
|
||||
if math.isnan(alpha_ratio):
|
||||
alpha_ratio = 1.0
|
||||
|
||||
return alpha_ratio, sigma_down, sigma_up
|
||||
|
||||
|
||||
def sde_noise_step(
|
||||
sample: mx.array,
|
||||
denoised_sample: mx.array,
|
||||
sigma: float,
|
||||
sigma_next: float,
|
||||
noise: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply SDE noise injection step.
|
||||
|
||||
Advances sample from sigma to sigma_next with stochastic noise injection.
|
||||
|
||||
Args:
|
||||
sample: Current sample (anchor point)
|
||||
denoised_sample: Denoised prediction at this step
|
||||
sigma: Current noise level
|
||||
sigma_next: Next noise level
|
||||
noise: Pre-generated noise tensor (channel-wise normalized)
|
||||
|
||||
Returns:
|
||||
Noised sample at sigma_next
|
||||
"""
|
||||
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
|
||||
|
||||
if sigma_up == 0 or sigma_next == 0:
|
||||
return denoised_sample
|
||||
|
||||
# Float32 arithmetic
|
||||
sample_f32 = sample.astype(mx.float32)
|
||||
denoised_f32 = denoised_sample.astype(mx.float32)
|
||||
noise_f32 = noise.astype(mx.float32)
|
||||
|
||||
# Extract epsilon prediction
|
||||
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
|
||||
denoised_next = sample_f32 - sigma * eps_next
|
||||
|
||||
# Mix deterministic and stochastic components
|
||||
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
|
||||
return x_noised
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Noise generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def channelwise_normalize(x: mx.array) -> mx.array:
|
||||
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
||||
|
||||
Operates on the last 2 dimensions (spatial H, W or time, freq).
|
||||
"""
|
||||
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
|
||||
x = x - mean
|
||||
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
|
||||
x = x / std
|
||||
return x
|
||||
|
||||
|
||||
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
|
||||
"""Generate channel-wise normalized Gaussian noise.
|
||||
|
||||
PyTorch uses float64; we use float32 (MLX doesn't support float64).
|
||||
The channel-wise normalization is the key quality-affecting step.
|
||||
|
||||
Args:
|
||||
shape: Shape of the noise tensor
|
||||
key: MLX random key for deterministic generation
|
||||
|
||||
Returns:
|
||||
Channel-wise normalized noise in float32
|
||||
"""
|
||||
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
|
||||
# Global normalization
|
||||
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
|
||||
# Channel-wise normalization
|
||||
noise = channelwise_normalize(noise)
|
||||
return noise
|
||||
1138
mlx_video/models/ltx_2/text_encoder.py
Normal file
1138
mlx_video/models/ltx_2/text_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
26
mlx_video/models/ltx_2/text_projection.py
Normal file
26
mlx_video/models/ltx_2/text_projection.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_size: int,
|
||||
out_features: int | None = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||
self.act = nn.GELU(approx="tanh") # Must match PyTorch's approximate="tanh"
|
||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.linear1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
403
mlx_video/models/ltx_2/transformer.py
Normal file
403
mlx_video/models/ltx_2/transformer.py
Normal file
@@ -0,0 +1,403 @@
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
|
||||
from mlx_video.models.ltx_2.attention import Attention
|
||||
from mlx_video.models.ltx_2.feed_forward import FeedForward
|
||||
from mlx_video.utils import rms_norm
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Modality:
|
||||
latent: mx.array
|
||||
timesteps: mx.array
|
||||
positions: mx.array
|
||||
context: mx.array
|
||||
enabled: bool = True
|
||||
context_mask: Optional[mx.array] = None
|
||||
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
|
||||
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
|
||||
# Raw sigma value (scalar per batch) for prompt adaln (LTX-2.3)
|
||||
sigma: Optional[mx.array] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TransformerArgs:
|
||||
x: mx.array
|
||||
context: mx.array
|
||||
context_mask: Optional[mx.array]
|
||||
timesteps: mx.array
|
||||
embedded_timestep: mx.array
|
||||
positional_embeddings: Tuple[mx.array, mx.array]
|
||||
cross_positional_embeddings: Optional[Tuple[mx.array, mx.array]]
|
||||
cross_scale_shift_timestep: Optional[mx.array]
|
||||
cross_gate_timestep: Optional[mx.array]
|
||||
enabled: bool
|
||||
# LTX-2.3: prompt-conditioned timestep embeddings for cross-attention
|
||||
prompt_timesteps: Optional[mx.array] = None
|
||||
prompt_embedded_timestep: Optional[mx.array] = None
|
||||
|
||||
|
||||
class BasicAVTransformerBlock(nn.Module):
|
||||
"""Audio-Video transformer block with cross-modal attention.
|
||||
|
||||
Supports video-only, audio-only, or combined audio-video processing
|
||||
with bidirectional cross-attention between modalities.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idx: int,
|
||||
video: Optional[TransformerConfig] = None,
|
||||
audio: Optional[TransformerConfig] = None,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
norm_eps: float = 1e-6,
|
||||
has_prompt_adaln: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.idx = idx
|
||||
self.norm_eps = norm_eps
|
||||
self.has_prompt_adaln = has_prompt_adaln
|
||||
|
||||
# Video components
|
||||
if video is not None:
|
||||
self.attn1 = Attention(
|
||||
query_dim=video.dim,
|
||||
heads=video.heads,
|
||||
dim_head=video.d_head,
|
||||
context_dim=None, # Self-attention
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=video.dim,
|
||||
context_dim=video.context_dim,
|
||||
heads=video.heads,
|
||||
dim_head=video.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
self.ff = FeedForward(video.dim, dim_out=video.dim)
|
||||
# 9 params for LTX-2.3 (self-attn + cross-attn + FFN), 6 for LTX-2
|
||||
num_ada_params = 9 if has_prompt_adaln else 6
|
||||
self.scale_shift_table = mx.zeros((num_ada_params, video.dim))
|
||||
|
||||
if has_prompt_adaln:
|
||||
self.prompt_scale_shift_table = mx.zeros((2, video.dim))
|
||||
|
||||
# Audio components
|
||||
if audio is not None:
|
||||
self.audio_attn1 = Attention(
|
||||
query_dim=audio.dim,
|
||||
heads=audio.heads,
|
||||
dim_head=audio.d_head,
|
||||
context_dim=None,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
self.audio_attn2 = Attention(
|
||||
query_dim=audio.dim,
|
||||
context_dim=audio.context_dim,
|
||||
heads=audio.heads,
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
self.audio_ff = FeedForward(audio.dim, dim_out=audio.dim)
|
||||
num_audio_ada_params = 9 if has_prompt_adaln else 6
|
||||
self.audio_scale_shift_table = mx.zeros((num_audio_ada_params, audio.dim))
|
||||
|
||||
if has_prompt_adaln:
|
||||
self.audio_prompt_scale_shift_table = mx.zeros((2, audio.dim))
|
||||
|
||||
# Cross-modal attention (when both video and audio are enabled)
|
||||
if audio is not None and video is not None:
|
||||
# Audio-to-Video: Q from video, K/V from audio
|
||||
self.audio_to_video_attn = Attention(
|
||||
query_dim=video.dim,
|
||||
context_dim=audio.dim,
|
||||
heads=audio.heads,
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
# Video-to-Audio: Q from audio, K/V from video
|
||||
self.video_to_audio_attn = Attention(
|
||||
query_dim=audio.dim,
|
||||
context_dim=video.dim,
|
||||
heads=audio.heads,
|
||||
dim_head=audio.d_head,
|
||||
rope_type=rope_type,
|
||||
norm_eps=norm_eps,
|
||||
has_gate_logits=has_prompt_adaln,
|
||||
)
|
||||
# Scale-shift tables for cross-attention
|
||||
self.scale_shift_table_a2v_ca_audio = mx.zeros((5, audio.dim))
|
||||
self.scale_shift_table_a2v_ca_video = mx.zeros((5, video.dim))
|
||||
|
||||
def get_ada_values(
|
||||
self,
|
||||
scale_shift_table: mx.array,
|
||||
batch_size: int,
|
||||
timestep: mx.array,
|
||||
indices: slice,
|
||||
) -> Tuple[mx.array, ...]:
|
||||
"""Get adaptive normalization values from scale-shift table.
|
||||
|
||||
Args:
|
||||
scale_shift_table: Table of shape (num_params, dim)
|
||||
batch_size: Batch size
|
||||
timestep: Timestep embeddings of shape (B, 1, num_params * dim) or similar
|
||||
indices: Slice for which parameters to extract
|
||||
|
||||
Returns:
|
||||
Tuple of scale-shift values
|
||||
"""
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
|
||||
# scale_shift_table[indices]: (num_selected, dim)
|
||||
# Add batch and sequence dimensions: (1, 1, num_selected, dim)
|
||||
table_slice = scale_shift_table[indices]
|
||||
table_expanded = mx.expand_dims(mx.expand_dims(table_slice, axis=0), axis=0)
|
||||
|
||||
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
|
||||
timestep_reshaped = mx.reshape(
|
||||
timestep,
|
||||
(batch_size, timestep.shape[1], num_ada_params, -1)
|
||||
)
|
||||
|
||||
# Extract the relevant indices
|
||||
timestep_slice = timestep_reshaped[:, :, indices, :]
|
||||
|
||||
# Add table values to timestep
|
||||
ada_values = table_expanded + timestep_slice
|
||||
|
||||
# Unbind along the parameter dimension
|
||||
# Result: tuple of tensors, each of shape (B, seq, dim)
|
||||
num_sliced = ada_values.shape[2]
|
||||
result = tuple(ada_values[:, :, i, :] for i in range(num_sliced))
|
||||
|
||||
return result
|
||||
|
||||
def get_av_ca_ada_values(
|
||||
self,
|
||||
scale_shift_table: mx.array,
|
||||
batch_size: int,
|
||||
scale_shift_timestep: mx.array,
|
||||
gate_timestep: mx.array,
|
||||
num_scale_shift_values: int = 4,
|
||||
) -> Tuple[mx.array, mx.array, mx.array, mx.array, mx.array]:
|
||||
"""Get adaptive values for cross-modal attention.
|
||||
|
||||
Args:
|
||||
scale_shift_table: Table with 5 parameters (4 scale-shift + 1 gate)
|
||||
batch_size: Batch size
|
||||
scale_shift_timestep: Timestep for scale-shift
|
||||
gate_timestep: Timestep for gating
|
||||
num_scale_shift_values: Number of scale-shift values (default 4)
|
||||
|
||||
Returns:
|
||||
Tuple of 5 tensors: (scale1, shift1, scale2, shift2, gate)
|
||||
"""
|
||||
# Get scale-shift values
|
||||
scale_shift_ada = self.get_ada_values(
|
||||
scale_shift_table[:num_scale_shift_values, :],
|
||||
batch_size,
|
||||
scale_shift_timestep,
|
||||
slice(None, None),
|
||||
)
|
||||
|
||||
# Get gate values
|
||||
gate_ada = self.get_ada_values(
|
||||
scale_shift_table[num_scale_shift_values:, :],
|
||||
batch_size,
|
||||
gate_timestep,
|
||||
slice(None, None),
|
||||
)
|
||||
|
||||
# Squeeze the sequence dimension if it's 1
|
||||
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
|
||||
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
|
||||
|
||||
return (*scale_shift_squeezed, *gate_squeezed)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[TransformerArgs] = None,
|
||||
audio: Optional[TransformerArgs] = None,
|
||||
skip_video_self_attn: bool = False,
|
||||
skip_audio_self_attn: bool = False,
|
||||
skip_cross_modal: bool = False,
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Forward pass through transformer block.
|
||||
|
||||
Args:
|
||||
video: Video modality arguments
|
||||
audio: Audio modality arguments
|
||||
skip_video_self_attn: Skip video self-attention (for STG perturbation)
|
||||
skip_audio_self_attn: Skip audio self-attention (for STG perturbation)
|
||||
skip_cross_modal: Skip all cross-modal attention (for modality isolation)
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_video, updated_audio) TransformerArgs
|
||||
"""
|
||||
batch_size = video.x.shape[0] if video is not None else audio.x.shape[0]
|
||||
|
||||
vx = video.x if video is not None else None
|
||||
ax = audio.x if audio is not None else None
|
||||
|
||||
# Check which modalities to run
|
||||
run_vx = video is not None and video.enabled and vx.size > 0
|
||||
run_ax = audio is not None and audio.enabled and ax.size > 0
|
||||
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
|
||||
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
|
||||
|
||||
# Process video self-attention and cross-attention with text
|
||||
if run_vx:
|
||||
vshift_msa, vscale_msa, vgate_msa = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
|
||||
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
||||
vshift_q, vscale_q, vgate_q = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
|
||||
)
|
||||
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
|
||||
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
|
||||
)
|
||||
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
|
||||
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
|
||||
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
|
||||
else:
|
||||
vx = vx + self.attn2(
|
||||
rms_norm(vx, eps=self.norm_eps),
|
||||
context=video.context,
|
||||
mask=video.context_mask,
|
||||
)
|
||||
|
||||
# Process audio self-attention and cross-attention with text
|
||||
if run_ax:
|
||||
ashift_msa, ascale_msa, agate_msa = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(0, 3)
|
||||
)
|
||||
|
||||
# Self-attention with RoPE (skip_attention=True for STG perturbation)
|
||||
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
|
||||
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
|
||||
|
||||
# Cross-attention with text context
|
||||
if self.has_prompt_adaln:
|
||||
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
|
||||
ashift_q, ascale_q, agate_q = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
|
||||
)
|
||||
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
|
||||
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
|
||||
)
|
||||
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
|
||||
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
|
||||
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
|
||||
else:
|
||||
ax = ax + self.audio_attn2(
|
||||
rms_norm(ax, eps=self.norm_eps),
|
||||
context=audio.context,
|
||||
mask=audio.context_mask,
|
||||
)
|
||||
|
||||
# Audio-Video cross-modal attention
|
||||
if run_a2v or run_v2a:
|
||||
vx_norm3 = rms_norm(vx, eps=self.norm_eps)
|
||||
ax_norm3 = rms_norm(ax, eps=self.norm_eps)
|
||||
|
||||
# Get adaptive values for audio cross-attention
|
||||
(
|
||||
scale_ca_audio_a2v,
|
||||
shift_ca_audio_a2v,
|
||||
scale_ca_audio_v2a,
|
||||
shift_ca_audio_v2a,
|
||||
gate_out_v2a,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_audio,
|
||||
ax.shape[0],
|
||||
audio.cross_scale_shift_timestep,
|
||||
audio.cross_gate_timestep,
|
||||
)
|
||||
|
||||
# Get adaptive values for video cross-attention
|
||||
(
|
||||
scale_ca_video_a2v,
|
||||
shift_ca_video_a2v,
|
||||
scale_ca_video_v2a,
|
||||
shift_ca_video_v2a,
|
||||
gate_out_a2v,
|
||||
) = self.get_av_ca_ada_values(
|
||||
self.scale_shift_table_a2v_ca_video,
|
||||
vx.shape[0],
|
||||
video.cross_scale_shift_timestep,
|
||||
video.cross_gate_timestep,
|
||||
)
|
||||
|
||||
# Audio-to-Video cross-attention
|
||||
if run_a2v:
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_a2v) + shift_ca_video_a2v
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_a2v) + shift_ca_audio_a2v
|
||||
vx = vx + (
|
||||
self.audio_to_video_attn(
|
||||
vx_scaled,
|
||||
context=ax_scaled,
|
||||
pe=video.cross_positional_embeddings,
|
||||
k_pe=audio.cross_positional_embeddings,
|
||||
)
|
||||
* gate_out_a2v
|
||||
)
|
||||
|
||||
# Video-to-Audio cross-attention
|
||||
if run_v2a:
|
||||
ax_scaled = ax_norm3 * (1 + scale_ca_audio_v2a) + shift_ca_audio_v2a
|
||||
vx_scaled = vx_norm3 * (1 + scale_ca_video_v2a) + shift_ca_video_v2a
|
||||
ax = ax + (
|
||||
self.video_to_audio_attn(
|
||||
ax_scaled,
|
||||
context=vx_scaled,
|
||||
pe=audio.cross_positional_embeddings,
|
||||
k_pe=video.cross_positional_embeddings,
|
||||
)
|
||||
* gate_out_v2a
|
||||
)
|
||||
|
||||
# Process video feed-forward
|
||||
if run_vx:
|
||||
vshift_mlp, vscale_mlp, vgate_mlp = self.get_ada_values(
|
||||
self.scale_shift_table, vx.shape[0], video.timesteps, slice(3, 6)
|
||||
)
|
||||
vx_scaled = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_mlp) + vshift_mlp
|
||||
vx = vx + self.ff(vx_scaled) * vgate_mlp
|
||||
|
||||
# Process audio feed-forward
|
||||
if run_ax:
|
||||
ashift_mlp, ascale_mlp, agate_mlp = self.get_ada_values(
|
||||
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(3, 6)
|
||||
)
|
||||
ax_scaled = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_mlp) + ashift_mlp
|
||||
ax = ax + self.audio_ff(ax_scaled) * agate_mlp
|
||||
|
||||
# Return updated TransformerArgs
|
||||
video_out = replace(video, x=vx) if video is not None else None
|
||||
audio_out = replace(audio, x=ax) if audio is not None else None
|
||||
|
||||
return video_out, audio_out
|
||||
371
mlx_video/models/ltx_2/upsampler.py
Normal file
371
mlx_video/models/ltx_2/upsampler.py
Normal file
@@ -0,0 +1,371 @@
|
||||
from typing import Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class Conv3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
dilation: Union[int, Tuple[int, int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
if isinstance(padding, int):
|
||||
padding = (padding, padding, padding)
|
||||
if isinstance(dilation, int):
|
||||
dilation = (dilation, dilation, dilation)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
|
||||
# Weight shape: (C_out, KD, KH, KW, C_in)
|
||||
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = mx.zeros((out_channels,))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (N, D, H, W, C_in)
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (N, D', H', W', C_out)
|
||||
"""
|
||||
y = mx.conv3d(
|
||||
x,
|
||||
self.weight,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
)
|
||||
|
||||
if self.bias is not None:
|
||||
y = y + self.bias
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class GroupNorm3d(nn.Module):
|
||||
|
||||
def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.num_groups = num_groups
|
||||
self.num_channels = num_channels
|
||||
self.eps = eps
|
||||
self.weight = mx.ones((num_channels,))
|
||||
self.bias = mx.zeros((num_channels,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, D, H, W, C)
|
||||
n, d, h, w, c = x.shape
|
||||
input_dtype = x.dtype
|
||||
|
||||
|
||||
x = x.astype(mx.float32)
|
||||
|
||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
|
||||
|
||||
# Compute mean and var over spatial and channel group dims
|
||||
mean = mx.mean(x, axis=(1, 3), keepdims=True)
|
||||
var = mx.var(x, axis=(1, 3), keepdims=True)
|
||||
|
||||
# Normalize
|
||||
x = (x - mean) / mx.sqrt(var + self.eps)
|
||||
|
||||
# Reshape back
|
||||
x = mx.reshape(x, (n, d, h, w, c))
|
||||
|
||||
# Apply weight and bias
|
||||
weight = self.weight.astype(mx.float32)
|
||||
bias = self.bias.astype(mx.float32)
|
||||
x = x * weight + bias
|
||||
|
||||
# Convert back to input dtype
|
||||
x = x.astype(input_dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PixelShuffle2D(nn.Module):
|
||||
"""Pixel shuffle for 2D spatial upsampling."""
|
||||
|
||||
def __init__(self, upscale_factor: int = 2):
|
||||
super().__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, H, W, C) where C = out_channels * upscale_factor^2
|
||||
n, h, w, c = x.shape
|
||||
r = self.upscale_factor
|
||||
out_c = c // (r * r)
|
||||
|
||||
# Reshape: (N, H, W, out_c, r, r)
|
||||
x = mx.reshape(x, (n, h, w, out_c, r, r))
|
||||
|
||||
# Permute: (N, H, r, W, r, out_c)
|
||||
x = mx.transpose(x, (0, 1, 4, 2, 5, 3))
|
||||
|
||||
# Reshape: (N, H*r, W*r, out_c)
|
||||
x = mx.reshape(x, (n, h * r, w * r, out_c))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(nn.Module):
|
||||
|
||||
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
# 2D conv: mid_channels -> 4*mid_channels for pixel shuffle
|
||||
self.conv = nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1)
|
||||
|
||||
# Blur kernel for antialiasing
|
||||
self.blur_down_kernel = mx.ones((1, 1, 5, 5)) / 25.0
|
||||
|
||||
self.pixel_shuffle = PixelShuffle2D(2)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (N, D, H, W, C) - channels last 3D format
|
||||
|
||||
n, d, h, w, c = x.shape
|
||||
|
||||
# Process frame by frame
|
||||
# Reshape to (N*D, H, W, C) for 2D operations
|
||||
x = mx.reshape(x, (n * d, h, w, c))
|
||||
|
||||
# Apply 2D conv
|
||||
x = self.conv(x)
|
||||
|
||||
# Pixel shuffle for 2x upscaling
|
||||
x = self.pixel_shuffle(x)
|
||||
|
||||
# Reshape back to (N, D, H*2, W*2, C)
|
||||
x = mx.reshape(x, (n, d, h * 2, w * 2, c))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock3D(nn.Module):
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.conv1 = Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.norm1 = GroupNorm3d(32, channels)
|
||||
self.conv2 = Conv3d(channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = GroupNorm3d(32, channels)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
residual = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = nn.silu(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
|
||||
# Activation AFTER residual addition
|
||||
x = nn.silu(x + residual)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LatentUpsampler(nn.Module):
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 1024,
|
||||
num_blocks_per_stage: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
|
||||
# Initial projection
|
||||
self.initial_conv = Conv3d(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = GroupNorm3d(32, mid_channels)
|
||||
|
||||
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
|
||||
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=2.0)
|
||||
|
||||
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
|
||||
# Final projection
|
||||
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def __call__(self, latent: mx.array, debug: bool = False) -> mx.array:
|
||||
"""Upsample latents by 2x spatially.
|
||||
|
||||
Args:
|
||||
latent: Input tensor of shape (B, C, F, H, W) - channels first
|
||||
debug: If True, print intermediate values for debugging
|
||||
|
||||
Returns:
|
||||
Upsampled tensor of shape (B, C, F, H*2, W*2) - channels first
|
||||
"""
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
|
||||
if debug:
|
||||
print(" [DEBUG] LatentUpsampler forward pass:")
|
||||
debug_stats("Input (channels first)", latent)
|
||||
|
||||
# Convert from channels first (B, C, F, H, W) to channels last (B, F, H, W, C)
|
||||
x = mx.transpose(latent, (0, 2, 3, 4, 1))
|
||||
if debug:
|
||||
debug_stats("After transpose to channels-last", x)
|
||||
|
||||
# Initial conv
|
||||
x = self.initial_conv(x)
|
||||
if debug:
|
||||
debug_stats("After initial_conv", x)
|
||||
x = self.initial_norm(x)
|
||||
if debug:
|
||||
debug_stats("After initial_norm", x)
|
||||
x = nn.silu(x)
|
||||
if debug:
|
||||
debug_stats("After silu", x)
|
||||
|
||||
# Pre-upsample blocks
|
||||
for i in sorted(self.res_blocks.keys()):
|
||||
x = self.res_blocks[i](x)
|
||||
if debug:
|
||||
debug_stats(f"After res_blocks[{i}]", x)
|
||||
|
||||
# Upsample (2D spatial, frame-by-frame)
|
||||
x = self.upsampler(x)
|
||||
if debug:
|
||||
debug_stats("After upsampler (spatial 2x)", x)
|
||||
|
||||
# Post-upsample blocks
|
||||
for i in sorted(self.post_upsample_res_blocks.keys()):
|
||||
x = self.post_upsample_res_blocks[i](x)
|
||||
if debug:
|
||||
debug_stats(f"After post_upsample_res_blocks[{i}]", x)
|
||||
|
||||
# Final conv
|
||||
x = self.final_conv(x)
|
||||
if debug:
|
||||
debug_stats("After final_conv", x)
|
||||
|
||||
# Convert back to channels first (B, C, F, H, W)
|
||||
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||
if debug:
|
||||
debug_stats("Output (channels first)", x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def upsample_latents(
|
||||
latent: mx.array,
|
||||
upsampler: LatentUpsampler,
|
||||
latent_mean: mx.array,
|
||||
latent_std: mx.array,
|
||||
debug: bool = False,
|
||||
) -> mx.array:
|
||||
# Un-normalize: latent * std + mean
|
||||
latent_mean = latent_mean.reshape(1, -1, 1, 1, 1)
|
||||
latent_std = latent_std.reshape(1, -1, 1, 1, 1)
|
||||
latent = latent * latent_std + latent_mean
|
||||
|
||||
# Upsample
|
||||
latent = upsampler(latent, debug=debug)
|
||||
|
||||
# Re-normalize: (latent - mean) / std
|
||||
latent = (latent - latent_mean) / latent_std
|
||||
|
||||
return latent
|
||||
|
||||
|
||||
def load_upsampler(weights_path: str) -> LatentUpsampler:
|
||||
"""Load upsampler from safetensors weights.
|
||||
|
||||
Args:
|
||||
weights_path: Path to upsampler weights file
|
||||
|
||||
Returns:
|
||||
Loaded LatentUpsampler model
|
||||
"""
|
||||
print(f"Loading spatial upsampler from {weights_path}...")
|
||||
raw_weights = mx.load(weights_path)
|
||||
|
||||
# Check weight shapes to determine mid_channels
|
||||
# res_blocks.0.conv1.weight should be (mid_channels, mid_channels, 3, 3, 3)
|
||||
sample_key = "res_blocks.0.conv1.weight"
|
||||
if sample_key in raw_weights:
|
||||
mid_channels = raw_weights[sample_key].shape[0]
|
||||
else:
|
||||
mid_channels = 1024 # default
|
||||
|
||||
print(f" Detected mid_channels: {mid_channels}")
|
||||
|
||||
# Create model
|
||||
upsampler = LatentUpsampler(
|
||||
in_channels=128,
|
||||
mid_channels=mid_channels,
|
||||
num_blocks_per_stage=4,
|
||||
)
|
||||
|
||||
# Sanitize weights - convert from PyTorch to MLX format
|
||||
sanitized = {}
|
||||
for key, value in raw_weights.items():
|
||||
new_key = key
|
||||
|
||||
# LTX-2.3 upsampler uses sequential indexing: upsampler.0.* -> upsampler.conv.*
|
||||
if key.startswith("upsampler.0."):
|
||||
new_key = key.replace("upsampler.0.", "upsampler.conv.")
|
||||
|
||||
# Conv3d weights: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||
if "weight" in new_key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Conv2d weights: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
# Load weights
|
||||
upsampler.load_weights(list(sanitized.items()), strict=False)
|
||||
|
||||
print(f" Loaded {len(sanitized)} weights")
|
||||
|
||||
return upsampler
|
||||
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
8
mlx_video/models/ltx_2/video_vae/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
TemporalTilingConfig,
|
||||
)
|
||||
294
mlx_video/models/ltx_2/video_vae/convolution.py
Normal file
294
mlx_video/models/ltx_2/video_vae/convolution.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class PaddingModeType(Enum):
|
||||
ZEROS = "zeros"
|
||||
REFLECT = "reflect"
|
||||
|
||||
|
||||
def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
|
||||
"""Apply reflect padding to spatial dimensions of a 5D tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) - channels last
|
||||
pad_h: Padding for height dimension
|
||||
pad_w: Padding for width dimension
|
||||
|
||||
Returns:
|
||||
Padded tensor
|
||||
"""
|
||||
if pad_h == 0 and pad_w == 0:
|
||||
return x
|
||||
|
||||
# Height padding (axis 2)
|
||||
if pad_h > 0:
|
||||
# Get reflection indices - exclude boundary
|
||||
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
|
||||
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
|
||||
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
|
||||
|
||||
# Width padding (axis 3)
|
||||
if pad_w > 0:
|
||||
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
|
||||
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
|
||||
x = mx.concatenate([left_pad, x, right_pad], axis=3)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def make_conv_nd(
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, ...]],
|
||||
stride: Union[int, Tuple[int, ...]] = 1,
|
||||
padding: Union[int, Tuple[int, ...], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
) -> nn.Module:
|
||||
|
||||
if dims == 2:
|
||||
return CausalConv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif dims == 3:
|
||||
return CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
causal=causal,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of dimensions: {dims}")
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride to tuples
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.time_kernel_size = kernel_size[0]
|
||||
|
||||
# Calculate spatial padding (temporal is handled separately via frame replication)
|
||||
height_pad = kernel_size[1] // 2
|
||||
width_pad = kernel_size[2] // 2
|
||||
self.spatial_padding = (height_pad, width_pad)
|
||||
|
||||
# Create the base convolution (without padding, we'll handle it manually)
|
||||
self.conv = nn.Conv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0, # We handle padding manually
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
|
||||
use_causal = causal if causal is not None else self.causal
|
||||
|
||||
# Apply temporal padding via frame replication
|
||||
# Only apply if kernel_size > 1
|
||||
if self.time_kernel_size > 1:
|
||||
if use_causal:
|
||||
# Causal: replicate first frame kernel_size-1 times at the beginning
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x], axis=2)
|
||||
else:
|
||||
# Non-causal: replicate first frame at start, last frame at end
|
||||
pad_size = (self.time_kernel_size - 1) // 2
|
||||
if pad_size > 0:
|
||||
first_frame_pad = mx.repeat(x[:, :, :1, :, :], pad_size, axis=2)
|
||||
last_frame_pad = mx.repeat(x[:, :, -1:, :, :], pad_size, axis=2)
|
||||
x = mx.concatenate([first_frame_pad, x, last_frame_pad], axis=2)
|
||||
|
||||
# Transpose to channels last: (B, C, D, H, W) -> (B, D, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 4, 1))
|
||||
|
||||
# Apply spatial padding
|
||||
pad_h, pad_w = self.spatial_padding
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
if self.spatial_padding_mode == PaddingModeType.REFLECT:
|
||||
# Use reflect padding for spatial dimensions
|
||||
x = reflect_pad_2d(x, pad_h, pad_w)
|
||||
else:
|
||||
# Use zero padding for spatial dimensions
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(0, 0), # D (temporal - already padded)
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
# Apply convolution with chunking for large tensors
|
||||
# Note: We choose to use chunking because MLX conv3d fails around 33 frames with 192x192 spatial
|
||||
x = self._chunked_conv3d(x)
|
||||
|
||||
# Transpose back to channels first: (B, D, H, W, C) -> (B, C, D, H, W)
|
||||
x = mx.transpose(x, (0, 4, 1, 2, 3))
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv3d(self, x: mx.array) -> mx.array:
|
||||
"""Apply conv3d in temporal chunks to work around MLX bug with large tensors.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, D, H, W, C) in channels-last format
|
||||
|
||||
Returns:
|
||||
Output tensor after conv3d
|
||||
"""
|
||||
b, d, h, w, c = x.shape
|
||||
|
||||
|
||||
total_elements = d * h * w * c
|
||||
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
|
||||
|
||||
if total_elements <= max_safe_elements:
|
||||
return self.conv(x)
|
||||
|
||||
elements_per_frame = h * w * c
|
||||
max_frames_per_chunk = max(1, max_safe_elements // elements_per_frame)
|
||||
chunk_size = min(max_frames_per_chunk, 24) # Cap at 24 frames per chunk
|
||||
|
||||
kernel_t = self.time_kernel_size
|
||||
|
||||
overlap = kernel_t - 1
|
||||
|
||||
|
||||
expected_output_frames = d - overlap
|
||||
|
||||
outputs = []
|
||||
out_idx = 0
|
||||
|
||||
# Process chunks
|
||||
in_start = 0
|
||||
while out_idx < expected_output_frames:
|
||||
remaining = expected_output_frames - out_idx
|
||||
out_frames_this_chunk = min(chunk_size, remaining)
|
||||
|
||||
in_frames_needed = out_frames_this_chunk + overlap
|
||||
in_end = min(in_start + in_frames_needed, d)
|
||||
|
||||
chunk = x[:, in_start:in_end, :, :, :]
|
||||
|
||||
chunk_out = self.conv(chunk)
|
||||
mx.eval(chunk_out)
|
||||
|
||||
outputs.append(chunk_out)
|
||||
|
||||
out_idx += chunk_out.shape[1]
|
||||
in_start += chunk_out.shape[1]
|
||||
|
||||
# Concatenate all chunks
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=1)
|
||||
|
||||
|
||||
class CausalConv2d(nn.Module):
|
||||
"""2D convolution with optional causal padding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
causal: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
"""Initialize CausalConv2d."""
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.spatial_padding_mode = spatial_padding_mode
|
||||
|
||||
# Normalize kernel_size and stride
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size)
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
||||
# Calculate padding
|
||||
if isinstance(padding, str) and padding == "same":
|
||||
self.padding = (
|
||||
(kernel_size[0] - 1) // 2,
|
||||
(kernel_size[1] - 1) // 2,
|
||||
)
|
||||
elif isinstance(padding, int):
|
||||
self.padding = (padding, padding)
|
||||
else:
|
||||
self.padding = padding
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
|
||||
"""Forward pass."""
|
||||
# Transpose to channels last: (B, C, H, W) -> (B, H, W, C)
|
||||
x = mx.transpose(x, (0, 2, 3, 1))
|
||||
|
||||
# Apply padding
|
||||
pad_h, pad_w = self.padding
|
||||
if pad_h != 0 or pad_w != 0:
|
||||
pad_width = [
|
||||
(0, 0), # Batch
|
||||
(pad_h, pad_h), # H
|
||||
(pad_w, pad_w), # W
|
||||
(0, 0), # C
|
||||
]
|
||||
x = mx.pad(x, pad_width)
|
||||
|
||||
x = self.conv(x)
|
||||
|
||||
# Transpose back: (B, H, W, C) -> (B, C, H, W)
|
||||
x = mx.transpose(x, (0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
692
mlx_video/models/ltx_2/video_vae/decoder.py
Normal file
692
mlx_video/models/ltx_2/video_vae/decoder.py
Normal file
@@ -0,0 +1,692 @@
|
||||
"""Video VAE Decoder for LTX-2 with timestep conditioning.
|
||||
|
||||
Architecture (from PyTorch weights):
|
||||
- conv_in: 128 -> 1024
|
||||
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||
- up_blocks.1: Conv 1024 -> 4096, depth2space -> 512, upscale 2x
|
||||
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||
- up_blocks.3: Conv 512 -> 2048, depth2space -> 256, upscale 2x
|
||||
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||
- up_blocks.5: Conv 256 -> 1024, depth2space -> 128, upscale 2x
|
||||
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||
- pixel_norm + timestep modulation (last_scale_shift_table)
|
||||
- conv_out: 128 -> 48
|
||||
- unpatchify: 48 -> 3 with patch_size=4
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: mx.array,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = True,
|
||||
downscale_freq_shift: float = 0,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
) -> mx.array:
|
||||
"""Create sinusoidal timestep embeddings."""
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = mx.exp(exponent)
|
||||
emb = timesteps[:, None].astype(mx.float32) * emb[None, :]
|
||||
emb = scale * emb
|
||||
|
||||
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
|
||||
|
||||
if flip_sin_to_cos:
|
||||
emb = mx.concatenate([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
|
||||
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
"""MLP for timestep embedding."""
|
||||
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaTimestepEmbedder(nn.Module):
|
||||
"""Combined timestep embedding (sinusoidal + MLP)."""
|
||||
|
||||
def __init__(self, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=256,
|
||||
time_embed_dim=embedding_dim
|
||||
)
|
||||
|
||||
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
|
||||
timesteps_proj = get_timestep_embedding(
|
||||
timestep,
|
||||
embedding_dim=256,
|
||||
flip_sin_to_cos=True,
|
||||
downscale_freq_shift=0
|
||||
)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
|
||||
return timesteps_emb
|
||||
|
||||
|
||||
class ResnetBlock3DSimple(nn.Module):
|
||||
"""ResNet block with optional timestep conditioning.
|
||||
|
||||
Weight keys: conv1.conv, conv2.conv, scale_shift_table
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Nested conv structure to match PyTorch naming: conv1.conv.weight
|
||||
self.conv1 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||
self.conv2 = self._make_conv_wrapper(channels, channels, spatial_padding_mode)
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
# Scale-shift table for timestep conditioning: [shift1, scale1, shift2, scale2]
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = mx.zeros((4, channels))
|
||||
|
||||
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
|
||||
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
|
||||
class ConvWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_ch,
|
||||
out_channels=out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
return ConvWrapper()
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = False,
|
||||
timestep_embed: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# Block 1 with optional timestep conditioning
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
|
||||
# Combine table with timestep embedding
|
||||
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
|
||||
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
|
||||
channels = self.scale_shift_table.shape[1]
|
||||
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
shift1 = ada_values[:, 0] # (B, C, 1, 1, 1)
|
||||
scale1 = ada_values[:, 1]
|
||||
shift2 = ada_values[:, 2]
|
||||
scale2 = ada_values[:, 3]
|
||||
|
||||
x = x * (1 + scale1) + shift1
|
||||
|
||||
x = self.act(x)
|
||||
x = self.conv1(x, causal=causal)
|
||||
|
||||
# Block 2 with optional timestep conditioning
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
if self.timestep_conditioning and timestep_embed is not None:
|
||||
x = x * (1 + scale2) + shift2
|
||||
|
||||
x = self.act(x)
|
||||
x = self.conv2(x, causal=causal)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class ResBlockGroup(nn.Module):
|
||||
"""Group of ResNet blocks with shared timestep embedding.
|
||||
|
||||
PyTorch naming: res_blocks.0, res_blocks.1, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_layers: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Time embedder for this block group: embed_dim = 4 * channels
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
timestep_embed = None
|
||||
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
batch_size = x.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
)
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
for res_block in self.res_blocks.values():
|
||||
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2VideoDecoder(nn.Module):
|
||||
"""LTX-2 Video VAE Decoder with timestep conditioning.
|
||||
|
||||
Architecture:
|
||||
- conv_in: 128 -> 1024
|
||||
- up_blocks.0: 5 ResBlocks at 1024 (with timestep)
|
||||
- up_blocks.1: Upsampler 1024 -> 512
|
||||
- up_blocks.2: 5 ResBlocks at 512 (with timestep)
|
||||
- up_blocks.3: Upsampler 512 -> 256
|
||||
- up_blocks.4: 5 ResBlocks at 256 (with timestep)
|
||||
- up_blocks.5: Upsampler 256 -> 128
|
||||
- up_blocks.6: 5 ResBlocks at 128 (with timestep)
|
||||
- conv_out: 128 -> 48 (3 * 4^2 for patch_size=4)
|
||||
"""
|
||||
|
||||
# Block definitions: ("res", channels, num_layers) or ("d2s", in_channels, reduction, stride)
|
||||
# stride is (D, H, W) tuple
|
||||
DEFAULT_BLOCKS = [
|
||||
("res", 1024, 5),
|
||||
("d2s", 1024, 2, (2, 2, 2)),
|
||||
("res", 512, 5),
|
||||
("d2s", 512, 2, (2, 2, 2)),
|
||||
("res", 256, 5),
|
||||
("d2s", 256, 2, (2, 2, 2)),
|
||||
("res", 128, 5),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
patch_size: int = 4,
|
||||
num_layers_per_block: int = 5,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
timestep_conditioning: bool = True,
|
||||
decoder_blocks: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
# Decode parameters (configurable via constructor)
|
||||
self.decode_noise_scale = 0.025 # Set to 0.0 to disable noise
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Per-channel statistics for denormalization (loaded from weights)
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
blocks = decoder_blocks or self.DEFAULT_BLOCKS
|
||||
first_ch = blocks[0][1]
|
||||
last_ch = blocks[-1][1]
|
||||
|
||||
# Initial conv: in_channels -> first block channels
|
||||
class ConvInWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=first_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Build up blocks from config
|
||||
self.up_blocks = {}
|
||||
for idx, block_def in enumerate(blocks):
|
||||
block_type = block_def[0]
|
||||
ch = block_def[1]
|
||||
if block_type == "res":
|
||||
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
|
||||
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
|
||||
elif block_type == "d2s":
|
||||
reduction = block_def[2] if len(block_def) > 2 else 2
|
||||
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
|
||||
residual = block_def[4] if len(block_def) > 4 else True
|
||||
self.up_blocks[idx] = DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=ch,
|
||||
stride=stride,
|
||||
residual=residual,
|
||||
out_channels_reduction_factor=reduction,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
class ConvOutWrapper(nn.Module):
|
||||
def __init__(self_inner):
|
||||
super().__init__()
|
||||
self_inner.conv = CausalConv3d(
|
||||
in_channels=last_ch,
|
||||
out_channels=final_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
def __call__(self_inner, x, causal=False):
|
||||
return self_inner.conv(x, causal=causal)
|
||||
self.conv_out = ConvOutWrapper()
|
||||
|
||||
self.act = nn.SiLU()
|
||||
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = mx.array(1000.0)
|
||||
self.last_time_embedder = PixArtAlphaTimestepEmbedder(
|
||||
embedding_dim=last_ch * 2
|
||||
)
|
||||
self.last_scale_shift_table = mx.zeros((2, last_ch))
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Build decoder weights dict with key remapping
|
||||
sanitized = {}
|
||||
if "per_channel_statistics.mean" in weights:
|
||||
return weights
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||
continue
|
||||
|
||||
if key.startswith("vae.per_channel_statistics."):
|
||||
# Map per-channel statistics (use exact key matching)
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||
|
||||
Args:
|
||||
model_path: Path to directory containing config.json and safetensors files,
|
||||
or path to a single safetensors file.
|
||||
strict: Whether to require all weight keys to match.
|
||||
|
||||
Returns:
|
||||
Loaded LTX2VideoDecoder instance
|
||||
"""
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
config_dict = {}
|
||||
|
||||
# Load config from directory
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Load weights from directory
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
# Infer block structure from weights
|
||||
decoder_blocks = cls._infer_blocks(weights)
|
||||
|
||||
# Determine spatial padding mode from config
|
||||
spatial_padding_mode_str = config_dict.get("spatial_padding_mode", "reflect")
|
||||
spatial_padding_mode = PaddingModeType(spatial_padding_mode_str)
|
||||
|
||||
model = cls(
|
||||
timestep_conditioning=config_dict.get("timestep_conditioning", False),
|
||||
decoder_blocks=decoder_blocks,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _infer_blocks(weights: dict) -> list:
|
||||
"""Infer decoder block structure from weight keys."""
|
||||
block_indices = set()
|
||||
for k in weights:
|
||||
if "up_blocks." in k:
|
||||
idx_str = k.split("up_blocks.")[1].split(".")[0]
|
||||
if idx_str.isdigit():
|
||||
block_indices.add(int(idx_str))
|
||||
|
||||
if not block_indices:
|
||||
return None
|
||||
|
||||
# First pass: collect block info
|
||||
raw_blocks = []
|
||||
for idx in sorted(block_indices):
|
||||
has_conv = any(f"up_blocks.{idx}.conv." in k for k in weights)
|
||||
res_indices = set()
|
||||
for k in weights:
|
||||
prefix = f"up_blocks.{idx}.res_blocks."
|
||||
if prefix in k:
|
||||
res_idx = k.split(prefix)[1].split(".")[0]
|
||||
if res_idx.isdigit():
|
||||
res_indices.add(int(res_idx))
|
||||
|
||||
if has_conv and not res_indices:
|
||||
# D2S block - get conv shape
|
||||
for k, v in weights.items():
|
||||
if f"up_blocks.{idx}.conv." in k and "weight" in k:
|
||||
in_ch = v.shape[-1] if v.ndim == 5 else v.shape[1]
|
||||
conv_out_ch = v.shape[0]
|
||||
raw_blocks.append(("d2s", in_ch, conv_out_ch))
|
||||
break
|
||||
elif res_indices:
|
||||
num_res = max(res_indices) + 1
|
||||
for k, v in weights.items():
|
||||
if f"up_blocks.{idx}.res_blocks.0.conv1" in k and "weight" in k:
|
||||
ch = v.shape[0]
|
||||
raw_blocks.append(("res", ch, num_res))
|
||||
break
|
||||
|
||||
# Second pass: determine d2s strides using the channel progression
|
||||
# For each d2s block, the next res block tells us the expected output channels
|
||||
blocks = []
|
||||
d2s_strides = []
|
||||
for i, block in enumerate(raw_blocks):
|
||||
if block[0] == "res":
|
||||
blocks.append(block)
|
||||
elif block[0] == "d2s":
|
||||
in_ch, conv_out_ch = block[1], block[2]
|
||||
# Find next res block's channels
|
||||
next_ch = None
|
||||
for j in range(i + 1, len(raw_blocks)):
|
||||
if raw_blocks[j][0] == "res":
|
||||
next_ch = raw_blocks[j][1]
|
||||
break
|
||||
|
||||
if next_ch is None:
|
||||
next_ch = in_ch // 2 # fallback
|
||||
|
||||
# out_ch = in_ch // reduction
|
||||
reduction = in_ch // next_ch if next_ch > 0 else 2
|
||||
|
||||
# conv_out = next_ch * multiplier → multiplier = conv_out / next_ch
|
||||
multiplier = conv_out_ch // next_ch if next_ch > 0 else 8
|
||||
|
||||
# Determine stride from multiplier
|
||||
if multiplier == 8:
|
||||
stride = (2, 2, 2)
|
||||
elif multiplier == 4:
|
||||
stride = (1, 2, 2)
|
||||
elif multiplier == 2:
|
||||
stride = (2, 1, 1)
|
||||
else:
|
||||
stride = (2, 2, 2)
|
||||
|
||||
d2s_strides.append(stride)
|
||||
blocks.append(("d2s", in_ch, reduction, stride))
|
||||
|
||||
if not blocks:
|
||||
return None
|
||||
|
||||
# Determine residual flag: LTX-2 has uniform (2,2,2) strides with reduction=2 → residual=True
|
||||
# LTX-2.3 has mixed strides or reduction=1 → residual=False
|
||||
has_mixed_strides = len(set(d2s_strides)) > 1
|
||||
has_non_standard_reduction = any(b[2] != 2 for b in blocks if b[0] == "d2s")
|
||||
use_residual = not has_mixed_strides and not has_non_standard_reduction
|
||||
|
||||
# Apply residual flag to all d2s blocks
|
||||
final_blocks = []
|
||||
for block in blocks:
|
||||
if block[0] == "d2s":
|
||||
final_blocks.append(("d2s", block[1], block[2], block[3], use_residual))
|
||||
else:
|
||||
final_blocks.append(block)
|
||||
|
||||
return final_blocks
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: mx.array,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning and timestep is not None:
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||
elif isinstance(block, DepthToSpaceUpsample):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
scaled_timestep.flatten(),
|
||||
hidden_dtype=x.dtype
|
||||
)
|
||||
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
|
||||
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
|
||||
ada_values = ada_values + ts_reshaped
|
||||
|
||||
shift = ada_values[:, 0] # (B, 128, 1, 1, 1)
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
sample: mx.array,
|
||||
tiling_config: Optional[TilingConfig] = None,
|
||||
tiling_mode: str = "auto",
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
debug: bool = False,
|
||||
on_frames_ready: Optional[callable] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
This method is useful for decoding large videos that would otherwise
|
||||
cause out-of-memory errors. It divides the latents into tiles,
|
||||
decodes each tile separately, and blends them together.
|
||||
|
||||
Args:
|
||||
sample: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration. If None, uses TilingConfig.default().
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
debug: Whether to print debug info.
|
||||
|
||||
Returns:
|
||||
Decoded video of shape (B, 3, F*8, H*8, W*8).
|
||||
"""
|
||||
if tiling_config is None:
|
||||
tiling_config = TilingConfig.default()
|
||||
|
||||
# Check if tiling is actually needed
|
||||
_, _, f, h, w = sample.shape
|
||||
needs_spatial_tiling = False
|
||||
needs_temporal_tiling = False
|
||||
|
||||
# Spatial scale is 32 (8x VAE upsample + 4x unpatchify)
|
||||
# Temporal scale is 8
|
||||
spatial_scale = 32
|
||||
temporal_scale = 8
|
||||
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
if h > tile_size_latent or w > tile_size_latent:
|
||||
needs_spatial_tiling = True
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale
|
||||
if f > tile_size_latent:
|
||||
needs_temporal_tiling = True
|
||||
|
||||
# Auto-enable chunked conv for modes where it helps (larger tiles)
|
||||
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
|
||||
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
|
||||
|
||||
if not needs_spatial_tiling and not needs_temporal_tiling:
|
||||
# No tiling needed, use regular decode
|
||||
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
|
||||
|
||||
return decode_with_tiling(
|
||||
decoder_fn=self,
|
||||
latents=sample,
|
||||
tiling_config=tiling_config,
|
||||
spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x
|
||||
temporal_scale=8, # VAE temporal upsampling factor
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
VideoDecoder = LTX2VideoDecoder
|
||||
44
mlx_video/models/ltx_2/video_vae/encoder.py
Normal file
44
mlx_video/models/ltx_2/video_vae/encoder.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Video VAE Encoder for LTX-2 Image-to-Video.
|
||||
|
||||
The encoder compresses input images/videos to latent representations.
|
||||
Used for I2V (image-to-video) conditioning by encoding the input image
|
||||
to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
|
||||
|
||||
|
||||
|
||||
def encode_image(
|
||||
image: mx.array,
|
||||
encoder: VideoEncoder,
|
||||
) -> mx.array:
|
||||
"""Encode a single image to latent space.
|
||||
|
||||
Args:
|
||||
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
|
||||
encoder: Loaded VAE encoder
|
||||
|
||||
Returns:
|
||||
Latent tensor of shape (1, 128, 1, H//32, W//32)
|
||||
"""
|
||||
# Add batch dimension if needed
|
||||
if image.ndim == 3:
|
||||
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
|
||||
|
||||
# Convert from (B, H, W, C) to (B, C, H, W)
|
||||
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
if image.max() > 1.0:
|
||||
image = image / 255.0
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
|
||||
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
|
||||
|
||||
# Encode
|
||||
latent = encoder(image)
|
||||
|
||||
return latent
|
||||
125
mlx_video/models/ltx_2/video_vae/ops.py
Normal file
125
mlx_video/models/ltx_2/video_vae/ops.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Operations for Video VAE."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||
"""Convert video to patches.
|
||||
|
||||
Moves spatial pixels from H, W dimensions to channel dimension.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, F, H, W)
|
||||
patch_size_hw: Spatial patch size
|
||||
patch_size_t: Temporal patch size
|
||||
|
||||
Returns:
|
||||
Patched tensor of shape (B, C * patch_size_hw^2, F, H/patch_size_hw, W/patch_size_hw)
|
||||
"""
|
||||
b, c, f, h, w = x.shape
|
||||
|
||||
# Check dimensions are divisible
|
||||
assert h % patch_size_hw == 0 and w % patch_size_hw == 0
|
||||
assert f % patch_size_t == 0
|
||||
|
||||
# New dimensions
|
||||
new_h = h // patch_size_hw
|
||||
new_w = w // patch_size_hw
|
||||
new_f = f // patch_size_t
|
||||
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
|
||||
|
||||
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
|
||||
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
|
||||
|
||||
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
|
||||
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
|
||||
x = mx.transpose(x, (0, 1, 3, 7, 5, 2, 4, 6))
|
||||
|
||||
# Reshape: (B, C, pt, pw, ph, F', H', W') -> (B, C*pt*pw*ph, F', H', W')
|
||||
x = mx.reshape(x, (b, new_c, new_f, new_h, new_w))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def unpatchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.array:
|
||||
"""Convert patches back to video.
|
||||
|
||||
Inverse of patchify - moves pixels from channel dimension back to spatial.
|
||||
Matches PyTorch einops: "b (c p r q) f h w -> b c (f p) (h q) (w r)"
|
||||
where p=patch_size_t, r=patch_size_hw (width), q=patch_size_hw (height)
|
||||
|
||||
Args:
|
||||
x: Patched tensor of shape (B, C * patch_size_hw^2, F, H, W)
|
||||
patch_size_hw: Spatial patch size
|
||||
patch_size_t: Temporal patch size
|
||||
|
||||
Returns:
|
||||
Video tensor of shape (B, C, F * patch_size_t, H * patch_size_hw, W * patch_size_hw)
|
||||
"""
|
||||
b, c_packed, f, h, w = x.shape
|
||||
|
||||
# Calculate original channel count
|
||||
c = c_packed // (patch_size_hw * patch_size_hw * patch_size_t)
|
||||
|
||||
# Reshape: (B, C*pt*pr*pq, F, H, W) -> (B, C, pt, pr, pq, F, H, W)
|
||||
# where pt=temporal, pr=width_patch (r), pq=height_patch (q)
|
||||
# Channel layout from PyTorch is (c, p, r, q) = (c, temporal, width, height)
|
||||
x = mx.reshape(x, (b, c, patch_size_t, patch_size_hw, patch_size_hw, f, h, w))
|
||||
|
||||
# Permute to interleave patches with spatial dims:
|
||||
# (B, C, pt, pr, pq, F, H, W) -> (B, C, F, pt, H, pq, W, pr)
|
||||
|
||||
x = mx.transpose(x, (0, 1, 5, 2, 6, 4, 7, 3))
|
||||
|
||||
# Reshape: (B, C, F, pt, H, pq, W, pr) -> (B, C, F*pt, H*pq, W*pr)
|
||||
x = mx.reshape(x, (b, c, f * patch_size_t, h * patch_size_hw, w * patch_size_hw))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PerChannelStatistics(nn.Module):
|
||||
|
||||
def __init__(self, latent_channels: int = 128):
|
||||
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
|
||||
# Learnable per-channel mean and std
|
||||
self.mean = mx.zeros((latent_channels,))
|
||||
self.std = mx.ones((latent_channels,))
|
||||
|
||||
def normalize(self, x: mx.array) -> mx.array:
|
||||
"""Normalize latents using per-channel statistics.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, ...)
|
||||
|
||||
Returns:
|
||||
Normalized tensor
|
||||
"""
|
||||
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return ((x - mean) / std).astype(dtype)
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics.
|
||||
|
||||
Args:
|
||||
x: Normalized tensor of shape (B, C, ...)
|
||||
|
||||
Returns:
|
||||
Denormalized tensor
|
||||
"""
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision
|
||||
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
|
||||
return (x * std + mean).astype(dtype)
|
||||
172
mlx_video/models/ltx_2/video_vae/resnet.py
Normal file
172
mlx_video/models/ltx_2/video_vae/resnet.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""ResNet blocks for Video VAE."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
class NormLayerType(Enum):
|
||||
GROUP_NORM = "group_norm"
|
||||
PIXEL_NORM = "pixel_norm"
|
||||
|
||||
|
||||
def get_norm_layer(
|
||||
norm_type: NormLayerType,
|
||||
num_channels: int,
|
||||
num_groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
) -> nn.Module:
|
||||
|
||||
if norm_type == NormLayerType.GROUP_NORM:
|
||||
return nn.GroupNorm(num_groups=num_groups, dims=num_channels, eps=eps)
|
||||
elif norm_type == NormLayerType.PIXEL_NORM:
|
||||
return PixelNorm(eps=eps)
|
||||
else:
|
||||
raise ValueError(f"Unknown norm type: {norm_type}")
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
eps: float = 1e-6,
|
||||
groups: int = 32,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.inject_noise = inject_noise
|
||||
|
||||
# First normalization and convolution
|
||||
self.norm1 = get_norm_layer(norm_layer, in_channels, groups, eps)
|
||||
self.conv1 = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Second normalization and convolution
|
||||
self.norm2 = get_norm_layer(norm_layer, out_channels, groups, eps)
|
||||
self.conv2 = CausalConv3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Shortcut connection if channels change
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
self.shortcut = None
|
||||
|
||||
# Activation
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = True,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
residual = x
|
||||
|
||||
# First block
|
||||
x = self.norm1(x)
|
||||
x = self.act(x)
|
||||
x = self.conv1(x, causal=causal)
|
||||
|
||||
# Inject noise if enabled
|
||||
if self.inject_noise and generator is not None:
|
||||
noise = mx.random.normal(x.shape)
|
||||
x = x + noise * 0.01
|
||||
|
||||
# Second block
|
||||
x = self.norm2(x)
|
||||
x = self.act(x)
|
||||
x = self.conv2(x, causal=causal)
|
||||
|
||||
# Shortcut
|
||||
if self.shortcut is not None:
|
||||
residual = self.shortcut(residual, causal=causal)
|
||||
|
||||
return x + residual
|
||||
|
||||
|
||||
class UNetMidBlock3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Create ResNet blocks - use dict for MLX parameter tracking
|
||||
# Named res_blocks to match PyTorch weight keys
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3D(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
causal: bool = True,
|
||||
timestep: Optional[mx.array] = None,
|
||||
generator: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
|
||||
for resnet in self.res_blocks.values():
|
||||
x = resnet(x, causal=causal, generator=generator)
|
||||
|
||||
return x
|
||||
275
mlx_video/models/ltx_2/video_vae/sampling.py
Normal file
275
mlx_video/models/ltx_2/video_vae/sampling.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""Sampling operations for Video VAE (upsampling/downsampling)."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
|
||||
|
||||
class SpaceToDepthDownsample(nn.Module):
|
||||
"""Space-to-depth downsampling with 3x3 conv and skip connection.
|
||||
|
||||
PyTorch-compatible implementation:
|
||||
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
|
||||
2. Space-to-depth on conv output: channels * prod(stride)
|
||||
3. Space-to-depth on input with group averaging for skip connection
|
||||
4. Add skip connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
self.out_channels = out_channels
|
||||
|
||||
# Calculate channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
self.group_size = in_channels * multiplier // out_channels
|
||||
conv_out_channels = out_channels // multiplier
|
||||
|
||||
# 3x3 convolution (not 1x1)
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=conv_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def _space_to_depth(self, x: mx.array) -> mx.array:
|
||||
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Reshape to group spatial elements
|
||||
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
||||
|
||||
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||
|
||||
# Reshape to combine channels
|
||||
new_c = c * st * sh * sw
|
||||
new_d = d // st
|
||||
new_h = h // sh
|
||||
new_w = w // sw
|
||||
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Temporal padding for causal mode
|
||||
if st == 2:
|
||||
# Duplicate first frame for padding
|
||||
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
|
||||
d = d + 1
|
||||
|
||||
# Pad if necessary to make dimensions divisible by stride
|
||||
pad_d = (st - d % st) % st
|
||||
pad_h = (sh - h % sh) % sh
|
||||
pad_w = (sw - w % sw) % sw
|
||||
|
||||
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
||||
|
||||
# Skip connection: space-to-depth on input, then group mean
|
||||
x_in = self._space_to_depth(x)
|
||||
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
|
||||
b2, c2, d2, h2, w2 = x_in.shape
|
||||
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
|
||||
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
|
||||
|
||||
# Conv branch: apply conv then space-to-depth
|
||||
x_conv = self.conv(x, causal=causal)
|
||||
x_conv = self._space_to_depth(x_conv)
|
||||
|
||||
# Add skip connection
|
||||
return x_conv + x_in
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
in_channels: int,
|
||||
stride: Union[int, Tuple[int, int, int]],
|
||||
residual: bool = False,
|
||||
out_channels_reduction_factor: int = 1,
|
||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if isinstance(stride, int):
|
||||
stride = (stride, stride, stride)
|
||||
|
||||
self.stride = stride
|
||||
self.dims = dims
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
|
||||
# Calculate output channels
|
||||
multiplier = stride[0] * stride[1] * stride[2]
|
||||
out_channels = in_channels // out_channels_reduction_factor
|
||||
self.out_channels = out_channels
|
||||
|
||||
# 3x3x3 convolution to prepare channels for unpacking (matches PyTorch)
|
||||
self.conv = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * multiplier,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
|
||||
def _depth_to_space(self, x: mx.array) -> mx.array:
|
||||
b, c_packed, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
c = c_packed // (st * sh * sw)
|
||||
|
||||
# (B, C*st*sh*sw, D, H, W) -> (B, C, st, sh, sw, D, H, W)
|
||||
x = mx.reshape(x, (b, c, st, sh, sw, d, h, w))
|
||||
|
||||
# (B, C, st, sh, sw, D, H, W) -> (B, C, D, st, H, sh, W, sw)
|
||||
x = mx.transpose(x, (0, 1, 5, 2, 6, 3, 7, 4))
|
||||
|
||||
# (B, C, D, st, H, sh, W, sw) -> (B, C, D*st, H*sh, W*sw)
|
||||
x = mx.reshape(x, (b, c, d * st, h * sh, w * sw))
|
||||
|
||||
return x
|
||||
|
||||
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
|
||||
# Compute residual path if enabled
|
||||
x_residual = None
|
||||
if self.residual:
|
||||
# Reshape input: treat channels as spatial factors
|
||||
# "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)"
|
||||
x_residual = self._depth_to_space(x)
|
||||
|
||||
# Tile channels to match output (PyTorch .repeat() tiles, not element-repeat!)
|
||||
# num_repeat = prod(stride) / out_channels_reduction_factor
|
||||
num_repeat = (st * sh * sw) // self.out_channels_reduction_factor
|
||||
x_residual = mx.tile(x_residual, (1, num_repeat, 1, 1, 1))
|
||||
|
||||
# Remove first temporal frame if temporal upsampling
|
||||
if st > 1:
|
||||
x_residual = x_residual[:, :, 1:, :, :]
|
||||
|
||||
# Use chunked mode for large tensors to reduce peak memory
|
||||
if chunked_conv and d > 4:
|
||||
x = self._chunked_conv_depth_to_space(x, causal)
|
||||
else:
|
||||
# Apply conv
|
||||
x = self.conv(x, causal=causal)
|
||||
# Depth to space rearrangement
|
||||
x = self._depth_to_space(x)
|
||||
|
||||
# Remove first frame for causal temporal upsampling
|
||||
if st > 1:
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
# Add residual
|
||||
if self.residual and x_residual is not None:
|
||||
x = x + x_residual
|
||||
|
||||
return x
|
||||
|
||||
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||
"""Chunked conv + depth_to_space that processes in temporal chunks.
|
||||
|
||||
This reduces peak memory by avoiding the full high-channel intermediate tensor.
|
||||
Instead of materializing (B, 4096, D, H, W), we process temporal chunks and
|
||||
immediately apply depth_to_space.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, C, D, H, W)
|
||||
causal: Whether to use causal convolutions
|
||||
|
||||
Returns:
|
||||
Output tensor after conv + depth_to_space
|
||||
"""
|
||||
b, c, d, h, w = x.shape
|
||||
st, sh, sw = self.stride
|
||||
out_c = self.out_channels
|
||||
|
||||
# Output dimensions
|
||||
out_d = d * st
|
||||
out_h = h * sh
|
||||
out_w = w * sw
|
||||
|
||||
# Chunk size in temporal dimension (process 4 frames at a time)
|
||||
chunk_size = 4
|
||||
kernel_t = 3 # Temporal kernel size
|
||||
|
||||
# For causal conv, we need (kernel_t - 1) frames of padding at the start
|
||||
# For non-causal, we need (kernel_t - 1) // 2 on each side
|
||||
if causal:
|
||||
# Pad start with first frame repeated
|
||||
pad_start = kernel_t - 1
|
||||
pad_end = 0
|
||||
else:
|
||||
pad_start = (kernel_t - 1) // 2
|
||||
pad_end = (kernel_t - 1) // 2
|
||||
|
||||
# Allocate output
|
||||
outputs = []
|
||||
|
||||
# Process in chunks with overlap for conv kernel
|
||||
t_pos = 0
|
||||
while t_pos < d:
|
||||
t_end = min(t_pos + chunk_size, d)
|
||||
|
||||
# Calculate input range with padding for kernel
|
||||
in_start = max(0, t_pos - pad_start)
|
||||
in_end = min(d, t_end + pad_end)
|
||||
|
||||
# Extract chunk
|
||||
chunk = x[:, :, in_start:in_end, :, :]
|
||||
|
||||
# Apply conv to chunk
|
||||
chunk_conv = self.conv(chunk, causal=causal)
|
||||
|
||||
# Apply depth_to_space
|
||||
chunk_out = self._depth_to_space(chunk_conv)
|
||||
|
||||
# Calculate valid output range (excluding padding effects)
|
||||
# Each input frame produces st output frames
|
||||
out_start = (t_pos - in_start) * st
|
||||
out_end = out_start + (t_end - t_pos) * st
|
||||
|
||||
# Extract valid portion
|
||||
chunk_out = chunk_out[:, :, out_start:out_end, :, :]
|
||||
|
||||
outputs.append(chunk_out)
|
||||
|
||||
# Evaluate to free intermediate memory
|
||||
mx.eval(outputs[-1])
|
||||
|
||||
t_pos = t_end
|
||||
|
||||
# Concatenate all chunks
|
||||
if len(outputs) == 1:
|
||||
return outputs[0]
|
||||
return mx.concatenate(outputs, axis=2)
|
||||
492
mlx_video/models/ltx_2/video_vae/tiling.py
Normal file
492
mlx_video/models/ltx_2/video_vae/tiling.py
Normal file
@@ -0,0 +1,492 @@
|
||||
"""VAE Tiling Configuration for decoding large videos.
|
||||
|
||||
Implements spatial and temporal tiling with trapezoidal blending masks
|
||||
to decode large videos without running out of memory.
|
||||
|
||||
Default configuration (from PyTorch):
|
||||
- Spatial: 512px tiles with 64px overlap
|
||||
- Temporal: 64 frames with 24 frame overlap
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def compute_trapezoidal_mask_1d(
|
||||
length: int,
|
||||
ramp_left: int,
|
||||
ramp_right: int,
|
||||
left_starts_from_0: bool = False,
|
||||
) -> mx.array:
|
||||
"""Generate a 1D trapezoidal blending mask with linear ramps.
|
||||
|
||||
Args:
|
||||
length: Output length of the mask.
|
||||
ramp_left: Fade-in length on the left.
|
||||
ramp_right: Fade-out length on the right.
|
||||
left_starts_from_0: Whether the ramp starts from 0 or first non-zero value.
|
||||
Useful for temporal tiles where the first tile is causal.
|
||||
|
||||
Returns:
|
||||
A 1D array of shape (length,) with values in [0, 1].
|
||||
"""
|
||||
if length <= 0:
|
||||
raise ValueError("Mask length must be positive.")
|
||||
|
||||
ramp_left = max(0, min(ramp_left, length))
|
||||
ramp_right = max(0, min(ramp_right, length))
|
||||
|
||||
# Start with ones
|
||||
mask = [1.0] * length
|
||||
|
||||
# Apply left ramp (fade in)
|
||||
if ramp_left > 0:
|
||||
interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2
|
||||
# Create fade_in values using linspace logic
|
||||
fade_in_full = [i / (interval_length - 1) for i in range(interval_length)]
|
||||
fade_in = fade_in_full[:-1] # Remove last element
|
||||
if not left_starts_from_0:
|
||||
fade_in = fade_in[1:] # Remove first element too
|
||||
for i in range(min(ramp_left, len(fade_in))):
|
||||
mask[i] *= fade_in[i]
|
||||
|
||||
# Apply right ramp (fade out)
|
||||
if ramp_right > 0:
|
||||
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
||||
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
|
||||
for i in range(ramp_right):
|
||||
mask[length - ramp_right + i] *= fade_out[i]
|
||||
|
||||
return mx.clip(mx.array(mask), 0, 1)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpatialTilingConfig:
|
||||
"""Configuration for dividing each frame into spatial tiles with optional overlap."""
|
||||
|
||||
tile_size_in_pixels: int
|
||||
tile_overlap_in_pixels: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_pixels < 64:
|
||||
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
||||
if self.tile_size_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
||||
if self.tile_overlap_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
||||
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TemporalTilingConfig:
|
||||
"""Configuration for dividing a video into temporal tiles."""
|
||||
|
||||
tile_size_in_frames: int
|
||||
tile_overlap_in_frames: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_frames < 16:
|
||||
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
||||
if self.tile_size_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
||||
if self.tile_overlap_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
||||
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TilingConfig:
|
||||
"""Configuration for splitting video into tiles with optional overlap."""
|
||||
|
||||
spatial_config: Optional[SpatialTilingConfig] = None
|
||||
temporal_config: Optional[TemporalTilingConfig] = None
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "TilingConfig":
|
||||
"""Default tiling: 512px spatial, 64 frame temporal."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
||||
"""Spatial tiling only (for short videos with large resolution)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
|
||||
temporal_config=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig":
|
||||
"""Temporal tiling only (for long videos with small resolution)."""
|
||||
return cls(
|
||||
spatial_config=None,
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def aggressive(cls) -> "TilingConfig":
|
||||
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def conservative(cls) -> "TilingConfig":
|
||||
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def auto(
|
||||
cls,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
spatial_threshold: int = 512,
|
||||
temporal_threshold: int = 65,
|
||||
) -> Optional["TilingConfig"]:
|
||||
"""Automatically determine tiling config based on video dimensions.
|
||||
|
||||
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
|
||||
enough context for CausalConv3d and sufficient overlap for clean blending.
|
||||
|
||||
Args:
|
||||
height: Video height in pixels
|
||||
width: Video width in pixels
|
||||
num_frames: Number of video frames
|
||||
spatial_threshold: Enable spatial tiling if either dimension exceeds this
|
||||
temporal_threshold: Enable temporal tiling if frames exceed this
|
||||
|
||||
Returns:
|
||||
TilingConfig if tiling is needed, None otherwise
|
||||
"""
|
||||
needs_spatial = height > spatial_threshold or width > spatial_threshold
|
||||
needs_temporal = num_frames > temporal_threshold
|
||||
|
||||
if not needs_spatial and not needs_temporal:
|
||||
return None
|
||||
|
||||
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
|
||||
# Smaller tiles cause quality degradation because CausalConv3d needs
|
||||
# sufficient temporal context and overlap for clean blending.
|
||||
spatial_config = None
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
||||
|
||||
if needs_temporal:
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DimensionIntervals:
|
||||
"""Intervals for splitting a single dimension."""
|
||||
starts: List[int]
|
||||
ends: List[int]
|
||||
left_ramps: List[int]
|
||||
right_ramps: List[int]
|
||||
|
||||
|
||||
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
"""Split a spatial dimension into intervals."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
|
||||
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
||||
starts = [i * (size - overlap) for i in range(amount)]
|
||||
ends = [start + size for start in starts]
|
||||
ends[-1] = dimension_size
|
||||
left_ramps = [0] + [overlap] * (amount - 1)
|
||||
right_ramps = [overlap] * (amount - 1) + [0]
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
||||
|
||||
|
||||
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
"""Split a temporal dimension into intervals with causal adjustment."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
|
||||
# Start with spatial split
|
||||
intervals = split_in_spatial(size, overlap, dimension_size)
|
||||
|
||||
# Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1
|
||||
starts = intervals.starts.copy()
|
||||
left_ramps = intervals.left_ramps.copy()
|
||||
|
||||
for i in range(1, len(starts)):
|
||||
starts[i] = starts[i] - 1
|
||||
left_ramps[i] = left_ramps[i] + 1
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
|
||||
|
||||
|
||||
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
"""Map temporal latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = 1 + (end - 1) * scale
|
||||
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
"""Map spatial latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = end * scale
|
||||
left_ramp_scaled = left_ramp * scale
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def decode_with_tiling(
|
||||
decoder_fn,
|
||||
latents: mx.array,
|
||||
tiling_config: TilingConfig,
|
||||
spatial_scale: int = 32,
|
||||
temporal_scale: int = 8,
|
||||
causal: bool = False,
|
||||
timestep: Optional[mx.array] = None,
|
||||
chunked_conv: bool = False,
|
||||
on_frames_ready: Optional[Callable[[mx.array, int], None]] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latents using tiling to reduce memory usage.
|
||||
|
||||
Args:
|
||||
decoder_fn: Decoder function to call for each tile.
|
||||
latents: Input latents of shape (B, C, F, H, W).
|
||||
tiling_config: Tiling configuration.
|
||||
spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify).
|
||||
temporal_scale: Temporal scale factor (8 for LTX VAE).
|
||||
causal: Whether to use causal convolutions.
|
||||
timestep: Optional timestep for conditioning.
|
||||
chunked_conv: Whether to use chunked conv mode for upsampling (reduces memory).
|
||||
on_frames_ready: Optional callback called with (frames, start_idx) when frames are finalized.
|
||||
frames: Tensor of shape (B, 3, num_frames, H, W) with finalized RGB frames.
|
||||
start_idx: Starting frame index in the full video.
|
||||
|
||||
Returns:
|
||||
Decoded video.
|
||||
"""
|
||||
import gc
|
||||
|
||||
b, c, f_latent, h_latent, w_latent = latents.shape
|
||||
|
||||
# Compute output shape
|
||||
out_f = 1 + (f_latent - 1) * temporal_scale
|
||||
out_h = h_latent * spatial_scale
|
||||
out_w = w_latent * spatial_scale
|
||||
|
||||
# Get tile size and overlap in latent space
|
||||
if tiling_config.spatial_config is not None:
|
||||
s_cfg = tiling_config.spatial_config
|
||||
spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale
|
||||
spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale
|
||||
else:
|
||||
spatial_tile_size = max(h_latent, w_latent)
|
||||
spatial_overlap = 0
|
||||
|
||||
if tiling_config.temporal_config is not None:
|
||||
t_cfg = tiling_config.temporal_config
|
||||
temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale
|
||||
temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale
|
||||
else:
|
||||
temporal_tile_size = f_latent
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
num_t_tiles = len(temporal_intervals.starts)
|
||||
num_h_tiles = len(height_intervals.starts)
|
||||
num_w_tiles = len(width_intervals.starts)
|
||||
total_tiles = num_t_tiles * num_h_tiles * num_w_tiles
|
||||
|
||||
# Initialize output and weight accumulator
|
||||
# Use float32 for accumulation to avoid precision issues
|
||||
output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32)
|
||||
weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32)
|
||||
mx.eval(output, weights)
|
||||
|
||||
tile_idx = 0
|
||||
for t_idx in range(num_t_tiles):
|
||||
t_start = temporal_intervals.starts[t_idx]
|
||||
t_end = temporal_intervals.ends[t_idx]
|
||||
t_left = temporal_intervals.left_ramps[t_idx]
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
h_end = height_intervals.ends[h_idx]
|
||||
h_left = height_intervals.left_ramps[h_idx]
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
w_end = width_intervals.ends[w_idx]
|
||||
w_left = width_intervals.left_ramps[w_idx]
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
del tile_latents
|
||||
|
||||
# Get actual decoded dimensions
|
||||
_, _, decoded_t, decoded_h, decoded_w = tile_output.shape
|
||||
expected_t = out_t_slice.stop - out_t_slice.start
|
||||
expected_h = out_h_slice.stop - out_h_slice.start
|
||||
expected_w = out_w_slice.stop - out_w_slice.start
|
||||
|
||||
# Handle potential size mismatches (use minimum)
|
||||
actual_t = min(decoded_t, expected_t)
|
||||
actual_h = min(decoded_h, expected_h)
|
||||
actual_w = min(decoded_w, expected_w)
|
||||
|
||||
# Build blend mask
|
||||
t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask
|
||||
h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
|
||||
# Compute output coordinates
|
||||
t_out_start = out_t_slice.start
|
||||
t_out_end = t_out_start + actual_t
|
||||
h_out_start = out_h_slice.start
|
||||
h_out_end = h_out_start + actual_h
|
||||
w_out_start = out_w_slice.start
|
||||
w_out_end = w_out_start + actual_w
|
||||
|
||||
# Use direct slice assignment (MLX supports this)
|
||||
# Weighted accumulation
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||
)
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
mx.eval(output, weights)
|
||||
|
||||
# Clean up tile-specific arrays
|
||||
del tile_output_slice, weighted_tile, blend_mask
|
||||
del t_mask_slice, h_mask_slice, w_mask_slice
|
||||
|
||||
tile_idx += 1
|
||||
|
||||
# Periodic garbage collection and cache clearing
|
||||
if tile_idx % 4 == 0:
|
||||
gc.collect()
|
||||
try:
|
||||
mx.clear_cache()
|
||||
except Exception:
|
||||
pass # May not be available on all platforms
|
||||
|
||||
# After completing all spatial tiles for this temporal tile,
|
||||
# check if any frames are now finalized (no future tiles will contribute)
|
||||
if on_frames_ready is not None and num_t_tiles > 1:
|
||||
# Determine the finalized frame boundary
|
||||
# Frames before the start of the next tile's output region are finalized
|
||||
if t_idx < num_t_tiles - 1:
|
||||
# Next tile starts at temporal_intervals.starts[t_idx + 1]
|
||||
next_tile_start_latent = temporal_intervals.starts[t_idx + 1]
|
||||
# Map to output frame index (first frame of next tile's contribution)
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
else:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
if next_tile_start_out > emitted:
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
on_frames_ready(finalized_output, emitted)
|
||||
decode_with_tiling._emitted_frames = next_tile_start_out
|
||||
|
||||
del finalized_output, finalized_weights
|
||||
gc.collect()
|
||||
|
||||
# Normalize by weights
|
||||
weights = mx.maximum(weights, 1e-8)
|
||||
output = output / weights
|
||||
mx.eval(output)
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
on_frames_ready(remaining_output, emitted)
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
del weights
|
||||
gc.collect()
|
||||
|
||||
# Convert back to original dtype if needed
|
||||
return output.astype(latents.dtype)
|
||||
597
mlx_video/models/ltx_2/video_vae/video_vae.py
Normal file
597
mlx_video/models/ltx_2/video_vae/video_vae.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""Video VAE Encoder and Decoder for LTX-2."""
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import (
|
||||
NormLayerType,
|
||||
ResnetBlock3D,
|
||||
UNetMidBlock3D,
|
||||
get_norm_layer,
|
||||
)
|
||||
from mlx_video.models.ltx_2.video_vae.sampling import (
|
||||
DepthToSpaceUpsample,
|
||||
SpaceToDepthDownsample,
|
||||
)
|
||||
from mlx_video.utils import PixelNorm
|
||||
|
||||
|
||||
class LogVarianceType(Enum):
|
||||
"""Log variance mode for VAE."""
|
||||
PER_CHANNEL = "per_channel"
|
||||
UNIFORM = "uniform"
|
||||
CONSTANT = "constant"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
def _make_encoder_block(
|
||||
block_name: str,
|
||||
block_config: Dict[str, Any],
|
||||
in_channels: int,
|
||||
convolution_dimensions: int,
|
||||
norm_layer: NormLayerType,
|
||||
norm_num_groups: int,
|
||||
spatial_padding_mode: PaddingModeType,
|
||||
) -> Tuple[nn.Module, int]:
|
||||
"""Create an encoder block.
|
||||
|
||||
Args:
|
||||
block_name: Type of block
|
||||
block_config: Block configuration
|
||||
in_channels: Input channels
|
||||
convolution_dimensions: Number of dimensions
|
||||
norm_layer: Normalization layer type
|
||||
norm_num_groups: Number of groups for group norm
|
||||
spatial_padding_mode: Padding mode
|
||||
|
||||
Returns:
|
||||
Tuple of (block, output_channels)
|
||||
"""
|
||||
out_channels = in_channels
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
num_layers=block_config["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 1, 1),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_x_y":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=(2, 2, 2),
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time_res":
|
||||
out_channels = in_channels * block_config.get("multiplier", 2)
|
||||
block = SpaceToDepthDownsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder block: {block_name}")
|
||||
|
||||
return block, out_channels
|
||||
|
||||
|
||||
def _make_decoder_block(
|
||||
block_name: str,
|
||||
block_config: Dict[str, Any],
|
||||
in_channels: int,
|
||||
convolution_dimensions: int,
|
||||
norm_layer: NormLayerType,
|
||||
timestep_conditioning: bool,
|
||||
norm_num_groups: int,
|
||||
spatial_padding_mode: PaddingModeType,
|
||||
) -> Tuple[nn.Module, int]:
|
||||
"""Create a decoder block."""
|
||||
out_channels = in_channels
|
||||
|
||||
if block_name == "res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
num_layers=block_config["num_layers"],
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_config.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
out_channels = in_channels // block_config.get("multiplier", 2)
|
||||
block = ResnetBlock3D(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_config.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 1, 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(1, 2, 2),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
out_channels = in_channels // block_config.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=convolution_dimensions,
|
||||
in_channels=in_channels,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_config.get("residual", False),
|
||||
out_channels_reduction_factor=block_config.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown decoder block: {block_name}")
|
||||
|
||||
return block, out_channels
|
||||
|
||||
|
||||
class VideoEncoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(self, config: "VideoEncoderModelConfig"):
|
||||
"""Initialize VideoEncoder from config.
|
||||
|
||||
Args:
|
||||
config: VideoEncoderModelConfig with encoder parameters
|
||||
"""
|
||||
super().__init__()
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
self.patch_size = config.patch_size
|
||||
self.norm_layer = config.norm_layer
|
||||
self.latent_channels = config.out_channels
|
||||
self.latent_log_var = config.latent_log_var
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
encoder_blocks = config.encoder_blocks if config.encoder_blocks else []
|
||||
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
|
||||
|
||||
# Per-channel statistics for normalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
|
||||
|
||||
# After patchify, channels increase by patch_size^2
|
||||
in_channels = config.in_channels * config.patch_size ** 2
|
||||
feature_channels = config.out_channels
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feature_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build encoder blocks
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=config.convolution_dimensions,
|
||||
norm_layer=config.norm_layer,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks[idx] = block
|
||||
|
||||
# Output normalization and convolution
|
||||
if config.norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif config.norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
# Calculate output convolution channels
|
||||
conv_out_channels = config.out_channels
|
||||
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
|
||||
conv_out_channels *= 2
|
||||
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
|
||||
conv_out_channels += 1
|
||||
|
||||
self.conv_out = CausalConv3d(
|
||||
in_channels=feature_channels,
|
||||
out_channels=conv_out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""Encode video to latent representation.
|
||||
|
||||
Args:
|
||||
sample: Input video of shape (B, C, F, H, W).
|
||||
F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...)
|
||||
|
||||
Returns:
|
||||
Normalized latent means of shape (B, 128, F', H', W')
|
||||
"""
|
||||
# Validate frame count
|
||||
frames_count = sample.shape[2]
|
||||
if ((frames_count - 1) % 8) != 0:
|
||||
raise ValueError(
|
||||
"Invalid number of frames: Encode input must have 1 + 8 * x frames "
|
||||
f"(e.g., 1, 9, 17, ...). Got {frames_count} frames."
|
||||
)
|
||||
|
||||
# Initial patchify
|
||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
sample = self.conv_in(sample, causal=True)
|
||||
|
||||
# Process through encoder blocks
|
||||
for i in range(len(self.down_blocks)):
|
||||
down_block = self.down_blocks[i]
|
||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||
sample = down_block(sample, causal=True)
|
||||
else:
|
||||
sample = down_block(sample, causal=True)
|
||||
|
||||
# Output processing
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=True)
|
||||
|
||||
# Handle log variance modes
|
||||
if self.latent_log_var == LogVarianceType.UNIFORM:
|
||||
means = sample[:, :-1, ...]
|
||||
logvar = sample[:, -1:, ...]
|
||||
num_channels = means.shape[1]
|
||||
repeated_logvar = mx.tile(logvar, (1, num_channels, 1, 1, 1))
|
||||
sample = mx.concatenate([means, repeated_logvar], axis=1)
|
||||
elif self.latent_log_var == LogVarianceType.CONSTANT:
|
||||
sample = sample[:, :-1, ...]
|
||||
approx_ln_0 = -30
|
||||
sample = mx.concatenate([
|
||||
sample,
|
||||
mx.full_like(sample, approx_ln_0),
|
||||
], axis=1)
|
||||
|
||||
# Split into means and logvar, normalize means
|
||||
means = sample[:, :self.latent_channels, ...]
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize VAE encoder weights from PyTorch format to MLX format."""
|
||||
sanitized = {}
|
||||
if "per_channel_statistics.mean" in weights:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Only process VAE encoder weights
|
||||
if not key.startswith("vae."):
|
||||
continue
|
||||
|
||||
# Handle per-channel statistics
|
||||
if "vae.per_channel_statistics" in key:
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue
|
||||
elif key.startswith("vae.encoder."):
|
||||
new_key = key.replace("vae.encoder.", "")
|
||||
else:
|
||||
continue
|
||||
|
||||
# Conv3d: PyTorch (O, I, D, H, W) -> MLX (O, D, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Conv2d: PyTorch (O, I, H, W) -> MLX (O, H, W, I)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "VideoEncoder":
|
||||
"""Load a pretrained VideoEncoder from a directory with weights and config.
|
||||
|
||||
Args:
|
||||
model_path: Path to directory containing safetensors weights and config.json
|
||||
|
||||
Returns:
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
import json
|
||||
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
|
||||
|
||||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
config = VideoEncoderModelConfig(**config_dict)
|
||||
else:
|
||||
config = VideoEncoderModelConfig()
|
||||
|
||||
# Load weights
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
if model_path.is_file():
|
||||
weights = mx.load(str(model_path))
|
||||
else:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
else:
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
# Create model, sanitize and load weights
|
||||
model = cls(config)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=False)
|
||||
return model
|
||||
|
||||
|
||||
class VideoDecoder(nn.Module):
|
||||
|
||||
_DEFAULT_NORM_NUM_GROUPS = 32
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
convolution_dimensions: int = 3,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
decoder_blocks: List[Tuple[str, Any]] = None,
|
||||
patch_size: int = 4,
|
||||
norm_layer: NormLayerType = NormLayerType.PIXEL_NORM,
|
||||
causal: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT,
|
||||
):
|
||||
"""Initialize VideoDecoder.
|
||||
|
||||
Args:
|
||||
convolution_dimensions: Number of dimensions
|
||||
in_channels: Input latent channels
|
||||
out_channels: Output channels (3 for RGB)
|
||||
decoder_blocks: List of (block_name, config) tuples
|
||||
patch_size: Spatial patch size
|
||||
norm_layer: Normalization layer type
|
||||
causal: Whether to use causal convolutions
|
||||
timestep_conditioning: Whether to use timestep conditioning
|
||||
decoder_spatial_padding_mode: Padding mode
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if decoder_blocks is None:
|
||||
decoder_blocks = []
|
||||
|
||||
self.patch_size = patch_size
|
||||
out_channels = out_channels * patch_size ** 2
|
||||
self.causal = causal
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
# Noise and timestep parameters
|
||||
self.decode_noise_scale = 0.025
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Compute initial feature channels
|
||||
feature_channels = in_channels
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
block_config = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
feature_channels = feature_channels * block_config.get("multiplier", 1)
|
||||
|
||||
# Initial convolution
|
||||
self.conv_in = CausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=feature_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build decoder blocks (reversed order)
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
block_name=block_name,
|
||||
block_config=block_config,
|
||||
in_channels=feature_channels,
|
||||
convolution_dimensions=convolution_dimensions,
|
||||
norm_layer=norm_layer,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
self.up_blocks[idx] = block
|
||||
|
||||
# Output normalization
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_groups=self._norm_num_groups,
|
||||
dims=feature_channels,
|
||||
eps=1e-6,
|
||||
)
|
||||
elif norm_layer == NormLayerType.PIXEL_NORM:
|
||||
self.conv_norm_out = PixelNorm()
|
||||
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(
|
||||
in_channels=feature_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
causal=True,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: mx.array,
|
||||
timestep: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
"""Decode latent to video.
|
||||
|
||||
Args:
|
||||
sample: Latent tensor of shape (B, 128, F', H', W')
|
||||
timestep: Optional timestep for conditioning
|
||||
|
||||
Returns:
|
||||
Decoded video of shape (B, 3, F, H, W)
|
||||
"""
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
|
||||
# Denormalize latents
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
# Use default timestep if not provided
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
|
||||
# Initial convolution
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
# Process through decoder blocks
|
||||
for i in range(len(self.up_blocks)):
|
||||
up_block = self.up_blocks[i]
|
||||
if isinstance(up_block, UNetMidBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
elif isinstance(up_block, ResnetBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
else:
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
|
||||
# Output processing
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
# Unpatchify to restore spatial resolution
|
||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
|
||||
return sample
|
||||
Reference in New Issue
Block a user