This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,3 +1,2 @@
from mlx_video.models.ltx_2 import LTXModel, LTXModelConfig
from mlx_video.models.wan import WanModel, WanModelConfig

View File

@@ -1,8 +1,7 @@
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
TransformerConfig,
LTXModelType,
TransformerConfig,
)
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
from mlx_video.models.ltx_2.audio_vae import AudioDecoder, Vocoder, decode_audio

View File

@@ -8,7 +8,6 @@ from mlx_video.utils import get_timestep_embedding
class AdaLayerNormSingle(nn.Module):
def __init__(
self,
embedding_dim: int,
@@ -24,7 +23,9 @@ class AdaLayerNormSingle(nn.Module):
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True)
self.linear = nn.Linear(
embedding_dim, embedding_coefficient * embedding_dim, bias=True
)
def __call__(
self,
@@ -56,15 +57,19 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
use_additional_conditions: bool = False,
timestep_proj_dim: int = 256,
):
super().__init__()
self.embedding_dim = embedding_dim
self.size_emb_dim = size_emb_dim
self.use_additional_conditions = use_additional_conditions
self.time_proj = Timesteps(timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(timestep_proj_dim, embedding_dim, out_dim=embedding_dim)
self.time_proj = Timesteps(
timestep_proj_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.timestep_embedder = TimestepEmbedding(
timestep_proj_dim, embedding_dim, out_dim=embedding_dim
)
if use_additional_conditions and size_emb_dim > 0:
self.additional_embedder = ConditionEmbedding(size_emb_dim, embedding_dim)
@@ -87,7 +92,9 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
# Add additional conditions if enabled
if self.use_additional_conditions and self.size_emb_dim > 0:
if resolution is not None and aspect_ratio is not None:
additional_embeds = self.additional_embedder(resolution, aspect_ratio, hidden_dtype)
additional_embeds = self.additional_embedder(
resolution, aspect_ratio, hidden_dtype
)
timesteps_emb = timesteps_emb + additional_embeds
return timesteps_emb

View File

@@ -1,10 +1,10 @@
"""Audio VAE module for LTX-2 audio generation."""
from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
from .audio_processor import load_audio, ensure_stereo, waveform_to_mel
from .causal_conv_2d import CausalConv2d, make_conv2d
from ..config import CausalityAxis
from .attention import AttentionType, AttnBlock, make_attn
from .audio_processor import ensure_stereo, load_audio, waveform_to_mel
from .audio_vae import AudioDecoder, AudioEncoder, decode_audio
from .causal_conv_2d import CausalConv2d, make_conv2d
from .downsample import Downsample, build_downsampling_path
from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics

View File

@@ -32,7 +32,9 @@ class AttnBlock(nn.Module):
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def __call__(self, x: mx.array) -> mx.array:
"""
@@ -103,6 +105,8 @@ def make_attn(
elif attn_type == AttentionType.NONE:
return Identity()
elif attn_type == AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
raise NotImplementedError(
f"Attention type {attn_type.value} is not supported yet."
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")

View File

@@ -4,10 +4,9 @@ Matches the PyTorch AudioProcessor from LTX-2 (torchaudio.transforms.MelSpectrog
using librosa for macOS/MLX compatibility.
"""
from pathlib import Path
import numpy as np
import mlx.core as mx
import numpy as np
def load_audio(
@@ -99,14 +98,16 @@ def waveform_to_mel(
for ch in range(channels):
# Magnitude spectrogram (power=1.0)
S = np.abs(librosa.stft(
waveform[ch],
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=True,
pad_mode="reflect",
))
S = np.abs(
librosa.stft(
waveform[ch],
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=True,
pad_mode="reflect",
)
)
# Mel filterbank with slaney normalization
mel_basis = librosa.filters.mel(

View File

@@ -1,15 +1,15 @@
"""Audio VAE encoder and decoder for LTX-2."""
from typing import Dict
from pathlib import Path
from typing import Dict
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig
from ..config import AudioDecoderModelConfig, AudioEncoderModelConfig, CausalityAxis
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
@@ -39,7 +39,9 @@ def build_mid_block(
causality_axis=causality_axis,
)
mid["attn_1"] = (
make_attn(channels, attn_type=attn_type, norm_type=norm_type) if add_attention else None
make_attn(channels, attn_type=attn_type, norm_type=norm_type)
if add_attention
else None
)
mid["block_2"] = ResnetBlock(
in_channels=channels,
@@ -93,7 +95,10 @@ class AudioEncoder(nn.Module):
self.attn_type = config.attn_type
self.conv_in = make_conv2d(
config.in_channels, self.ch, kernel_size=3, stride=1,
config.in_channels,
self.ch,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
@@ -125,7 +130,10 @@ class AudioEncoder(nn.Module):
self.norm_out = build_normalization_layer(block_in, normtype=self.norm_type)
out_channels = 2 * config.z_channels if config.double_z else config.z_channels
self.conv_out = make_conv2d(
block_in, out_channels, kernel_size=3, stride=1,
block_in,
out_channels,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
@@ -160,7 +168,11 @@ class AudioEncoder(nn.Module):
continue
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
value = (
value
if check_array_shape(value)
else mx.transpose(value, (0, 2, 3, 1))
)
sanitized[new_key] = value
return sanitized
@@ -168,11 +180,14 @@ class AudioEncoder(nn.Module):
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioEncoder":
"""Load audio encoder from pretrained weights."""
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
import json
from mlx_video.models.ltx_2.config import AudioEncoderModelConfig
model_path = Path(model_path)
config = AudioEncoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
config = AudioEncoderModelConfig.from_dict(
json.load(open(model_path / "config.json"))
)
encoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
encoder.load_weights(list(weights.items()), strict=True)
@@ -265,7 +280,6 @@ class AudioDecoder(nn.Module):
"""
super().__init__()
# Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
@@ -305,7 +319,11 @@ class AudioDecoder(nn.Module):
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
config.z_channels,
base_block_channels,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
self.mid = build_mid_block(
@@ -334,9 +352,15 @@ class AudioDecoder(nn.Module):
initial_block_channels=base_block_channels,
)
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.norm_out = build_normalization_layer(
final_block_channels, normtype=self.norm_type
)
self.conv_out = make_conv2d(
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
final_block_channels,
config.out_ch,
kernel_size=3,
stride=1,
causality_axis=self.causality_axis,
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
@@ -371,7 +395,11 @@ class AudioDecoder(nn.Module):
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
value = (
value
if check_array_shape(value)
else mx.transpose(value, (0, 2, 3, 1))
)
sanitized[new_key] = value
@@ -380,17 +408,19 @@ class AudioDecoder(nn.Module):
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model."""
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
import json
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
from mlx_video.models.ltx_2.config import AudioDecoderModelConfig
config = AudioDecoderModelConfig.from_dict(
json.load(open(model_path / "config.json"))
)
decoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = decoder.sanitize(weights)
decoder.load_weights(list(weights.items()), strict=True)
return decoder
def __call__(self, sample: mx.array) -> mx.array:
"""
Decode latent features back to audio spectrograms.
@@ -414,7 +444,9 @@ class AudioDecoder(nn.Module):
return self._adjust_output_shape(h, target_shape)
def _denormalize_latents(self, sample: mx.array) -> tuple[mx.array, AudioLatentShape]:
def _denormalize_latents(
self, sample: mx.array
) -> tuple[mx.array, AudioLatentShape]:
"""Denormalize latents using per-channel statistics."""
# sample shape: (B, H, W, C) in MLX format
latent_shape = AudioLatentShape(
@@ -436,7 +468,9 @@ class AudioDecoder(nn.Module):
batch=latent_shape.batch,
channels=self.out_ch,
frames=target_frames,
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
mel_bins=(
self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins
),
)
return sample, target_shape
@@ -462,7 +496,10 @@ class AudioDecoder(nn.Module):
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:, : min(current_time, target_time), : min(current_freq, target_freq), :target_channels
:,
: min(current_time, target_time),
: min(current_freq, target_freq),
:target_channels,
]
# Step 2: Calculate padding needed for time and frequency dimensions
@@ -514,7 +551,9 @@ class AudioDecoder(nn.Module):
return mx.tanh(h) if self.tanh_out else h
def decode_audio(latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder") -> mx.array:
def decode_audio(
latent: mx.array, audio_decoder: AudioDecoder, vocoder: "Vocoder"
) -> mx.array:
"""
Decode an audio latent representation using the provided audio decoder and vocoder.
Args:

View File

@@ -53,8 +53,16 @@ class CausalConv2d(nn.Module):
# For (N, H, W, C) format: axis 1 is H (height), axis 2 is W (width)
if self.causality_axis == CausalityAxis.NONE:
# Non-causal: symmetric padding
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2)
elif self.causality_axis in (CausalityAxis.WIDTH, CausalityAxis.WIDTH_COMPATIBILITY):
self.padding = (
pad_h // 2,
pad_h - pad_h // 2,
pad_w // 2,
pad_w - pad_w // 2,
)
elif self.causality_axis in (
CausalityAxis.WIDTH,
CausalityAxis.WIDTH_COMPATIBILITY,
):
# Causal on width: pad left (before width axis)
self.padding = (pad_h // 2, pad_h - pad_h // 2, pad_w, 0)
elif self.causality_axis == CausalityAxis.HEIGHT:
@@ -90,7 +98,10 @@ class CausalConv2d(nn.Module):
if any(p > 0 for p in self.padding):
# MLX pad expects: [(before_0, after_0), (before_1, after_1), ...]
# For (N, H, W, C): axis 0=N, axis 1=H, axis 2=W, axis 3=C
x = mx.pad(x, [(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)])
x = mx.pad(
x,
[(0, 0), (pad_h_top, pad_h_bottom), (pad_w_left, pad_w_right), (0, 0)],
)
return self.conv(x)
@@ -124,7 +135,14 @@ def make_conv2d(
if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d(
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
in_channels,
out_channels,
kernel_size,
stride,
dilation,
groups,
bias,
causality_axis,
)
else:
# For non-causal convolution, use symmetric padding if not specified

View File

@@ -5,8 +5,8 @@ from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from ..config import CausalityAxis
from .attention import AttentionType, make_attn
from .normalization import NormType
from .resnet import ResnetBlock
@@ -34,7 +34,9 @@ class Downsample(nn.Module):
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in MLX conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array) -> mx.array:
"""
@@ -116,10 +118,14 @@ def build_downsampling_path(
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if i_level != num_resolutions - 1:
stage["downsample"] = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
stage["downsample"] = Downsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res = curr_res // 2
down_modules[i_level] = stage

View File

@@ -51,7 +51,9 @@ def build_normalization_layer(
A normalization layer
"""
if normtype == NormType.GROUP:
return nn.GroupNorm(num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True)
return nn.GroupNorm(
num_groups=num_groups, dims=in_channels, eps=1e-6, affine=True
)
if normtype == NormType.PIXEL:
# For MLX channels-last format (B, H, W, C), normalize along channels (dim=-1)
# PyTorch uses dim=1 for channels-first format (B, C, H, W)

View File

@@ -1,12 +1,12 @@
"""ResNet blocks for audio VAE and vocoder."""
from typing import List, Tuple
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .causal_conv_2d import make_conv2d
from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1
@@ -125,7 +125,11 @@ class ResnetBlock(nn.Module):
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
self.conv1 = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
if temb_channels > 0:
@@ -134,17 +138,29 @@ class ResnetBlock(nn.Module):
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
self.dropout_rate = dropout
self.conv2 = make_conv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
out_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
in_channels,
out_channels,
kernel_size=1,
stride=1,
causality_axis=causality_axis,
)
def __call__(
@@ -168,7 +184,9 @@ class ResnetBlock(nn.Module):
if temb is not None and self.temb_channels > 0:
# temb: (B, temb_channels) -> (B, out_channels)
# Need to add spatial dims: (B, 1, 1, out_channels) for broadcasting
h = h + mx.expand_dims(mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1)
h = h + mx.expand_dims(
mx.expand_dims(nn.silu(self.temb_proj(temb)), axis=1), axis=1
)
h = self.norm2(h)
h = nn.silu(h)

View File

@@ -5,9 +5,9 @@ from typing import Set, Tuple
import mlx.core as mx
import mlx.nn as nn
from ..config import CausalityAxis
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from ..config import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock
@@ -42,7 +42,11 @@ class Upsample(nn.Module):
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
in_channels,
in_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
def __call__(self, x: mx.array) -> mx.array:
@@ -124,10 +128,14 @@ def build_upsampling_path(
)
block_in = block_out
if curr_res in attn_resolutions:
stage["attn"][i_block] = make_attn(block_in, attn_type=attn_type, norm_type=norm_type)
stage["attn"][i_block] = make_attn(
block_in, attn_type=attn_type, norm_type=norm_type
)
if level != 0:
stage["upsample"] = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
stage["upsample"] = Upsample(
block_in, resamp_with_conv, causality_axis=causality_axis
)
curr_res *= 2
up_modules[level] = stage

View File

@@ -7,8 +7,8 @@ Supports:
"""
import math
from typing import List, Tuple
from pathlib import Path
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -32,7 +32,9 @@ class Snake(nn.Module):
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
self.alpha = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
def __call__(self, x: mx.array) -> mx.array:
# x: (N, L, C) in MLX format
@@ -48,8 +50,12 @@ class SnakeBeta(nn.Module):
def __init__(self, in_features: int, alpha_logscale: bool = True) -> None:
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
self.beta = mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
self.alpha = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
self.beta = (
mx.zeros((in_features,)) if alpha_logscale else mx.ones((in_features,))
)
def __call__(self, x: mx.array) -> mx.array:
alpha = self.alpha
@@ -73,7 +79,9 @@ def _sinc(x: mx.array) -> mx.array:
)
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> mx.array:
def kaiser_sinc_filter1d(
cutoff: float, half_width: float, kernel_size: int
) -> mx.array:
"""Compute a Kaiser-windowed sinc filter."""
even = kernel_size % 2 == 0
half_size = kernel_size // 2
@@ -88,6 +96,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
# Kaiser window - compute using scipy-compatible formula
import numpy as np
window = mx.array(np.kaiser(kernel_size, beta).astype(np.float32))
if even:
@@ -107,6 +116,7 @@ def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) ->
def hann_sinc_filter1d(ratio: int) -> Tuple[mx.array, int, int, int]:
"""Compute a Hann-windowed sinc filter for upsampling (used by BWE resampler)."""
import numpy as np
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
@@ -187,10 +197,16 @@ class UpSample1d(nn.Module):
self.kernel_size = filt.shape[2]
self.filter = filt
else:
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
self.pad_left = (
self.pad * self.stride + (self.kernel_size - self.stride) // 2
)
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
self.filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
@@ -215,10 +231,12 @@ class UpSample1d(nn.Module):
filt = self.filter.astype(x.dtype) # (1, 1, K)
filt = mx.transpose(filt, (0, 2, 1)) # (1, K, 1)
x = self.ratio * mx.conv_transpose1d(x, filt, stride=self.stride) # (N*C, L', 1)
x = self.ratio * mx.conv_transpose1d(
x, filt, stride=self.stride
) # (N*C, L', 1)
# Trim padding
x = x[:, self.pad_left:-self.pad_right, :]
x = x[:, self.pad_left : -self.pad_right, :]
x = x.reshape(n, c, -1) # (N, C, L')
x = mx.transpose(x, (0, 2, 1)) # (N, L', C)
@@ -285,16 +303,24 @@ class AMPBlock1(nn.Module):
self.convs1 = {
i: nn.Conv1d(
channels, channels, kernel_size, stride=1,
dilation=d, padding=get_padding(kernel_size, d),
channels,
channels,
kernel_size,
stride=1,
dilation=d,
padding=get_padding(kernel_size, d),
)
for i, d in enumerate(dilation)
}
self.convs2 = {
i: nn.Conv1d(
channels, channels, kernel_size, stride=1,
dilation=1, padding=get_padding(kernel_size, 1),
channels,
channels,
kernel_size,
stride=1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
for i in range(len(dilation))
}
@@ -348,7 +374,9 @@ class STFTFn(nn.Module):
y = mx.concatenate([first, y], axis=1)
# forward_basis: (514, 1, 512) PyTorch format -> (514, 512, 1) MLX
basis = mx.transpose(self.forward_basis.astype(y.dtype), (0, 2, 1)) # (514, K, 1)
basis = mx.transpose(
self.forward_basis.astype(y.dtype), (0, 2, 1)
) # (514, K, 1)
# Conv1d: (B, T, 1) * (514, K, 1) -> (B, T_frames, 514)
spec = mx.conv1d(y, basis, stride=self.hop_length)
@@ -358,8 +386,10 @@ class STFTFn(nn.Module):
real = spec[..., :n_freqs]
imag = spec[..., n_freqs:]
magnitude = mx.sqrt(real ** 2 + imag ** 2)
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(real.dtype)
magnitude = mx.sqrt(real**2 + imag**2)
phase = mx.arctan2(imag.astype(mx.float32), real.astype(mx.float32)).astype(
real.dtype
)
# Output: (B, T_frames, n_freqs) in MLX channels-last
return magnitude, phase
@@ -368,7 +398,9 @@ class STFTFn(nn.Module):
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram from precomputed STFT bases."""
def __init__(self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int) -> None:
def __init__(
self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int
) -> None:
super().__init__()
self.stft_fn = STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
@@ -385,7 +417,9 @@ class MelSTFT(nn.Module):
"""
magnitude, phase = self.stft_fn(y)
# magnitude: (B, T_frames, n_freqs)
mel = magnitude @ self.mel_basis.astype(magnitude.dtype).T # (B, T_frames, n_mels)
mel = (
magnitude @ self.mel_basis.astype(magnitude.dtype).T
) # (B, T_frames, n_mels)
log_mel = mx.log(mx.clip(mel, 1e-5, None))
# Transpose to (B, n_mels, T_frames) for compatibility with vocoder input format
return mx.transpose(log_mel, (0, 2, 1))
@@ -415,8 +449,11 @@ class Vocoder(nn.Module):
in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(
in_channels, config.upsample_initial_channel,
kernel_size=7, stride=1, padding=3,
in_channels,
config.upsample_initial_channel,
kernel_size=7,
stride=1,
padding=3,
)
# Upsampling layers
@@ -424,11 +461,13 @@ class Vocoder(nn.Module):
for i, (stride, kernel_size) in enumerate(
zip(config.upsample_rates, config.upsample_kernel_sizes)
):
in_ch = config.upsample_initial_channel // (2 ** i)
in_ch = config.upsample_initial_channel // (2**i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch, out_ch,
kernel_size=kernel_size, stride=stride,
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
@@ -442,7 +481,9 @@ class Vocoder(nn.Module):
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = AMPBlock1(
ch, kernel_size, tuple(dilations),
ch,
kernel_size,
tuple(dilations),
activation=config.activation,
)
block_idx += 1
@@ -455,10 +496,14 @@ class Vocoder(nn.Module):
for kernel_size, dilations in zip(
config.resblock_kernel_sizes, config.resblock_dilation_sizes
):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
self.resblocks[block_idx] = resblock_class(
ch, kernel_size, tuple(dilations)
)
block_idx += 1
final_channels = config.upsample_initial_channel // (2 ** len(config.upsample_rates))
final_channels = config.upsample_initial_channel // (
2 ** len(config.upsample_rates)
)
# Post-activation
if self.is_amp:
@@ -468,8 +513,11 @@ class Vocoder(nn.Module):
# Final conv
out_channels = 2 if config.stereo else 1
self.conv_post = nn.Conv1d(
final_channels, out_channels,
kernel_size=7, stride=1, padding=3,
final_channels,
out_channels,
kernel_size=7,
stride=1,
padding=3,
bias=config.use_bias_at_final,
)
@@ -588,7 +636,9 @@ class VocoderWithBWE(nn.Module):
"""
x = self.vocoder(mel_spec) # (B, C, T) at input_sampling_rate
_, _, length_low_rate = x.shape
output_length = length_low_rate * self.output_sampling_rate // self.input_sampling_rate
output_length = (
length_low_rate * self.output_sampling_rate // self.input_sampling_rate
)
# Pad to hop_length multiple
remainder = length_low_rate % self.hop_length
@@ -685,5 +735,3 @@ def _load_vocoder_with_bwe(config_dict: dict, weights: dict) -> VocoderWithBWE:
model.load_weights(list(weights.items()), strict=False)
return model

View File

@@ -1,3 +1,6 @@
"""Conditioning modules for LTX-2 video generation."""
from mlx_video.models.ltx_2.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.models.ltx_2.conditioning.latent import (
VideoConditionByLatentIndex,
apply_conditioning,
)

View File

@@ -5,7 +5,7 @@ the video generation process at specific frame positions.
"""
from dataclasses import dataclass
from typing import Optional, List, Tuple
from typing import List, Optional, Tuple
import mlx.core as mx
@@ -22,6 +22,7 @@ class VideoConditionByLatentIndex:
frame_idx: Frame index to condition (0 = first frame)
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
"""
latent: mx.array
frame_idx: int = 0
strength: float = 1.0
@@ -41,6 +42,7 @@ class LatentState:
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
1.0 = full denoise, 0.0 = keep clean
"""
latent: mx.array
clean_latent: mx.array
denoise_mask: mx.array
@@ -130,15 +132,15 @@ def apply_conditioning(
if frame_idx <= i < end_idx:
# Use conditioning latent
cond_idx = i - frame_idx
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
latent_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
clean_list.append(cond_latent[:, :, cond_idx : cond_idx + 1])
# Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
else:
# Keep original
latent_list.append(state.latent[:, :, i:i+1])
clean_list.append(state.clean_latent[:, :, i:i+1])
mask_list.append(state.denoise_mask[:, :, i:i+1])
latent_list.append(state.latent[:, :, i : i + 1])
clean_list.append(state.clean_latent[:, :, i : i + 1])
mask_list.append(state.denoise_mask[:, :, i : i + 1])
state.latent = mx.concatenate(latent_list, axis=2)
state.clean_latent = mx.concatenate(clean_list, axis=2)

View File

@@ -1,4 +1,3 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
@@ -22,9 +21,11 @@ class LTXRopeType(Enum):
SPLIT = "split"
TWO_D = "2d"
class AttentionType(Enum):
DEFAULT = "default"
@dataclass
class BaseModelConfig:
@@ -46,7 +47,7 @@ class BaseModelConfig:
if v is not None:
if isinstance(v, Enum):
result[k] = v.value
elif hasattr(v, 'to_dict'):
elif hasattr(v, "to_dict"):
result[k] = v.to_dict()
else:
result[k] = v
@@ -68,26 +69,30 @@ class VideoVAEConfig(BaseModelConfig):
out_channels: int = 128
latent_channels: int = 128
patch_size: int = 4
encoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
])
decoder_blocks: List[tuple] = field(default_factory=lambda: [
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
])
encoder_blocks: List[tuple] = field(
default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
)
decoder_blocks: List[tuple] = field(
default_factory=lambda: [
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
("compress_all", {"residual": True, "multiplier": 2}),
("res_x", {"num_layers": 5, "inject_noise": False}),
]
)
@dataclass
@@ -111,7 +116,9 @@ class LTXModelConfig(BaseModelConfig):
audio_in_channels: int = 128
audio_out_channels: int = 128
audio_cross_attention_dim: int = 2048
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
audio_caption_channels: int = (
3840 # Input dim for audio text embeddings (same as video)
)
# Positional embedding config
positional_embedding_theta: float = 10000.0
@@ -196,7 +203,6 @@ class LTXModelConfig(BaseModelConfig):
)
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
@@ -237,21 +243,22 @@ class AudioDecoderModelConfig(BaseModelConfig):
def __post_init__(self):
"""Convert string enum values to proper enum types."""
# Import here to avoid circular imports
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
from .audio_vae.normalization import NormType
# Convert causality_axis string to enum
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
# Convert norm_type string to enum
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
# Convert attn_type string to enum
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class AudioEncoderModelConfig(BaseModelConfig):
ch: int = 128
@@ -282,8 +289,8 @@ class AudioEncoderModelConfig(BaseModelConfig):
def __post_init__(self):
"""Convert string enum values to proper enum types."""
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
from .audio_vae.normalization import NormType
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
@@ -334,6 +341,7 @@ class VideoDecoderModelConfig(BaseModelConfig):
dropout: float = 0.0
timestep_conditioning: bool = False
@dataclass
class VideoEncoderModelConfig(BaseModelConfig):
convolution_dimensions: int = 3
@@ -343,21 +351,24 @@ class VideoEncoderModelConfig(BaseModelConfig):
norm_layer: Enum = None
latent_log_var: Enum = None
encoder_spatial_padding_mode: Enum = None
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2})
])
encoder_blocks: List[tuple] = field(
default_factory=lambda: [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
)
def __post_init__(self):
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM
@@ -371,10 +382,12 @@ class VideoEncoderModelConfig(BaseModelConfig):
if isinstance(self.latent_log_var, str):
self.latent_log_var = LogVarianceType(self.latent_log_var)
if isinstance(self.encoder_spatial_padding_mode, str):
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
self.encoder_spatial_padding_mode = PaddingModeType(
self.encoder_spatial_padding_mode
)
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.encoder_blocks is not None:
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
return result
return result

View File

@@ -49,7 +49,6 @@ from typing import Dict
import mlx.core as mx
# ─── Key prefix routing ──────────────────────────────────────────────────────
TRANSFORMER_PREFIX = "model.diffusion_model."
@@ -78,7 +77,7 @@ def sanitize_transformer(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
new_key = key[len(TRANSFORMER_PREFIX):]
new_key = key[len(TRANSFORMER_PREFIX) :]
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
@@ -109,7 +108,7 @@ def sanitize_vae_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
else:
continue
elif key.startswith(VAE_DECODER_PREFIX):
new_key = key[len(VAE_DECODER_PREFIX):]
new_key = key[len(VAE_DECODER_PREFIX) :]
else:
continue
@@ -147,7 +146,7 @@ def sanitize_vae_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if value.dtype != mx.float32:
value = value.astype(mx.float32)
elif key.startswith(VAE_ENCODER_PREFIX):
new_key = key[len(VAE_ENCODER_PREFIX):]
new_key = key[len(VAE_ENCODER_PREFIX) :]
else:
continue
@@ -170,7 +169,7 @@ def sanitize_audio_decoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
new_key = None
if key.startswith(AUDIO_DECODER_PREFIX):
new_key = key[len(AUDIO_DECODER_PREFIX):]
new_key = key[len(AUDIO_DECODER_PREFIX) :]
elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
@@ -196,7 +195,7 @@ def sanitize_audio_encoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
new_key = None
if key.startswith(AUDIO_ENCODER_PREFIX):
new_key = key[len(AUDIO_ENCODER_PREFIX):]
new_key = key[len(AUDIO_ENCODER_PREFIX) :]
elif key.startswith(AUDIO_STATS_PREFIX):
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
@@ -226,7 +225,7 @@ def sanitize_vocoder(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if not key.startswith(VOCODER_PREFIX):
continue
new_key = key[len(VOCODER_PREFIX):]
new_key = key[len(VOCODER_PREFIX) :]
# Handle Conv1d/ConvTranspose1d weight shape conversion
if "weight" in new_key and value.ndim == 3:
@@ -260,20 +259,20 @@ def extract_text_projections(weights: Dict[str, mx.array]) -> Dict[str, mx.array
# aggregate_embed weights (text_embedding_projection.*)
for key, value in weights.items():
if key.startswith(TEXT_PROJ_PREFIX):
new_key = key[len(TEXT_PROJ_PREFIX):]
new_key = key[len(TEXT_PROJ_PREFIX) :]
extracted[new_key] = value
# video_embeddings_connector
for key, value in weights.items():
if key.startswith(VIDEO_CONNECTOR_PREFIX):
suffix = key[len(VIDEO_CONNECTOR_PREFIX):]
suffix = key[len(VIDEO_CONNECTOR_PREFIX) :]
new_key = "video_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value
# audio_embeddings_connector
for key, value in weights.items():
if key.startswith(AUDIO_CONNECTOR_PREFIX):
suffix = key[len(AUDIO_CONNECTOR_PREFIX):]
suffix = key[len(AUDIO_CONNECTOR_PREFIX) :]
new_key = "audio_embeddings_connector." + sanitize_connector_key(suffix)
extracted[new_key] = value
@@ -369,11 +368,15 @@ def save_config(config: dict, output_dir: Path):
# ─── Source resolution ─────────────────────────────────────────────────────────
# Matches monolithic model files: ltx-2-19b-distilled.safetensors, ltx-2.3-22b-dev.safetensors, etc.
MONOLITHIC_PATTERN = re.compile(r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$")
MONOLITHIC_PATTERN = re.compile(
r"^ltx-[\d.]+-\d+b-(?P<variant>distilled|dev)\.safetensors$"
)
# Matches upscaler files like ltx-2-spatial-upscaler-x2-1.0.safetensors,
# ltx-2.3-spatial-upscaler-x2-1.0.safetensors, etc.
UPSCALER_PATTERN = re.compile(r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$")
UPSCALER_PATTERN = re.compile(
r"^ltx-[\d.]+-(?:spatial|temporal)-upscaler-.+\.safetensors$"
)
def resolve_source(source: str, variant: str) -> Path:
@@ -506,7 +509,9 @@ def infer_transformer_config(weights: Dict[str, mx.array]) -> dict:
def infer_vae_decoder_config(weights: Dict[str, mx.array], variant: str) -> dict:
"""Infer VAE decoder config from weights."""
# Check for timestep conditioning keys
has_timestep = any("last_time_embedder" in k or "last_scale_shift_table" in k for k in weights)
has_timestep = any(
"last_time_embedder" in k or "last_scale_shift_table" in k for k in weights
)
# Count channel multipliers from up_blocks
max_block = -1
@@ -658,7 +663,9 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
config = infer_transformer_config(transformer_weights)
save_config(config, output_path / "transformer")
t_params = sum(v.size for v in transformer_weights.values())
print(f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards")
print(
f" {len(transformer_weights)} keys, {t_params:,} params, {num_shards} shards"
)
# 2. VAE Decoder
print(" [2/7] VAE Decoder...")
@@ -728,7 +735,8 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
]
else:
upscaler_files = [
f.name for f in source_dir.iterdir()
f.name
for f in source_dir.iterdir()
if f.is_file() and UPSCALER_PATTERN.match(f.name)
]
@@ -800,12 +808,21 @@ def convert(source: str, output_path: Path, variant: str = "distilled"):
print(f"\nDone! Converted {all_converted}/{total_keys} keys")
if all_converted < total_keys:
known_prefixes = (
TRANSFORMER_PREFIX, VAE_DECODER_PREFIX, VAE_ENCODER_PREFIX,
VAE_STATS_PREFIX, AUDIO_DECODER_PREFIX, AUDIO_ENCODER_PREFIX,
AUDIO_STATS_PREFIX, VOCODER_PREFIX, TEXT_PROJ_PREFIX,
VIDEO_CONNECTOR_PREFIX, AUDIO_CONNECTOR_PREFIX,
TRANSFORMER_PREFIX,
VAE_DECODER_PREFIX,
VAE_ENCODER_PREFIX,
VAE_STATS_PREFIX,
AUDIO_DECODER_PREFIX,
AUDIO_ENCODER_PREFIX,
AUDIO_STATS_PREFIX,
VOCODER_PREFIX,
TEXT_PROJ_PREFIX,
VIDEO_CONNECTOR_PREFIX,
AUDIO_CONNECTOR_PREFIX,
)
skipped = [k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)]
skipped = [
k for k in all_weights if not any(k.startswith(p) for p in known_prefixes)
]
if skipped:
print(f" Skipped {len(skipped)} keys:")
for k in sorted(skipped)[:20]:

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +1,14 @@
from pathlib import Path
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.config import (
LTXModelConfig,
LTXModelType,
LTXRopeType,
TransformerConfig,
)
from mlx_video.models.ltx_2.adaln import AdaLayerNormSingle
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
from mlx_video.models.ltx_2.text_projection import PixArtAlphaTextProjection
from mlx_video.models.ltx_2.transformer import (
@@ -58,11 +57,17 @@ class TransformerArgsPreprocessor:
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb, embedded_timestep = self.adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
# Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
timestep_emb = mx.reshape(
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
)
embedded_timestep = mx.reshape(
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
)
return timestep_emb, embedded_timestep
@@ -74,9 +79,15 @@ class TransformerArgsPreprocessor:
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
embedded_timestep = mx.reshape(embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1]))
timestep_emb, embedded_timestep = adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
timestep_emb = mx.reshape(
timestep_emb, (batch_size, -1, timestep_emb.shape[-1])
)
embedded_timestep = mx.reshape(
embedded_timestep, (batch_size, -1, embedded_timestep.shape[-1])
)
return timestep_emb, embedded_timestep
def _prepare_context(
@@ -107,7 +118,9 @@ class TransformerArgsPreprocessor:
# Convert boolean/int mask to float mask
# 0 -> -inf (masked), 1 -> 0 (not masked)
mask = (attention_mask.astype(x_dtype) - 1) * 1e9
mask = mx.reshape(mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
mask = mx.reshape(
mask, (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
)
return mask
def _prepare_positional_embeddings(
@@ -132,9 +145,15 @@ class TransformerArgsPreprocessor:
def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
timestep, embedded_timestep = self._prepare_timestep(
modality.timesteps, x.shape[0], hidden_dtype=x.dtype
)
context, attention_mask = self._prepare_context(
modality.context, x, modality.context_mask
)
attention_mask = self._prepare_attention_mask(
attention_mask, modality.latent.dtype
)
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
if modality.positional_embeddings is not None:
@@ -152,8 +171,13 @@ class TransformerArgsPreprocessor:
prompt_timestep = None
prompt_embedded_timestep = None
if self.prompt_adaln is not None and modality.sigma is not None:
prompt_timestep, prompt_embedded_timestep = self._prepare_timestep_with_adaln(
self.prompt_adaln, modality.sigma, x.shape[0], hidden_dtype=x.dtype,
prompt_timestep, prompt_embedded_timestep = (
self._prepare_timestep_with_adaln(
self.prompt_adaln,
modality.sigma,
x.shape[0],
hidden_dtype=x.dtype,
)
)
return TransformerArgs(
@@ -229,11 +253,13 @@ class MultiModalTransformerArgsPreprocessor:
)
# Prepare cross-attention timestep embeddings
cross_scale_shift_timestep, cross_gate_timestep = self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
cross_scale_shift_timestep, cross_gate_timestep = (
self._prepare_cross_attention_timestep(
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
)
)
return replace(
@@ -254,17 +280,25 @@ class MultiModalTransformerArgsPreprocessor:
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
scale_shift_timestep, _ = self.cross_scale_shift_adaln(
timestep.reshape(-1), hidden_dtype=hidden_dtype
)
scale_shift_timestep = mx.reshape(
scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])
)
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(
timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype
)
gate_timestep = mx.reshape(
gate_timestep, (batch_size, -1, gate_timestep.shape[-1])
)
return scale_shift_timestep, gate_timestep
class LTXModel(nn.Module):
def __init__(self, config: LTXModelConfig):
super().__init__()
@@ -285,18 +319,25 @@ class LTXModel(nn.Module):
self._init_video(config)
if config.model_type.is_audio_enabled():
self.audio_positional_embedding_max_pos = config.audio_positional_embedding_max_pos
self.audio_positional_embedding_max_pos = (
config.audio_positional_embedding_max_pos
)
self.audio_num_attention_heads = config.audio_num_attention_heads
self.audio_inner_dim = config.audio_inner_dim
self._init_audio(config)
# Initialize cross-modal components
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
if (
config.model_type.is_video_enabled()
and config.model_type.is_audio_enabled()
):
cross_pe_max_pos = max(
config.positional_embedding_max_pos[0],
config.audio_positional_embedding_max_pos[0],
)
self.av_ca_timestep_scale_multiplier = config.av_ca_timestep_scale_multiplier
self.av_ca_timestep_scale_multiplier = (
config.av_ca_timestep_scale_multiplier
)
self.audio_cross_attention_dim = config.audio_cross_attention_dim
self._init_audio_video(config)
@@ -308,10 +349,14 @@ class LTXModel(nn.Module):
self.patchify_proj = nn.Linear(config.in_channels, self.inner_dim, bias=True)
adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=adaln_coefficient)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=adaln_coefficient
)
if config.has_prompt_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(self.inner_dim, embedding_coefficient=2)
self.prompt_adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=2
)
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=config.caption_channels,
@@ -323,13 +368,19 @@ class LTXModel(nn.Module):
self.proj_out = nn.Linear(self.inner_dim, config.out_channels)
def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_patchify_proj = nn.Linear(
config.audio_in_channels, self.audio_inner_dim, bias=True
)
audio_adaln_coefficient = 9 if config.has_prompt_adaln else 6
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient)
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim, embedding_coefficient=audio_adaln_coefficient
)
if config.has_prompt_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(self.audio_inner_dim, embedding_coefficient=2)
self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim, embedding_coefficient=2
)
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels,
@@ -338,7 +389,9 @@ class LTXModel(nn.Module):
# Output components
self.audio_scale_shift_table = mx.zeros((2, self.audio_inner_dim))
self.audio_norm_out = nn.LayerNorm(self.audio_inner_dim, eps=config.norm_eps, affine=False)
self.audio_norm_out = nn.LayerNorm(
self.audio_inner_dim, eps=config.norm_eps, affine=False
)
self.audio_proj_out = nn.Linear(self.audio_inner_dim, config.audio_out_channels)
def _init_audio_video(self, config: LTXModelConfig) -> None:
@@ -361,8 +414,13 @@ class LTXModel(nn.Module):
embedding_coefficient=1,
)
def _init_preprocessors(self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]) -> None:
if config.model_type.is_video_enabled() and config.model_type.is_audio_enabled():
def _init_preprocessors(
self, config: LTXModelConfig, cross_pe_max_pos: Optional[int]
) -> None:
if (
config.model_type.is_video_enabled()
and config.model_type.is_audio_enabled()
):
# Multi-modal preprocessors
self.video_args_preprocessor = MultiModalTransformerArgsPreprocessor(
patchify_proj=self.patchify_proj,
@@ -468,7 +526,8 @@ class LTXModel(nn.Module):
stg_a_set = set(stg_audio_blocks) if stg_audio_blocks else set()
for idx, block in self.transformer_blocks.items():
video, audio = block(
video=video, audio=audio,
video=video,
audio=audio,
skip_video_self_attn=(idx in stg_v_set),
skip_audio_self_attn=(idx in stg_a_set),
skip_cross_modal=skip_cross_modal,
@@ -483,7 +542,7 @@ class LTXModel(nn.Module):
x: mx.array,
embedded_timestep: mx.array,
) -> mx.array:
# scale_shift_table: (2, dim) -> expand to (1, 1, 2, dim)
# embedded_timestep: (B, 1, dim) -> expand to (B, 1, 1, dim)
table_expanded = scale_shift_table[None, None, :, :] # (1, 1, 2, dim)
@@ -526,8 +585,12 @@ class LTXModel(nn.Module):
raise ValueError("Audio is not enabled for this model")
# Preprocess arguments
video_args = self.video_args_preprocessor.prepare(video) if video is not None else None
audio_args = self.audio_args_preprocessor.prepare(audio) if audio is not None else None
video_args = (
self.video_args_preprocessor.prepare(video) if video is not None else None
)
audio_args = (
self.audio_args_preprocessor.prepare(audio) if audio is not None else None
)
# Process transformer blocks
video_out, audio_out = self._process_transformer_blocks(
@@ -567,7 +630,7 @@ class LTXModel(nn.Module):
def sanitize(self, weights: dict) -> dict:
sanitized = {}
has_raw_prefix = any(k.startswith("model.diffusion_model.") for k in weights)
if not has_raw_prefix:
return weights
@@ -577,7 +640,10 @@ class LTXModel(nn.Module):
if not key.startswith("model.diffusion_model."):
continue
if "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
if (
"audio_embeddings_connector" in key
or "video_embeddings_connector" in key
):
continue
# Remove 'model.diffusion_model.' prefix
@@ -612,9 +678,11 @@ class LTXModel(nn.Module):
for weight_file in model_path.glob("*.safetensors"):
weights.update(mx.load(str(weight_file)))
sanitized = model.sanitize(weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
sanitized = {
k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v
for k, v in sanitized.items()
}
model.load_weights(list(sanitized.items()), strict=strict)
mx.eval(model.parameters())
@@ -625,7 +693,7 @@ class LTXModel(nn.Module):
class X0Model(nn.Module):
def __init__(self, velocity_model: LTXModel):
super().__init__()
self.velocity_model = velocity_model
@@ -639,13 +707,18 @@ class X0Model(nn.Module):
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
vx, ax = self.velocity_model(
video, audio,
video,
audio,
stg_video_blocks=stg_video_blocks,
stg_audio_blocks=stg_audio_blocks,
skip_cross_modal=skip_cross_modal,
)
denoised_video = to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
denoised_audio = to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
denoised_video = (
to_denoised(video.latent, vx, video.timesteps) if vx is not None else None
)
denoised_audio = (
to_denoised(audio.latent, ax, audio.timesteps) if ax is not None else None
)
return denoised_video, denoised_audio

View File

@@ -1,9 +1,10 @@
import numpy as np
from typing import Optional
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
def bilateral_filter(
image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75
) -> np.ndarray:
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
Args:
@@ -17,6 +18,7 @@ def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sig
"""
try:
import cv2
return cv2.bilateralFilter(image, d, sigma_color, sigma_space)
except ImportError:
# Fallback to simple Gaussian blur if cv2 not available
@@ -35,14 +37,20 @@ def gaussian_blur(image: np.ndarray, kernel_size: int = 3) -> np.ndarray:
"""
try:
import cv2
return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
except ImportError:
# Simple box blur fallback
from scipy.ndimage import uniform_filter
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(np.uint8)
return uniform_filter(image, size=(kernel_size, kernel_size, 1)).astype(
np.uint8
)
def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0) -> np.ndarray:
def unsharp_mask(
image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, amount: float = 1.0
) -> np.ndarray:
"""Apply unsharp masking to enhance edges after blur.
Args:
@@ -56,6 +64,7 @@ def unsharp_mask(image: np.ndarray, kernel_size: int = 5, sigma: float = 1.0, am
"""
try:
import cv2
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
sharpened = cv2.addWeighted(image, 1 + amount, blurred, -amount, 0)
return np.clip(sharpened, 0, 255).astype(np.uint8)
@@ -81,23 +90,23 @@ def reduce_grid_artifacts(
if method == "bilateral":
d = max(3, int(5 * strength))
sigma = 50 + 50 * strength
processed = np.stack([
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
for frame in video
])
processed = np.stack(
[
bilateral_filter(frame, d=d, sigma_color=sigma, sigma_space=sigma)
for frame in video
]
)
elif method == "gaussian":
kernel_size = max(3, int(3 + 4 * strength))
if kernel_size % 2 == 0:
kernel_size += 1
processed = np.stack([
gaussian_blur(frame, kernel_size=kernel_size)
for frame in video
])
processed = np.stack(
[gaussian_blur(frame, kernel_size=kernel_size) for frame in video]
)
elif method == "frequency":
processed = np.stack([
remove_grid_frequency(frame, grid_size=8)
for frame in video
])
processed = np.stack(
[remove_grid_frequency(frame, grid_size=8) for frame in video]
)
else:
raise ValueError(f"Unknown method: {method}")
@@ -160,6 +169,3 @@ def remove_grid_frequency(frame: np.ndarray, grid_size: int = 8) -> np.ndarray:
result[:, :, c] = np.clip(channel_filtered, 0, 255).astype(np.uint8)
return result

View File

@@ -1,4 +1,3 @@
import math
from typing import List, Optional, Tuple
@@ -86,11 +85,12 @@ def rotate_half_interleaved(x: mx.array) -> mx.array:
"""
# x: (..., dim) where dim is even
x_even = x[..., 0::2] # [x0, x2, x4, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
x_odd = x[..., 1::2] # [x1, x3, x5, ...]
# Stack: [[-x1, x0], [-x3, x2], ...] then flatten to [-x1, x0, -x3, x2, ...]
rotated = mx.stack([-x_odd, x_even], axis=-1)
return mx.reshape(rotated, x.shape)
def apply_rotary_emb_1d(
q: mx.array,
k: mx.array,
@@ -228,9 +228,9 @@ def get_fractional_positions(
Fractional positions in range [-1, 1] after scaling
"""
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), (
f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
)
assert n_pos_dims == len(
max_pos
), f"Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})"
# Divide each dimension by its max position
fractional_positions = []
@@ -392,11 +392,15 @@ def precompute_freqs_cis(
if max_pos is None:
max_pos = [20, 2048, 2048]
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
num_attention_heads, rope_type
indices_grid,
dim,
theta,
max_pos,
use_middle_indices_grid,
num_attention_heads,
rope_type,
)
# Keep positions in float32 for RoPE computation.
@@ -495,7 +499,9 @@ def _precompute_freqs_cis_double_precision(
# Compute frequencies: outer product
# scaled_positions: (B, T, n_dims) -> (B, T, n_dims, 1)
# freq_indices: (num_indices,) -> (1, 1, 1, num_indices)
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(freq_indices, (1, 1, 1, -1))
freqs = mx.expand_dims(scaled_positions, axis=-1) * mx.reshape(
freq_indices, (1, 1, 1, -1)
)
# freqs: (B, T, n_dims, num_indices)
# Transpose and flatten: (B, T, n_dims, num_indices) -> (B, T, num_indices, n_dims) -> (B, T, num_indices * n_dims)

View File

@@ -5,15 +5,14 @@ noise injection, ported from the LTX-2 PyTorch implementation.
"""
import math
from typing import Optional
import mlx.core as mx
# ---------------------------------------------------------------------------
# Phi functions and RK coefficients (pure Python math, no MLX needed)
# ---------------------------------------------------------------------------
def phi(j: int, neg_h: float) -> float:
"""Compute phi_j(z) where z = -h (negative step size in log-space).
@@ -43,6 +42,7 @@ def get_res2s_coefficients(
Returns:
(a21, b1, b2): RK coefficients.
"""
def get_phi(j: int, neg_h: float) -> float:
cache_key = (j, neg_h)
if cache_key in phi_cache:
@@ -69,6 +69,7 @@ def get_res2s_coefficients(
# SDE noise injection
# ---------------------------------------------------------------------------
def get_sde_coeff(
sigma_next: float,
) -> tuple[float, float, float]:
@@ -139,7 +140,9 @@ def sde_noise_step(
denoised_next = sample_f32 - sigma * eps_next
# Mix deterministic and stochastic components
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
x_noised = (
alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
)
return x_noised
@@ -148,6 +151,7 @@ def sde_noise_step(
# Noise generation
# ---------------------------------------------------------------------------
def channelwise_normalize(x: mx.array) -> mx.array:
"""Normalize each channel to zero mean and unit variance over spatial dims.

View File

@@ -1,25 +1,25 @@
import functools
import logging
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn
from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx_2.rope import apply_interleaved_rotary_emb
from mlx_vlm.models.gemma3.language import Gemma3Model
from mlx_vlm.models.gemma3.config import TextConfig
from mlx_vlm.models.gemma3.language import Gemma3Model
from rich.console import Console
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeRemainingColumn,
)
from mlx_video.utils import apply_quantization, rms_norm
# Path to system prompts
PROMPTS_DIR = Path(__file__).parent / "prompts"
@@ -36,11 +36,10 @@ def _load_system_prompt(prompt_name: str) -> str:
class LanguageModel(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
# Create config matching LTX-2 text encoder requirements
self.config = config
self.config = config
# Create the Gemma3Model from mlx-vlm
self.model = Gemma3Model(self.config)
@@ -51,7 +50,7 @@ class LanguageModel(nn.Module):
attention_mask: Optional[mx.array],
dtype: mx.Dtype,
) -> mx.array:
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
if attention_mask is not None:
@@ -59,15 +58,25 @@ class LanguageModel(nn.Module):
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
mx.full(combined.shape, min_val, dtype=dtype))
min_val = (
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
)
mask = mx.where(
combined,
mx.zeros(combined.shape, dtype=dtype),
mx.full(combined.shape, min_val, dtype=dtype),
)
return mask[:, None, :, :]
else:
# No padding mask, just causal
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
mx.full((seq_len, seq_len), min_val, dtype=dtype))
min_val = (
mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
)
mask = mx.where(
causal_mask,
mx.zeros((seq_len, seq_len), dtype=dtype),
mx.full((seq_len, seq_len), min_val, dtype=dtype),
)
return mask[None, None, :, :] # (1, 1, seq, seq)
def __call__(
@@ -91,7 +100,11 @@ class LanguageModel(nn.Module):
batch_size, seq_len = inputs.shape
# Get embeddings
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
h = (
input_embeddings
if input_embeddings is not None
else self.model.embed_tokens(inputs)
)
# Apply Gemma scaling
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
@@ -103,11 +116,12 @@ class LanguageModel(nn.Module):
if cache is None:
cache = [None] * len(self.model.layers)
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
full_causal_mask = self._create_causal_mask_with_padding(
seq_len, attention_mask, h.dtype
)
sliding_mask = full_causal_mask
num_layers = len(self.model.layers)
for i, layer in enumerate(self.model.layers):
is_global = (
@@ -147,9 +161,9 @@ class LanguageModel(nn.Module):
for key, value in weights.items():
if key.startswith(prefix):
if hasattr(value, "dtype") and value.dtype == mx.float32:
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16)
else:
sanitized[key[len(prefix):]] = value
sanitized[key[len(prefix) :]] = value
return sanitized
@property
@@ -158,6 +172,7 @@ class LanguageModel(nn.Module):
def make_cache(self):
from mlx_vlm.models.cache import KVCache, RotatingKVCache
caches = []
for i in range(len(self.layers)):
if (
@@ -172,6 +187,7 @@ class LanguageModel(nn.Module):
@classmethod
def from_pretrained(cls, model_path: str):
import json
weight_files = sorted(Path(model_path).glob("*.safetensors"))
config_file = Path(model_path) / "config.json"
config_dict = {}
@@ -179,7 +195,9 @@ class LanguageModel(nn.Module):
with open(config_file, "r") as f:
config_dict = json.load(f)
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
language_model = cls(
config=TextConfig.from_dict(config_dict["text_config"])
)
else:
raise ValueError(f"Config file not found at {model_path}")
@@ -188,19 +206,18 @@ class LanguageModel(nn.Module):
for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights)
apply_quantization(model=language_model, weights=weights, quantization=quantization)
apply_quantization(
model=language_model, weights=weights, quantization=quantization
)
language_model.load_weights(list(weights.items()), strict=False)
return language_model
class ConnectorAttention(nn.Module):
def __init__(
@@ -250,9 +267,15 @@ class ConnectorAttention(nn.Module):
k = self.k_norm(k)
# Reshape to (B, H, T, D) for SPLIT RoPE
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
q = mx.reshape(
q, (batch_size, seq_len, self.num_heads, self.head_dim)
).transpose(0, 2, 1, 3)
k = mx.reshape(
k, (batch_size, seq_len, self.num_heads, self.head_dim)
).transpose(0, 2, 1, 3)
v = mx.reshape(
v, (batch_size, seq_len, self.num_heads, self.head_dim)
).transpose(0, 2, 1, 3)
if pe is not None:
q = self._apply_split_rope(q, pe[0], pe[1])
@@ -304,7 +327,7 @@ class ConnectorAttention(nn.Module):
out2 = x2 * cos_freq + x1 * sin_freq
return mx.concatenate([out1, out2], axis=-1).astype(input_dtype)
class GEGLU(nn.Module):
"""GELU-gated linear unit."""
@@ -336,9 +359,17 @@ class ConnectorFeedForward(nn.Module):
class ConnectorTransformerBlock(nn.Module):
def __init__(self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False):
def __init__(
self,
dim: int = 3840,
num_heads: int = 30,
head_dim: int = 128,
has_gate_logits: bool = False,
):
super().__init__()
self.attn1 = ConnectorAttention(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
self.attn1 = ConnectorAttention(
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
)
self.ff = ConnectorFeedForward(dim)
def __call__(
@@ -388,14 +419,18 @@ class Embeddings1DConnector(nn.Module):
self.positional_embedding_max_pos = positional_embedding_max_pos or [1]
self.transformer_1d_blocks = {
i: ConnectorTransformerBlock(dim, num_heads, head_dim, has_gate_logits=has_gate_logits)
i: ConnectorTransformerBlock(
dim, num_heads, head_dim, has_gate_logits=has_gate_logits
)
for i in range(num_layers)
}
if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
def _precompute_freqs_cis(
self, seq_len: int, dtype: mx.Dtype
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
@@ -464,11 +499,15 @@ class Embeddings1DConnector(nn.Module):
# Binary mask: 1 for valid tokens, 0 for padded
# attention_mask is additive: 0 for valid, large negative for padded
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(
mx.int32
) # (batch, seq)
# Tile registers to match sequence length, cast to hidden_states dtype
num_tiles = seq_len // self.num_learnable_registers
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(
dtype
) # (seq_len, dim)
# Process each batch item (PyTorch uses advanced indexing)
result_list = []
@@ -481,25 +520,33 @@ class Embeddings1DConnector(nn.Module):
# Extract valid tokens (where mask is 1)
# Since we have left-padded input, valid tokens are at the end
valid_tokens = hs_b[seq_len - num_valid:] # (num_valid, dim)
valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim)
# Pad with zeros on the right to get back to seq_len
pad_length = seq_len - num_valid
if pad_length > 0:
padding = mx.zeros((pad_length, dim), dtype=dtype)
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
adjusted = mx.concatenate(
[valid_tokens, padding], axis=0
) # (seq_len, dim)
else:
adjusted = valid_tokens
# Create flipped mask: 1s at front (where valid tokens now are), 0s at back
flipped_mask = mx.concatenate([
mx.ones((num_valid,), dtype=mx.int32),
mx.zeros((pad_length,), dtype=mx.int32)
], axis=0) # (seq,)
flipped_mask = mx.concatenate(
[
mx.ones((num_valid,), dtype=mx.int32),
mx.zeros((pad_length,), dtype=mx.int32),
],
axis=0,
) # (seq,)
# Combine: valid tokens at front, registers at back
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
combined = (
flipped_mask_expanded * adjusted
+ (1 - flipped_mask_expanded) * registers
)
result_list.append(combined)
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
@@ -526,7 +573,9 @@ class Embeddings1DConnector(nn.Module):
# Process through transformer blocks
for i in range(len(self.transformer_1d_blocks)):
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
hidden_states = self.transformer_1d_blocks[i](
hidden_states, attention_mask, freqs_cis
)
# Final RMS norm
hidden_states = rms_norm(hidden_states)
@@ -534,7 +583,6 @@ class Embeddings1DConnector(nn.Module):
return hidden_states, attention_mask
def norm_and_concat_hidden_states(
hidden_states: List[mx.array],
attention_mask: mx.array,
@@ -567,8 +615,12 @@ def norm_and_concat_hidden_states(
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
x_for_min = mx.where(
mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype)
)
x_for_max = mx.where(
mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype)
)
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min
@@ -603,7 +655,9 @@ def norm_and_concat_per_token_rms(
dtype = encoded_text.dtype
# Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D
variance = mx.mean(encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True) # (B, T, 1, L)
variance = mx.mean(
encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True
) # (B, T, 1, L)
normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6)
normed = normed.astype(dtype)
@@ -625,7 +679,9 @@ def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array:
class GemmaFeaturesExtractor(nn.Module):
"""V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization."""
def __init__(self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False):
def __init__(
self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False
):
super().__init__()
self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias)
@@ -674,13 +730,14 @@ class GemmaFeaturesExtractorV2(nn.Module):
if mode == "video":
target_dim = self.video_aggregate_embed.weight.shape[0]
return self.video_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
return self.video_aggregate_embed(
_rescale_norm(normed, target_dim, self.embedding_dim)
)
else:
target_dim = self.audio_aggregate_embed.weight.shape[0]
return self.audio_aggregate_embed(_rescale_norm(normed, target_dim, self.embedding_dim))
return self.audio_aggregate_embed(
_rescale_norm(normed, target_dim, self.embedding_dim)
)
class AudioEmbeddingsConnector(nn.Module):
@@ -717,8 +774,8 @@ class LTX2TextEncoder(nn.Module):
video_output_dim = 4096
audio_output_dim = 2048
self.feature_extractor_v2 = GemmaFeaturesExtractorV2(
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated)
embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale)
video_output_dim=video_output_dim,
audio_output_dim=audio_output_dim,
bias=True,
@@ -728,37 +785,57 @@ class LTX2TextEncoder(nn.Module):
# connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors
# config (nested under config.transformer.connector_positional_embedding_max_pos)
self.video_embeddings_connector = Embeddings1DConnector(
dim=video_output_dim, num_heads=32, head_dim=128,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[4096], has_gate_logits=True,
dim=video_output_dim,
num_heads=32,
head_dim=128,
num_layers=8,
num_learnable_registers=128,
positional_embedding_max_pos=[4096],
has_gate_logits=True,
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=audio_output_dim, num_heads=32, head_dim=64,
num_layers=8, num_learnable_registers=128,
positional_embedding_max_pos=[4096], has_gate_logits=True,
dim=audio_output_dim,
num_heads=32,
head_dim=64,
num_layers=8,
num_learnable_registers=128,
positional_embedding_max_pos=[4096],
has_gate_logits=True,
)
else:
# LTX-2: shared feature extractor, 3840-dim connectors
self.feature_extractor = GemmaFeaturesExtractor(feature_input_dim, hidden_dim)
self.feature_extractor = GemmaFeaturesExtractor(
feature_input_dim, hidden_dim
)
self.video_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128,
num_layers=2, num_learnable_registers=128,
dim=hidden_dim,
num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[1],
)
self.audio_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim, num_heads=30, head_dim=128,
num_layers=2, num_learnable_registers=128,
dim=hidden_dim,
num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[1],
)
self.processor = None
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
def load(
self,
model_path: Optional[str] = None,
text_encoder_path: Optional[str] = "google/gemma-3-12b-it",
):
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
# Load transformer weights for feature extractor and connector.
@@ -785,22 +862,35 @@ class LTX2TextEncoder(nn.Module):
if transformer_weights:
self._load_feature_extractors(transformer_weights, is_reformatted)
self._load_connector("video_embeddings_connector", transformer_weights, is_reformatted)
self._load_connector("audio_embeddings_connector", transformer_weights, is_reformatted)
self._load_connector(
"video_embeddings_connector", transformer_weights, is_reformatted
)
self._load_connector(
"audio_embeddings_connector", transformer_weights, is_reformatted
)
else:
print("WARNING: No transformer weights found for text projection connectors. "
"Text conditioning will use uninitialized weights!")
print(
"WARNING: No transformer weights found for text projection connectors. "
"Text conditioning will use uninitialized weights!"
)
# Load tokenizer
from transformers import AutoTokenizer
tokenizer_path = model_path / "tokenizer"
if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
self.processor = AutoTokenizer.from_pretrained(
str(tokenizer_path), trust_remote_code=True
)
else:
try:
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
self.processor = AutoTokenizer.from_pretrained(
text_encoder_path, trust_remote_code=True
)
except Exception:
self.processor = AutoTokenizer.from_pretrained("google/gemma-3-12b-it", trust_remote_code=True)
self.processor = AutoTokenizer.from_pretrained(
"google/gemma-3-12b-it", trust_remote_code=True
)
# Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left"
@@ -823,7 +913,11 @@ class LTX2TextEncoder(nn.Module):
submodule.bias = weights[b_key]
else:
# LTX-2: single aggregate_embed
agg_key = "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight"
agg_key = (
"aggregate_embed.weight"
if is_reformatted
else "text_embedding_projection.aggregate_embed.weight"
)
if agg_key in weights:
self.feature_extractor.aggregate_embed.weight = weights[agg_key]
@@ -837,12 +931,12 @@ class LTX2TextEncoder(nn.Module):
prefix = f"{name}."
for key, value in weights.items():
if key.startswith(prefix):
connector_weights[key[len(prefix):]] = value
connector_weights[key[len(prefix) :]] = value
else:
mono_prefix = f"model.diffusion_model.{name}."
for key, value in weights.items():
if key.startswith(mono_prefix):
connector_weights[key[len(mono_prefix):]] = value
connector_weights[key[len(mono_prefix) :]] = value
if not connector_weights:
return
@@ -894,21 +988,36 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
_, all_hidden_states = self.language_model(
inputs=input_ids,
input_embeddings=None,
attention_mask=attention_mask,
output_hidden_states=True,
)
if self.has_prompt_adaln:
# LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale)
video_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="video")
video_features = self.feature_extractor_v2(
all_hidden_states, attention_mask, mode="video"
)
additive_mask = (attention_mask - 1).astype(video_features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
additive_mask = (
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
)
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
video_embeddings, _ = self.video_embeddings_connector(
video_features, additive_mask
)
if return_audio_embeddings:
audio_features = self.feature_extractor_v2(all_hidden_states, attention_mask, mode="audio")
audio_features = self.feature_extractor_v2(
all_hidden_states, attention_mask, mode="audio"
)
audio_mask = (attention_mask - 1).astype(audio_features.dtype)
audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
audio_embeddings, _ = self.audio_embeddings_connector(audio_features, audio_mask)
audio_embeddings, _ = self.audio_embeddings_connector(
audio_features, audio_mask
)
return video_embeddings, audio_embeddings
else:
return video_embeddings, attention_mask
@@ -920,12 +1029,18 @@ class LTX2TextEncoder(nn.Module):
video_features = self.feature_extractor(concat_hidden)
additive_mask = (attention_mask - 1).astype(video_features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
additive_mask = (
additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
)
video_embeddings, _ = self.video_embeddings_connector(video_features, additive_mask)
video_embeddings, _ = self.video_embeddings_connector(
video_features, additive_mask
)
if return_audio_embeddings:
audio_embeddings, _ = self.audio_embeddings_connector(video_features, additive_mask)
audio_embeddings, _ = self.audio_embeddings_connector(
video_features, additive_mask
)
return video_embeddings, audio_embeddings
else:
return video_embeddings, attention_mask
@@ -964,7 +1079,7 @@ class LTX2TextEncoder(nn.Module):
# Remove leading/trailing whitespace
response = response.strip()
# Remove any leading punctuation
response = re.sub(r'^[^\w\s]+', '', response)
response = re.sub(r"^[^\w\s]+", "", response)
return response
def _apply_chat_template(
@@ -985,7 +1100,9 @@ class LTX2TextEncoder(nn.Module):
elif isinstance(content, list):
# Handle multimodal content (image + text)
text_parts = [c["text"] for c in content if c.get("type") == "text"]
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
formatted += (
f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
)
elif role == "assistant":
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
# Add generation prompt
@@ -1016,7 +1133,9 @@ class LTX2TextEncoder(nn.Module):
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler
except ImportError:
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
logging.warning(
"mlx-lm not available for prompt enhancement. Using original prompt."
)
return prompt
if self.processor is None:
@@ -1043,7 +1162,11 @@ class LTX2TextEncoder(nn.Module):
)
input_ids = mx.array(inputs["input_ids"])
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1))
sampler = make_sampler(
kwargs.get("temperature", 0.7),
kwargs.get("top_p", 1.0),
top_k=kwargs.get("top_k", -1),
)
logits_processors = make_logits_processors(
kwargs.get("logit_bias", None),
kwargs.get("repetition_penalty", 1.3),
@@ -1094,14 +1217,15 @@ class LTX2TextEncoder(nn.Module):
mx.clear_cache()
# Decode only the new tokens
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
enhanced_prompt = self.processor.decode(
generated_tokens, skip_special_tokens=True
)
enhanced_prompt = self._clean_response(enhanced_prompt)
logging.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
def enhance_i2v(
self,
prompt: str,
@@ -1135,4 +1259,3 @@ def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder()
encoder.load(model_path=model_path)
return encoder

View File

@@ -11,7 +11,7 @@ class PixArtAlphaTextProjection(nn.Module):
out_features: int | None = None,
bias: bool = True,
):
super().__init__()
out_features = out_features or hidden_size

View File

@@ -4,8 +4,8 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx_2.attention import Attention
from mlx_video.models.ltx_2.config import LTXRopeType, TransformerConfig
from mlx_video.models.ltx_2.feed_forward import FeedForward
from mlx_video.utils import rms_norm
@@ -171,8 +171,7 @@ class BasicAVTransformerBlock(nn.Module):
# timestep: (B, seq, num_params * dim) -> reshape to (B, seq, num_params, dim)
timestep_reshaped = mx.reshape(
timestep,
(batch_size, timestep.shape[1], num_ada_params, -1)
timestep, (batch_size, timestep.shape[1], num_ada_params, -1)
)
# Extract the relevant indices
@@ -225,8 +224,12 @@ class BasicAVTransformerBlock(nn.Module):
)
# Squeeze the sequence dimension if it's 1
scale_shift_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada)
gate_squeezed = tuple(mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada)
scale_shift_squeezed = tuple(
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in scale_shift_ada
)
gate_squeezed = tuple(
mx.squeeze(t, axis=1) if t.shape[1] == 1 else t for t in gate_ada
)
return (*scale_shift_squeezed, *gate_squeezed)
@@ -258,8 +261,16 @@ class BasicAVTransformerBlock(nn.Module):
# Check which modalities to run
run_vx = video is not None and video.enabled and vx.size > 0
run_ax = audio is not None and audio.enabled and ax.size > 0
run_a2v = run_vx and (audio is not None and audio.enabled and ax.size > 0) and not skip_cross_modal
run_v2a = run_ax and (video is not None and video.enabled and vx.size > 0) and not skip_cross_modal
run_a2v = (
run_vx
and (audio is not None and audio.enabled and ax.size > 0)
and not skip_cross_modal
)
run_v2a = (
run_ax
and (video is not None and video.enabled and vx.size > 0)
and not skip_cross_modal
)
# Process video self-attention and cross-attention with text
if run_vx:
@@ -269,7 +280,15 @@ class BasicAVTransformerBlock(nn.Module):
# Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_vx = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_msa) + vshift_msa
vx = vx + self.attn1(norm_vx, pe=video.positional_embeddings, skip_attention=skip_video_self_attn) * vgate_msa
vx = (
vx
+ self.attn1(
norm_vx,
pe=video.positional_embeddings,
skip_attention=skip_video_self_attn,
)
* vgate_msa
)
# Cross-attention with text context
if self.has_prompt_adaln:
@@ -278,11 +297,24 @@ class BasicAVTransformerBlock(nn.Module):
self.scale_shift_table, vx.shape[0], video.timesteps, slice(6, 9)
)
vprompt_shift_kv, vprompt_scale_kv = self.get_ada_values(
self.prompt_scale_shift_table, vx.shape[0], video.prompt_timesteps, slice(0, 2)
self.prompt_scale_shift_table,
vx.shape[0],
video.prompt_timesteps,
slice(0, 2),
)
attn_input = rms_norm(vx, eps=self.norm_eps) * (1 + vscale_q) + vshift_q
encoder_hidden_states = video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
vx = vx + self.attn2(attn_input, context=encoder_hidden_states, mask=video.context_mask) * vgate_q
encoder_hidden_states = (
video.context * (1 + vprompt_scale_kv) + vprompt_shift_kv
)
vx = (
vx
+ self.attn2(
attn_input,
context=encoder_hidden_states,
mask=video.context_mask,
)
* vgate_q
)
else:
vx = vx + self.attn2(
rms_norm(vx, eps=self.norm_eps),
@@ -298,20 +330,46 @@ class BasicAVTransformerBlock(nn.Module):
# Self-attention with RoPE (skip_attention=True for STG perturbation)
norm_ax = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_msa) + ashift_msa
ax = ax + self.audio_attn1(norm_ax, pe=audio.positional_embeddings, skip_attention=skip_audio_self_attn) * agate_msa
ax = (
ax
+ self.audio_attn1(
norm_ax,
pe=audio.positional_embeddings,
skip_attention=skip_audio_self_attn,
)
* agate_msa
)
# Cross-attention with text context
if self.has_prompt_adaln:
# LTX-2.3: Q modulated by timestep (indices 6-8), context modulated by prompt_adaln
ashift_q, ascale_q, agate_q = self.get_ada_values(
self.audio_scale_shift_table, ax.shape[0], audio.timesteps, slice(6, 9)
self.audio_scale_shift_table,
ax.shape[0],
audio.timesteps,
slice(6, 9),
)
aprompt_shift_kv, aprompt_scale_kv = self.get_ada_values(
self.audio_prompt_scale_shift_table, ax.shape[0], audio.prompt_timesteps, slice(0, 2)
self.audio_prompt_scale_shift_table,
ax.shape[0],
audio.prompt_timesteps,
slice(0, 2),
)
attn_input_a = (
rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
)
encoder_hidden_states_a = (
audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
)
ax = (
ax
+ self.audio_attn2(
attn_input_a,
context=encoder_hidden_states_a,
mask=audio.context_mask,
)
* agate_q
)
attn_input_a = rms_norm(ax, eps=self.norm_eps) * (1 + ascale_q) + ashift_q
encoder_hidden_states_a = audio.context * (1 + aprompt_scale_kv) + aprompt_shift_kv
ax = ax + self.audio_attn2(attn_input_a, context=encoder_hidden_states_a, mask=audio.context_mask) * agate_q
else:
ax = ax + self.audio_attn2(
rms_norm(ax, eps=self.norm_eps),

View File

@@ -1,4 +1,5 @@
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -36,11 +37,20 @@ class Conv3d(nn.Module):
self.groups = groups
# Weight shape: (C_out, KD, KH, KW, C_in)
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
scale = (
1.0
/ (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
shape=(
out_channels,
kernel_size[0],
kernel_size[1],
kernel_size[2],
in_channels,
),
)
if bias:
@@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module):
n, d, h, w, c = x.shape
input_dtype = x.dtype
x = x.astype(mx.float32)
# Reshape to (N, D*H*W, num_groups, C//num_groups)
@@ -219,7 +228,9 @@ class SpatialRationalResampler(nn.Module):
self.den = den
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1)
self.conv = nn.Conv2d(
mid_channels, num * num * mid_channels, kernel_size=3, padding=1
)
self.pixel_shuffle = PixelShuffle2D(num, num)
self.blur_down = BlurDownsample(stride=den)
@@ -230,7 +241,7 @@ class SpatialRationalResampler(nn.Module):
x = self.conv(x)
x = self.pixel_shuffle(x) # H*num, W*num
x = self.blur_down(x) # H*num/den, W*num/den
x = self.blur_down(x) # H*num/den, W*num/den
_, h_out, w_out, _ = x.shape
x = mx.reshape(x, (n, d, h_out, w_out, c))
@@ -240,6 +251,7 @@ class SpatialRationalResampler(nn.Module):
def _rational_for_scale(scale: float) -> Tuple[int, int]:
"""Convert a float scale to a rational fraction (numerator, denominator)."""
from fractions import Fraction
frac = Fraction(scale).limit_denominator(10)
return frac.numerator, frac.denominator
@@ -290,16 +302,22 @@ class LatentUpsampler(nn.Module):
self.initial_norm = GroupNorm3d(32, mid_channels)
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
self.res_blocks = {
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
}
# Upsampler: 2D spatial upsampling (frame-by-frame)
if rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale)
self.upsampler = SpatialRationalResampler(
mid_channels=mid_channels, scale=spatial_scale
)
else:
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
self.post_upsample_res_blocks = {
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
}
# Final projection
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
@@ -314,10 +332,13 @@ class LatentUpsampler(nn.Module):
Returns:
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
"""
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
print(
f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}"
)
if debug:
print(" [DEBUG] LatentUpsampler forward pass:")
@@ -404,7 +425,11 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
# Both formats may have upsampler.blur_down.kernel, so use channel count
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
conv_key = (
"upsampler.conv.weight"
if "upsampler.conv.weight" in raw_weights
else "upsampler.0.weight"
)
if conv_key in raw_weights:
out_channels = raw_weights[conv_key].shape[0]
ratio = out_channels // mid_channels
@@ -414,7 +439,9 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
rational_resampler = False
spatial_scale = 2.0
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
print(
f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}"
)
# Create model
upsampler = LatentUpsampler(

View File

@@ -109,6 +109,7 @@ def convert_audio_encoder(
return encoder_dir
from huggingface_hub import hf_hub_download
vae_path = hf_hub_download(
source_repo,
"audio_vae/diffusion_pytorch_model.safetensors",

View File

@@ -1,8 +1,8 @@
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.decoder import LTX2VideoDecoder, VideoDecoder
from mlx_video.models.ltx_2.video_vae.encoder import encode_image
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,
TemporalTilingConfig,
TilingConfig,
)
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -27,14 +27,18 @@ def reflect_pad_2d(x: mx.array, pad_h: int, pad_w: int) -> mx.array:
# Height padding (axis 2)
if pad_h > 0:
# Get reflection indices - exclude boundary
top_pad = x[:, :, 1:pad_h+1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h-1:-1, :, :][:, :, ::-1, :, :] # Flip bottom portion
top_pad = x[:, :, 1 : pad_h + 1, :, :][:, :, ::-1, :, :] # Flip top portion
bottom_pad = x[:, :, -pad_h - 1 : -1, :, :][
:, :, ::-1, :, :
] # Flip bottom portion
x = mx.concatenate([top_pad, x, bottom_pad], axis=2)
# Width padding (axis 3)
if pad_w > 0:
left_pad = x[:, :, :, 1:pad_w+1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w-1:-1, :][:, :, :, ::-1, :] # Flip right portion
left_pad = x[:, :, :, 1 : pad_w + 1, :][:, :, :, ::-1, :] # Flip left portion
right_pad = x[:, :, :, -pad_w - 1 : -1, :][
:, :, :, ::-1, :
] # Flip right portion
x = mx.concatenate([left_pad, x, right_pad], axis=3)
return x
@@ -50,7 +54,7 @@ def make_conv_nd(
causal: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
) -> nn.Module:
if dims == 2:
return CausalConv2d(
in_channels=in_channels,
@@ -118,15 +122,17 @@ class CausalConv3d(nn.Module):
)
def __call__(self, x: mx.array, causal: Optional[bool] = None) -> mx.array:
use_causal = causal if causal is not None else self.causal
# Apply temporal padding via frame replication
# Apply temporal padding via frame replication
# Only apply if kernel_size > 1
if self.time_kernel_size > 1:
if use_causal:
# Causal: replicate first frame kernel_size-1 times at the beginning
first_frame_pad = mx.repeat(x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2)
first_frame_pad = mx.repeat(
x[:, :, :1, :, :], self.time_kernel_size - 1, axis=2
)
x = mx.concatenate([first_frame_pad, x], axis=2)
else:
# Non-causal: replicate first frame at start, last frame at end
@@ -176,7 +182,6 @@ class CausalConv3d(nn.Module):
"""
b, d, h, w, c = x.shape
total_elements = d * h * w * c
max_safe_elements = 30 * 192 * 192 * 128 # ~140M elements per chunk
@@ -191,11 +196,10 @@ class CausalConv3d(nn.Module):
overlap = kernel_t - 1
expected_output_frames = d - overlap
outputs = []
out_idx = 0
out_idx = 0
# Process chunks
in_start = 0

View File

@@ -15,14 +15,14 @@ Architecture (from PyTorch weights):
"""
import math
from typing import Optional, Dict
from pathlib import Path
from typing import Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, unpatchify
from mlx_video.models.ltx_2.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx_2.video_vae.tiling import TilingConfig, decode_with_tiling
@@ -77,16 +77,14 @@ class PixArtAlphaTimestepEmbedder(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
self.timestep_embedder = TimestepEmbedding(
in_channels=256,
time_embed_dim=embedding_dim
in_channels=256, time_embed_dim=embedding_dim
)
def __call__(self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32) -> mx.array:
def __call__(
self, timestep: mx.array, hidden_dtype: mx.Dtype = mx.float32
) -> mx.array:
timesteps_proj = get_timestep_embedding(
timestep,
embedding_dim=256,
flip_sin_to_cos=True,
downscale_freq_shift=0
timestep, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0
)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))
return timesteps_emb
@@ -119,6 +117,7 @@ class ResnetBlock3DSimple(nn.Module):
def _make_conv_wrapper(self, in_ch, out_ch, padding_mode):
"""Create a wrapper object with a 'conv' attribute to match PyTorch naming."""
class ConvWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
@@ -130,13 +129,15 @@ class ResnetBlock3DSimple(nn.Module):
padding=1,
spatial_padding_mode=padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
return ConvWrapper()
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__(
self,
@@ -153,7 +154,9 @@ class ResnetBlock3DSimple(nn.Module):
if self.timestep_conditioning and timestep_embed is not None:
# scale_shift_table: (4, C), timestep_embed: (B, 4*C, 1, 1, 1)
# Combine table with timestep embedding
ada_values = self.scale_shift_table[None, :, :, None, None, None] # (1, 4, C, 1, 1, 1)
ada_values = self.scale_shift_table[
None, :, :, None, None, None
] # (1, 4, C, 1, 1, 1)
# Reshape timestep_embed from (B, 4*C, 1, 1, 1) to (B, 4, C, 1, 1, 1)
channels = self.scale_shift_table.shape[1]
ts_reshaped = timestep_embed.reshape(batch_size, 4, channels, 1, 1, 1)
@@ -199,16 +202,14 @@ class ResBlockGroup(nn.Module):
# Time embedder for this block group: embed_dim = 4 * channels
if timestep_conditioning:
self.time_embedder = PixArtAlphaTimestepEmbedder(
embedding_dim=channels * 4
)
self.time_embedder = PixArtAlphaTimestepEmbedder(embedding_dim=channels * 4)
# Use dict with int keys for MLX to track parameters properly
self.res_blocks = {
i: ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
timestep_conditioning=timestep_conditioning,
)
for i in range(num_layers)
}
@@ -224,8 +225,7 @@ class ResBlockGroup(nn.Module):
if self.timestep_conditioning and timestep is not None:
batch_size = x.shape[0]
timestep_embed = self.time_embedder(
timestep.flatten(),
hidden_dtype=x.dtype
timestep.flatten(), hidden_dtype=x.dtype
)
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
@@ -301,8 +301,10 @@ class LTX2VideoDecoder(nn.Module):
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_in = ConvInWrapper()
# Build up blocks from config
@@ -311,8 +313,12 @@ class LTX2VideoDecoder(nn.Module):
block_type = block_def[0]
ch = block_def[1]
if block_type == "res":
num_layers = block_def[2] if len(block_def) > 2 else num_layers_per_block
self.up_blocks[idx] = ResBlockGroup(ch, num_layers, spatial_padding_mode, timestep_conditioning)
num_layers = (
block_def[2] if len(block_def) > 2 else num_layers_per_block
)
self.up_blocks[idx] = ResBlockGroup(
ch, num_layers, spatial_padding_mode, timestep_conditioning
)
elif block_type == "d2s":
reduction = block_def[2] if len(block_def) > 2 else 2
stride = block_def[3] if len(block_def) > 3 else (2, 2, 2)
@@ -327,6 +333,7 @@ class LTX2VideoDecoder(nn.Module):
)
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
def __init__(self_inner):
super().__init__()
@@ -338,8 +345,10 @@ class LTX2VideoDecoder(nn.Module):
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self_inner, x, causal=False):
return self_inner.conv(x, causal=causal)
self.conv_out = ConvOutWrapper()
self.act = nn.SiLU()
@@ -358,7 +367,7 @@ class LTX2VideoDecoder(nn.Module):
return weights
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
@@ -374,7 +383,6 @@ class LTX2VideoDecoder(nn.Module):
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
@@ -384,7 +392,10 @@ class LTX2VideoDecoder(nn.Module):
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
if (
".conv.conv.weight" not in new_key
and ".conv.conv.bias" not in new_key
):
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
@@ -392,7 +403,9 @@ class LTX2VideoDecoder(nn.Module):
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
def from_pretrained(
cls, model_path: Path, strict: bool = True
) -> "LTX2VideoDecoder":
"""Load a pretrained decoder from a directory with config.json and weights.
Args:
@@ -422,7 +435,6 @@ class LTX2VideoDecoder(nn.Module):
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Infer block structure from weights
decoder_blocks = cls._infer_blocks(weights)
@@ -537,11 +549,9 @@ class LTX2VideoDecoder(nn.Module):
return final_blocks
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
return x / mx.sqrt(mx.mean(x ** 2, axis=1, keepdims=True) + eps)
return x / mx.sqrt(mx.mean(x**2, axis=1, keepdims=True) + eps)
def __call__(
self,
@@ -551,20 +561,15 @@ class LTX2VideoDecoder(nn.Module):
debug: bool = False,
chunked_conv: bool = False,
) -> mx.array:
batch_size = sample.shape[0]
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
@@ -574,7 +579,6 @@ class LTX2VideoDecoder(nn.Module):
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
@@ -583,19 +587,18 @@ class LTX2VideoDecoder(nn.Module):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
x = self.pixel_norm(x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
scaled_timestep.flatten(),
hidden_dtype=x.dtype
scaled_timestep.flatten(), hidden_dtype=x.dtype
)
embedded_timestep = embedded_timestep.reshape(batch_size, -1, 1, 1, 1)
ada_values = self.last_scale_shift_table[None, :, :, None, None, None] # (1, 2, 128, 1, 1, 1)
ada_values = self.last_scale_shift_table[
None, :, :, None, None, None
] # (1, 2, 128, 1, 1, 1)
ts_reshaped = embedded_timestep.reshape(batch_size, 2, 128, 1, 1, 1)
ada_values = ada_values + ts_reshaped
@@ -603,16 +606,13 @@ class LTX2VideoDecoder(nn.Module):
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
x = self.act(x)
x = self.conv_out(x, causal=causal)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
return x
@@ -669,11 +669,23 @@ class LTX2VideoDecoder(nn.Module):
# Auto-enable chunked conv for modes where it helps (larger tiles)
# Chunked conv reduces memory by processing conv+depth_to_space in temporal chunks
use_chunked_conv = tiling_mode in ("conservative", "none", "auto", "default", "spatial")
use_chunked_conv = tiling_mode in (
"conservative",
"none",
"auto",
"default",
"spatial",
)
if not needs_spatial_tiling and not needs_temporal_tiling:
# No tiling needed, use regular decode
return self(sample, causal=causal, timestep=timestep, debug=debug, chunked_conv=use_chunked_conv)
return self(
sample,
causal=causal,
timestep=timestep,
debug=debug,
chunked_conv=use_chunked_conv,
)
return decode_with_tiling(
decoder_fn=self,

View File

@@ -6,8 +6,8 @@ to latent space, which can then be used to condition video generation.
"""
import mlx.core as mx
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx_2.video_vae.video_vae import VideoEncoder
def encode_image(

View File

@@ -1,6 +1,5 @@
"""Operations for Video VAE."""
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -32,7 +31,9 @@ def patchify(x: mx.array, patch_size_hw: int = 4, patch_size_t: int = 1) -> mx.a
new_c = c * patch_size_hw * patch_size_hw * patch_size_t
# Reshape: (B, C, F, H, W) -> (B, C, F/pt, pt, H/ph, ph, W/pw, pw)
x = mx.reshape(x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw))
x = mx.reshape(
x, (b, c, new_f, patch_size_t, new_h, patch_size_hw, new_w, patch_size_hw)
)
# Permute: (B, C, F', pt, H', ph, W', pw) -> (B, C, pt, pw, ph, F', H', W')
# PyTorch einops uses (c, p, r, q) = (c, temporal, width, height), so we need pw before ph
@@ -101,7 +102,7 @@ class PerChannelStatistics(nn.Module):
Normalized tensor
"""
# Expand mean and std for broadcasting: (C,) -> (1, C, 1, 1, 1)
dtype = x.dtype
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
@@ -117,7 +118,7 @@ class PerChannelStatistics(nn.Module):
Returns:
Denormalized tensor
"""
dtype = x.dtype
dtype = x.dtype
# Cast to float32 for precision
mean = self.mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.std.astype(mx.float32).reshape(1, -1, 1, 1, 1)

View File

@@ -44,7 +44,7 @@ class ResnetBlock3D(nn.Module):
timestep_conditioning: bool = False,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
out_channels = out_channels or in_channels
@@ -96,7 +96,7 @@ class ResnetBlock3D(nn.Module):
causal: bool = True,
generator: Optional[int] = None,
) -> mx.array:
residual = x
# First block
@@ -136,7 +136,7 @@ class UNetMidBlock3D(nn.Module):
attention_head_dim: Optional[int] = None,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
self.num_layers = num_layers

View File

@@ -104,7 +104,7 @@ class SpaceToDepthDownsample(nn.Module):
class DepthToSpaceUpsample(nn.Module):
def __init__(
self,
dims: int,
@@ -114,7 +114,7 @@ class DepthToSpaceUpsample(nn.Module):
out_channels_reduction_factor: int = 1,
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
@@ -156,7 +156,9 @@ class DepthToSpaceUpsample(nn.Module):
return x
def __call__(self, x: mx.array, causal: bool = True, chunked_conv: bool = False) -> mx.array:
def __call__(
self, x: mx.array, causal: bool = True, chunked_conv: bool = False
) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
@@ -196,7 +198,9 @@ class DepthToSpaceUpsample(nn.Module):
return x
def _chunked_conv_depth_to_space(self, x: mx.array, causal: bool = True) -> mx.array:
def _chunked_conv_depth_to_space(
self, x: mx.array, causal: bool = True
) -> mx.array:
"""Chunked conv + depth_to_space that processes in temporal chunks.
This reduces peak memory by avoiding the full high-channel intermediate tensor.

View File

@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
# Apply right ramp (fade out)
if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
fade_out = [
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
]
for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i]
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
raise ValueError(
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
)
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
raise ValueError(
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
)
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
raise ValueError(
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
)
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
raise ValueError(
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
)
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
raise ValueError(
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
)
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
raise ValueError(
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
)
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
@@ -113,15 +127,21 @@ class TilingConfig:
def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
),
)
@classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
),
temporal_config=None,
)
@@ -130,23 +150,33 @@ class TilingConfig:
"""Temporal tiling only (for long videos with small resolution)."""
return cls(
spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
),
)
@classmethod
def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=256, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=32, tile_overlap_in_frames=8
),
)
@classmethod
def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=768, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=96, tile_overlap_in_frames=24
),
)
@classmethod
@@ -186,10 +216,14 @@ class TilingConfig:
temporal_config = None
if needs_spatial:
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
spatial_config = SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
)
if needs_temporal:
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
temporal_config = TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@@ -197,16 +231,21 @@ class TilingConfig:
@dataclass
class DimensionIntervals:
"""Intervals for splitting a single dimension."""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
def split_in_spatial(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a spatial dimension into intervals."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
return DimensionIntervals(
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
def split_in_temporal(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
# Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size)
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
return DimensionIntervals(
starts=starts,
ends=intervals.ends,
left_ramps=left_ramps,
right_ramps=intervals.right_ramps,
)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
def map_temporal_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask."""
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, True
)
return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
def map_spatial_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask."""
start = begin * scale
stop = end * scale
left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, False
)
return slice(start, stop), mask
@@ -315,7 +373,9 @@ def decode_with_tiling(
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_temporal(
temporal_tile_size, temporal_overlap, f_latent
)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -338,7 +398,9 @@ def decode_with_tiling(
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_temporal_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -347,7 +409,9 @@ def decode_with_tiling(
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
out_h_slice, h_mask = map_spatial_slice(
h_start, h_end, h_left, h_right, spatial_scale
)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
@@ -356,13 +420,23 @@ def decode_with_tiling(
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
out_w_slice, w_mask = map_spatial_slice(
w_start, w_end, w_left, w_right, spatial_scale
)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
tile_latents = latents[
:, :, t_start:t_end, h_start:h_end, w_start:w_end
]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
tile_output = decoder_fn(
tile_latents,
causal=causal,
timestep=timestep,
debug=False,
chunked_conv=chunked_conv,
)
mx.eval(tile_output)
# Clear tile_latents reference
@@ -385,13 +459,15 @@ def decode_with_tiling(
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
t_mask_slice.reshape(1, 1, -1, 1, 1)
* h_mask_slice.reshape(1, 1, 1, -1, 1)
* w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
tile_output_slice = tile_output[
:, :, :actual_t, :actual_h, :actual_w
].astype(mx.float32)
# Clear full tile_output
del tile_output
@@ -409,11 +485,37 @@ def decode_with_tiling(
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ blend_mask
)
# Force evaluation to free memory
@@ -445,10 +547,12 @@ def decode_with_tiling(
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
next_tile_start_out = (
1 + (next_tile_start_latent - 1) * temporal_scale
)
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
if not hasattr(decode_with_tiling, "_emitted_frames"):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
@@ -456,7 +560,10 @@ def decode_with_tiling(
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = (
output[:, :, emitted:next_tile_start_out, :, :]
/ finalized_weights
)
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
@@ -473,7 +580,7 @@ def decode_with_tiling(
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
@@ -481,7 +588,7 @@ def decode_with_tiling(
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
if hasattr(decode_with_tiling, "_emitted_frames"):
del decode_with_tiling._emitted_frames
# Clean up weights

View File

@@ -8,12 +8,15 @@ import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx_2.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx_2.video_vae.ops import PerChannelStatistics, patchify, unpatchify
from mlx_video.models.ltx_2.video_vae.ops import (
PerChannelStatistics,
patchify,
unpatchify,
)
from mlx_video.models.ltx_2.video_vae.resnet import (
NormLayerType,
ResnetBlock3D,
UNetMidBlock3D,
get_norm_layer,
)
from mlx_video.models.ltx_2.video_vae.sampling import (
DepthToSpaceUpsample,
@@ -24,6 +27,7 @@ from mlx_video.utils import PixelNorm
class LogVarianceType(Enum):
"""Log variance mode for VAE."""
PER_CHANNEL = "per_channel"
UNIFORM = "uniform"
CONSTANT = "constant"
@@ -229,7 +233,6 @@ class VideoEncoder(nn.Module):
config: VideoEncoderModelConfig with encoder parameters
"""
super().__init__()
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
self.patch_size = config.patch_size
self.norm_layer = config.norm_layer
@@ -241,10 +244,12 @@ class VideoEncoder(nn.Module):
encoder_spatial_padding_mode = config.encoder_spatial_padding_mode
# Per-channel statistics for normalizing latents
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.out_channels)
self.per_channel_statistics = PerChannelStatistics(
latent_channels=config.out_channels
)
# After patchify, channels increase by patch_size^2
in_channels = config.in_channels * config.patch_size ** 2
in_channels = config.in_channels * config.patch_size**2
feature_channels = config.out_channels
# Initial convolution
@@ -262,7 +267,11 @@ class VideoEncoder(nn.Module):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {}
for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_encoder_block(
block_name=block_name,
@@ -291,7 +300,10 @@ class VideoEncoder(nn.Module):
conv_out_channels = config.out_channels
if config.latent_log_var == LogVarianceType.PER_CHANNEL:
conv_out_channels *= 2
elif config.latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}:
elif config.latent_log_var in {
LogVarianceType.UNIFORM,
LogVarianceType.CONSTANT,
}:
conv_out_channels += 1
self.conv_out = CausalConv3d(
@@ -349,13 +361,16 @@ class VideoEncoder(nn.Module):
elif self.latent_log_var == LogVarianceType.CONSTANT:
sample = sample[:, :-1, ...]
approx_ln_0 = -30
sample = mx.concatenate([
sample,
mx.full_like(sample, approx_ln_0),
], axis=1)
sample = mx.concatenate(
[
sample,
mx.full_like(sample, approx_ln_0),
],
axis=1,
)
# Split into means and logvar, normalize means
means = sample[:, :self.latent_channels, ...]
means = sample[:, : self.latent_channels, ...]
return self.per_channel_statistics.normalize(means)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
@@ -409,6 +424,7 @@ class VideoEncoder(nn.Module):
Loaded VideoEncoder instance
"""
import json
from mlx_video.models.ltx_2.config import VideoEncoderModelConfig
# Load config
@@ -474,7 +490,7 @@ class VideoDecoder(nn.Module):
decoder_blocks = []
self.patch_size = patch_size
out_channels = out_channels * patch_size ** 2
out_channels = out_channels * patch_size**2
self.causal = causal
self.timestep_conditioning = timestep_conditioning
self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS
@@ -510,7 +526,11 @@ class VideoDecoder(nn.Module):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block_config = (
{"num_layers": block_params}
if isinstance(block_params, int)
else block_params
)
block, feature_channels = _make_decoder_block(
block_name=block_name,

View File

@@ -98,8 +98,12 @@ class WanSelfAttention(nn.Module):
v = self.v(x_w).reshape(b, s, n, d)
# RoPE in float32 for precision (official uses float64)
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
q = rope_apply(
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
k = rope_apply(
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
)
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
@@ -120,9 +124,7 @@ class WanSelfAttention(nn.Module):
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
return self.o(out)
@@ -213,9 +215,7 @@ class WanCrossAttention(nn.Module):
q, k, v, scale=self.scale, mask=mask
)
else:
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=self.scale
)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
return self.o(out)

View File

@@ -7,7 +7,6 @@ from typing import Dict, List, Optional, Tuple
import mlx.core as mx
import mlx.utils
import numpy as np
logger = logging.getLogger(__name__)
@@ -57,7 +56,9 @@ def load_safetensors_weights(path: str) -> Dict[str, mx.array]:
return weights
def sanitize_wan_transformer_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
def sanitize_wan_transformer_weights(
weights: Dict[str, mx.array]
) -> Dict[str, mx.array]:
"""Convert Wan2.2 transformer weight keys to MLX model structure.
Wan2.2 keys follow the pattern:
@@ -246,8 +247,8 @@ def _load_lora_configs(
Shared between weight-merging and runtime-wrapping paths.
"""
from mlx_video.lora import LoRAConfig, load_multiple_loras
from mlx_video.generate_wan import Colors
from mlx_video.lora import LoRAConfig, load_multiple_loras
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
@@ -264,7 +265,9 @@ def _load_lora_configs(
module_to_loras = load_multiple_loras(configs)
if not module_to_loras:
print(f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}")
print(
f"{Colors.YELLOW}Warning: No LoRA weights matched model layers{Colors.RESET}"
)
return module_to_loras
@@ -279,8 +282,8 @@ def load_and_apply_loras(
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
"""
from mlx_video.lora import apply_loras_to_weights
from mlx_video.generate_wan import Colors
from mlx_video.lora import apply_loras_to_weights
if not lora_configs:
return model_weights
@@ -289,12 +292,17 @@ def load_and_apply_loras(
if not module_to_loras:
return model_weights
print(f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}")
print(
f"{Colors.GREEN}Applying LoRAs to {len(module_to_loras)} modules...{Colors.RESET}"
)
if verbose:
print(f" Model has {len(model_weights)} weight keys")
modified_weights = apply_loras_to_weights(
model_weights, module_to_loras, verbose=verbose, quantization_bits=quantization_bits
model_weights,
module_to_loras,
verbose=verbose,
quantization_bits=quantization_bits,
)
print(f"{Colors.GREEN}✓ LoRAs applied successfully{Colors.RESET}")
@@ -435,8 +443,10 @@ def convert_wan_checkpoint(
src_model_type = src_config.get("model_type", "t2v")
src_text_len = src_config.get("text_len", 512)
print(f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}")
print(
f" Source config: dim={src_dim}, layers={src_num_layers}, "
f"heads={src_num_heads}, type={src_model_type}"
)
# Use preset for known TI2V 5B configuration
if src_model_type == "ti2v" and src_dim == 3072:
@@ -513,8 +523,11 @@ def convert_wan_checkpoint(
weights = load_torch_weights(str(vae_path))
if is_wan22_vae:
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
include_encoder = config.model_type in ("ti2v", "i2v")
weights = sanitize_wan22_vae_weights(weights, include_encoder=include_encoder)
weights = sanitize_wan22_vae_weights(
weights, include_encoder=include_encoder
)
else:
weights = sanitize_wan_vae_weights(weights)
# Always save VAE in float32 — official Wan2.2 runs VAE decode in
@@ -527,7 +540,9 @@ def convert_wan_checkpoint(
# Quantize transformer weights if requested
if quantize:
print(f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})...")
print(
f"\nQuantizing transformer weights ({bits}-bit, group_size={group_size})..."
)
_quantize_saved_model(output_dir, config, is_dual, bits, group_size)
print(f"\nConversion complete! Output: {output_dir}")
@@ -543,9 +558,16 @@ def _quantize_predicate(path: str, module) -> bool:
return False
# Quantize attention Q/K/V/O and FFN fc1/fc2
quantize_patterns = (
".self_attn.q", ".self_attn.k", ".self_attn.v", ".self_attn.o",
".cross_attn.q", ".cross_attn.k", ".cross_attn.v", ".cross_attn.o",
".ffn.fc1", ".ffn.fc2",
".self_attn.q",
".self_attn.k",
".self_attn.v",
".self_attn.o",
".cross_attn.q",
".cross_attn.k",
".cross_attn.v",
".cross_attn.o",
".ffn.fc1",
".ffn.fc2",
)
return any(path.endswith(p) for p in quantize_patterns)
@@ -684,14 +706,20 @@ def quantize_mlx_model(
# Build model config
from mlx_video.models.wan.config import WanModelConfig
config_dict = {k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__}
config_dict = {
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
}
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**config_dict)
# Copy non-transformer files to output dir (skip large model weights)
transformer_files = {"low_noise_model.safetensors", "high_noise_model.safetensors", "model.safetensors"}
transformer_files = {
"low_noise_model.safetensors",
"high_noise_model.safetensors",
"model.safetensors",
}
if dst.resolve() != src.resolve():
dst.mkdir(parents=True, exist_ok=True)
for f in src.iterdir():
@@ -763,11 +791,18 @@ if __name__ == "__main__":
if args.quantize_only:
quantize_mlx_model(
args.checkpoint_dir, args.output_dir,
bits=args.bits, group_size=args.group_size,
args.checkpoint_dir,
args.output_dir,
bits=args.bits,
group_size=args.group_size,
)
else:
convert_wan_checkpoint(
args.checkpoint_dir, args.output_dir, args.dtype, args.model_version,
quantize=args.quantize, bits=args.bits, group_size=args.group_size,
args.checkpoint_dir,
args.output_dir,
args.dtype,
args.model_version,
quantize=args.quantize,
bits=args.bits,
group_size=args.group_size,
)

View File

@@ -4,18 +4,15 @@ import argparse
import gc
import math
import random
import sys
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
from mlx_video.models.wan.i2v_utils import build_i2v_mask, preprocess_image
from mlx_video.models.wan.loading import (
_clean_text,
encode_text,
load_t5_encoder,
load_vae_decoder,
@@ -24,6 +21,7 @@ from mlx_video.models.wan.loading import (
)
from mlx_video.models.wan.postprocess import save_video
class Colors:
"""ANSI color codes for terminal output."""
@@ -37,6 +35,7 @@ class Colors:
DIM = "\033[2m"
RESET = "\033[0m"
# Backward-compat alias (tests and external code may use the old name)
_build_i2v_mask = build_i2v_mask
@@ -143,10 +142,13 @@ def generate_video(
for key in ("patch_size", "vae_stride", "window_size", "sample_guide_scale"):
if key in config_dict and isinstance(config_dict[key], list):
config_dict[key] = tuple(config_dict[key])
config = WanModelConfig(**{
k: v for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
})
config = WanModelConfig(
**{
k: v
for k, v in config_dict.items()
if k in WanModelConfig.__dataclass_fields__
}
)
else:
# Auto-detect: dual model files → 2.2, single model → 2.1
if (model_dir / "low_noise_model.safetensors").exists():
@@ -182,7 +184,9 @@ def generate_video(
if "patch_embedding_proj.weight" in k:
actual_dim = v.shape[0]
if actual_dim != config.dim:
print(f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}")
print(
f"{Colors.YELLOW} Config dim={config.dim} doesn't match weights dim={actual_dim}, auto-correcting...{Colors.RESET}"
)
if actual_dim <= 2048:
config = WanModelConfig.wan21_t2v_1_3b()
else:
@@ -192,13 +196,20 @@ def generate_video(
# Auto-correct Wan2.2 VAE params from stale configs
if config.in_dim == 48 and config.vae_z_dim != 48:
print(f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}")
config = WanModelConfig(**{
**{f.name: getattr(config, f.name) for f in config.__dataclass_fields__.values()},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
})
print(
f"{Colors.YELLOW} Auto-correcting Wan2.2 VAE params (in_dim=48 but vae_z_dim={config.vae_z_dim}){Colors.RESET}"
)
config = WanModelConfig(
**{
**{
f.name: getattr(config, f.name)
for f in config.__dataclass_fields__.values()
},
"vae_z_dim": 48,
"vae_stride": (4, 16, 16),
"sample_fps": 24,
}
)
# Apply defaults from config if not overridden
if steps is None:
@@ -227,7 +238,9 @@ def generate_video(
gen_frames = num_frames
if trim_first_frames > 0:
gen_frames = num_frames + trim_first_frames * 4
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
print(
f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}"
)
version_str = f"Wan{config.model_version}"
mode_str = "dual-model" if is_dual else "single-model"
@@ -247,10 +260,16 @@ def generate_video(
if is_i2v:
print(f" Image: {image}")
if neg_prompt_resolved and neg_prompt_resolved.strip():
neg_display = neg_prompt_resolved[:60] + "..." if len(neg_prompt_resolved) > 60 else neg_prompt_resolved
neg_display = (
neg_prompt_resolved[:60] + "..."
if len(neg_prompt_resolved) > 60
else neg_prompt_resolved
)
print(f" Neg prompt: {neg_display}")
print(f" Size: {width}x{height}, Frames: {num_frames}")
print(f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}")
print(
f" Steps: {steps}, Guide: {guide_scale}, Shift: {shift}, Solver: {scheduler}"
)
if cfg_disabled:
print(f" CFG: disabled (guide_scale≤1 → B=1 fast path, 2x denoising speedup)")
print(f"{Colors.RESET}")
@@ -275,12 +294,16 @@ def generate_video(
height = align_h
if width == 0:
width = align_w
print(f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
print(
f"{Colors.DIM} Aligned {old_w}x{old_h}{width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}"
)
# Enforce max_area constraint (model-specific resolution limit)
if config.max_area > 0 and height * width > config.max_area:
old_h, old_w = height, width
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
width, height = _best_output_size(
width, height, align_w, align_h, config.max_area
)
print(
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
@@ -309,6 +332,7 @@ def generate_video(
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
# Encode prompts
@@ -318,12 +342,15 @@ def generate_video(
context_null = None
mx.eval(context)
else:
context_null = encode_text(t5_encoder, tokenizer, neg_prompt_resolved, config.text_len)
context_null = encode_text(
t5_encoder, tokenizer, neg_prompt_resolved, config.text_len
)
mx.eval(context, context_null)
# Free T5 from memory
del t5_encoder
gc.collect(); mx.clear_cache()
gc.collect()
mx.clear_cache()
print(f"{Colors.DIM} T5 encoding: {time.time() - t1:.1f}s{Colors.RESET}")
# I2V: encode image to latent space
@@ -346,18 +373,25 @@ def generate_video(
img = Image.open(image).convert("RGB")
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
img = img.resize(
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
)
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
img_arr = mx.array(
np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0
) # [H, W, 3]
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
video = mx.concatenate([
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
], axis=1)
video = mx.concatenate(
[
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
],
axis=1,
)
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
vae_enc = load_vae_encoder(vae_path, config)
@@ -367,12 +401,17 @@ def generate_video(
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
msk = mx.concatenate(
[msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1
)
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
msk = mx.concatenate([
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
], axis=1)
msk = mx.concatenate(
[
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
],
axis=1,
)
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
@@ -395,13 +434,16 @@ def generate_video(
del vae_enc, img_tensor
gc.collect(); mx.clear_cache()
gc.collect()
mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
# Load transformer models
print(f"\n{Colors.BLUE}Loading transformer model(s)...{Colors.RESET}")
if quantization:
print(f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}")
print(
f"{Colors.DIM} Using {quantization['bits']}-bit quantized weights (group_size={quantization['group_size']}){Colors.RESET}"
)
t2 = time.time()
# Merge per-model LoRAs with shared LoRAs
@@ -412,10 +454,16 @@ def generate_video(
if is_dual:
low_noise_path = model_dir / "low_noise_model.safetensors"
high_noise_path = model_dir / "high_noise_model.safetensors"
low_noise_model = load_wan_model(low_noise_path, config, quantization, loras=_loras_low)
high_noise_model = load_wan_model(high_noise_path, config, quantization, loras=_loras_high)
low_noise_model = load_wan_model(
low_noise_path, config, quantization, loras=_loras_low
)
high_noise_model = load_wan_model(
high_noise_path, config, quantization, loras=_loras_high
)
else:
single_model = load_wan_model(model_dir / "model.safetensors", config, quantization, loras=_loras_single)
single_model = load_wan_model(
model_dir / "model.safetensors", config, quantization, loras=_loras_single
)
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step)
@@ -437,8 +485,12 @@ def generate_video(
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
context_cfg_low = mx.concatenate(
[context_emb_low[0:1], context_emb_low[1:2]], axis=0
)
context_cfg_high = mx.concatenate(
[context_emb_high[0:1], context_emb_high[1:2]], axis=0
)
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb)
@@ -534,7 +586,7 @@ def generate_video(
rcs = rope_cos_sin
# Use compiled forward when available (faster after first trace)
_call = getattr(model, '_compiled', model)
_call = getattr(model, "_compiled", model)
if cfg_disabled:
# No CFG: B=1 forward pass (2x faster than B=2 CFG batch)
@@ -552,7 +604,9 @@ def generate_video(
y_arg = [y_i2v] if is_i2v_channel_concat else None
if is_dual:
ctx = context_cond_high if timestep_val >= boundary else context_cond_low
ctx = (
context_cond_high if timestep_val >= boundary else context_cond_low
)
else:
ctx = context_cond
preds = _call(
@@ -571,7 +625,11 @@ def generate_video(
if is_dual:
gs = guide_scale[1] if timestep_val >= boundary else guide_scale[0]
else:
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
gs = (
guide_scale
if isinstance(guide_scale, (int, float))
else guide_scale[0]
)
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val
@@ -586,8 +644,10 @@ def generate_video(
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
ctx = (
context_cfg
if not is_dual
else (context_cfg_high if timestep_val >= boundary else context_cfg_low)
)
preds = _call(
[latents, latents],
@@ -618,16 +678,24 @@ def generate_video(
if debug_latents:
lat_np = np.array(latents) # [C, T, H, W]
n_t = lat_np.shape[1]
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
print(
f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}"
)
print(
f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}"
)
for t_pos in range(min(n_t, 8)):
frame = lat_np[:, t_pos, :, :]
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
print(
f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}"
)
if n_t > 8:
interior = lat_np[:, 4:, :, :]
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
print(
f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}"
)
print()
# Free transformer models and text embeddings
@@ -646,7 +714,8 @@ def generate_video(
del model, kv, context
if context_null is not None:
del context_null
gc.collect(); mx.clear_cache()
gc.collect()
mx.clear_cache()
# Load VAE and decode
print(f"\n{Colors.BLUE}Decoding with VAE...{Colors.RESET}")
@@ -677,13 +746,25 @@ def generate_video(
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
print(
f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}"
)
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
spatial_info = (
f"{tiling_config.spatial_config.tile_size_in_pixels}px"
if tiling_config.spatial_config
else "none"
)
temporal_info = (
f"{tiling_config.temporal_config.tile_size_in_frames}f"
if tiling_config.temporal_config
else "none"
)
print(
f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}"
)
if is_wan22_vae:
from mlx_video.models.wan.vae22 import denormalize_latents
@@ -718,7 +799,9 @@ def generate_video(
if trim_first_frames > 0:
trim_pixels = trim_first_frames * 4
video = video[trim_pixels:]
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
print(
f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}"
)
save_video(video, output_path, fps=config.sample_fps)
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
@@ -727,58 +810,124 @@ def generate_video(
def main():
parser = argparse.ArgumentParser(description="Wan Text-to-Video Generation (MLX)")
parser.add_argument("--model-dir", type=str, required=True, help="Path to converted MLX model directory")
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument("--image", type=str, default=None,
help="Path to input image for I2V (omit for T2V mode)")
parser.add_argument("--negative-prompt", type=str, default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)")
parser.add_argument("--no-negative-prompt", action="store_true",
help="Disable negative prompt (use empty string instead of config default)")
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
parser.add_argument("--shift", type=float, default=None, help="Noise schedule shift (default: from config)")
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument("--output-path", type=str, default="output.mp4", help="Output video path")
parser.add_argument(
"--scheduler", type=str, default="unipc",
"--model-dir",
type=str,
required=True,
help="Path to converted MLX model directory",
)
parser.add_argument("--prompt", type=str, required=True, help="Text prompt")
parser.add_argument(
"--image",
type=str,
default=None,
help="Path to input image for I2V (omit for T2V mode)",
)
parser.add_argument(
"--negative-prompt",
type=str,
default=None,
help="Negative prompt for CFG (default: official Chinese prompt from config)",
)
parser.add_argument(
"--no-negative-prompt",
action="store_true",
help="Disable negative prompt (use empty string instead of config default)",
)
parser.add_argument(
"--width", type=int, default=1280, help="Video width (default: 1280)"
)
parser.add_argument(
"--height",
type=int,
default=704,
help="Video height (default: 704; 720p models use 704)",
)
parser.add_argument(
"--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)"
)
parser.add_argument(
"--steps",
type=int,
default=None,
help="Number of diffusion steps (default: from config)",
)
parser.add_argument(
"--guide-scale",
type=str,
default=None,
help="Guidance scale: single float or low,high pair",
)
parser.add_argument(
"--shift",
type=float,
default=None,
help="Noise schedule shift (default: from config)",
)
parser.add_argument("--seed", type=int, default=-1, help="Random seed")
parser.add_argument(
"--output-path", type=str, default="output.mp4", help="Output video path"
)
parser.add_argument(
"--scheduler",
type=str,
default="unipc",
choices=["euler", "dpm++", "unipc"],
help="Diffusion solver: euler (1st order), dpm++ (2nd order), unipc (2nd order PC, default/official)",
)
parser.add_argument(
"--lora", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
"--lora",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to all models (repeatable). Format: --lora path.safetensors 0.8",
)
parser.add_argument(
"--lora-high", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
"--lora-high",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to high-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--lora-low", nargs=2, action="append", metavar=("PATH", "STRENGTH"),
"--lora-low",
nargs=2,
action="append",
metavar=("PATH", "STRENGTH"),
help="Apply a LoRA to low-noise model only (dual-model, repeatable)",
)
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
choices=[
"auto",
"none",
"default",
"aggressive",
"conservative",
"spatial",
"temporal",
],
help="VAE tiling mode to reduce memory during decoding (default: auto)",
)
parser.add_argument(
"--no-compile", action="store_true",
"--no-compile",
action="store_true",
help="Disable mx.compile on models (for debugging)",
)
parser.add_argument(
"--trim-first-frames", type=int, default=0, metavar="N",
"--trim-first-frames",
type=int,
default=0,
metavar="N",
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
"Default: 0 (disabled)",
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
"Default: 0 (disabled)",
)
parser.add_argument(
"--debug-latents", action="store_true",
"--debug-latents",
action="store_true",
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
)
args = parser.parse_args()

View File

@@ -21,7 +21,9 @@ def preprocess_image(image_path: str, width: int, height: int) -> mx.array:
# Resize so that the image covers the target size (LANCZOS)
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
img = img.resize(
(round(img.width * scale), round(img.height * scale)), Image.LANCZOS
)
# Center crop
x1 = (img.width - width) // 2

View File

@@ -6,7 +6,12 @@ import mlx.core as mx
import mlx.nn as nn
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
def load_wan_model(
model_path: Path,
config,
quantization: dict | None = None,
loras: list | None = None,
):
"""Load and initialize WanModel, with optional quantization and LoRA support.
Args:
@@ -93,9 +98,11 @@ def load_vae_decoder(model_path: Path, config=None):
if is_wan22:
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
vae = Wan22VAEDecoder(z_dim=48)
else:
from mlx_video.models.wan.vae import WanVAE
vae = WanVAE(z_dim=16)
weights = mx.load(str(model_path))
@@ -140,6 +147,7 @@ def _clean_text(text: str) -> str:
try:
import ftfy
text = ftfy.fix_text(text)
except ImportError:
pass

View File

@@ -1,4 +1,5 @@
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
@@ -37,7 +38,9 @@ class Head(nn.Module):
proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
"""
@@ -111,20 +114,23 @@ class WanModel(nn.Module):
# Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
self.freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Precompute sinusoidal inv_freq for time embedding.
half = config.freq_dim // 2
self._inv_freq = mx.array(
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
).astype(np.float32)
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
np.float32
)
)
def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings.
@@ -297,12 +303,19 @@ class WanModel(nn.Module):
seq_lens_list.append(p.shape[1])
x = mx.concatenate(
[
mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
(
mx.concatenate(
[
p,
mx.zeros(
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
),
],
axis=1,
)
if p.shape[1] < seq_len
else p
)
if p.shape[1] < seq_len
else p
for p in patches
],
axis=0,
@@ -315,9 +328,7 @@ class WanModel(nn.Module):
t = t[None]
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
sin_emb = mx.concatenate(
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
)
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]

View File

@@ -1,6 +1,8 @@
import numpy as np
from pathlib import Path
import numpy as np
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""Save video frames to MP4.
@@ -11,6 +13,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
"""
try:
import imageio
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
for frame in frames:
writer.append_data(frame)
@@ -18,6 +21,7 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
except ImportError:
try:
import cv2
h, w = frames.shape[1], frames.shape[2]
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
@@ -27,9 +31,11 @@ def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
except (ImportError, Exception):
# Last resort: save as individual PNGs
from PIL import Image
out_dir = Path(output_path).parent / Path(output_path).stem
out_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames):
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
print(
f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)"
)

View File

@@ -1,4 +1,3 @@
import math
import mlx.core as mx
import numpy as np
@@ -11,13 +10,16 @@ def rope_params(max_seq_len: int, dim: int, theta: float = 10000.0) -> mx.array:
Complex frequency tensor of shape [max_seq_len, dim // 2].
"""
assert dim % 2 == 0
freqs = np.arange(max_seq_len, dtype=np.float64)[:, None] * (
1.0
/ np.power(
theta,
np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
freqs = (
np.arange(max_seq_len, dtype=np.float64)[:, None]
* (
1.0
/ np.power(
theta,
np.arange(0, dim, 2, dtype=np.float64) / dim,
)
)[None, :]
)
# Store as (cos, sin) pairs: shape [max_seq_len, dim // 2, 2]
cos_freqs = np.cos(freqs).astype(np.float32)
sin_freqs = np.sin(freqs).astype(np.float32)
@@ -46,9 +48,9 @@ def rope_apply(
# Check if all batch elements have the same grid (common for CFG B=2)
f0, h0, w0 = grid_sizes[0]
seq_len = f0 * h0 * w0
all_same_grid = all(
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
) if b > 1 else True
all_same_grid = (
all(grid_sizes[i] == grid_sizes[0] for i in range(1, b)) if b > 1 else True
)
if all_same_grid:
# Vectorized path: apply RoPE to all batch elements at once
@@ -57,7 +59,9 @@ def rope_apply(
x_imag = x_seq[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(
b, seq_len, n, d
)
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
return x_rotated
@@ -102,17 +106,11 @@ def rope_apply(
# Build per-position frequencies by expanding along grid dims
# temporal: [f,1,1,d_t,2] -> [f,h,w,d_t,2]
ft = mx.broadcast_to(
freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2)
)
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
# height: [1,h,1,d_h,2] -> [f,h,w,d_h,2]
fh = mx.broadcast_to(
freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2)
)
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
# width: [1,1,w,d_w,2] -> [f,h,w,d_w,2]
fw = mx.broadcast_to(
freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2)
)
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
# Concatenate: [f*h*w, half_d, 2]
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)

View File

@@ -7,9 +7,8 @@ for the same quality as Euler.
import math
import numpy as np
import mlx.core as mx
import numpy as np
def _compute_sigmas(
@@ -25,9 +24,7 @@ def _compute_sigmas(
Returns num_steps+1 values (the last being 0.0 for the terminal state).
"""
# sigma bounds from unshifted training schedule (constructor uses shift=1)
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
::-1
]
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1]
sigmas_unshifted = 1.0 - alphas
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
sigma_min = float(sigmas_unshifted[-1]) # 0.0
@@ -65,7 +62,10 @@ class FlowMatchEulerScheduler:
sample: mx.array,
) -> mx.array:
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
dt = (
self._sigmas_float[self._step_index + 1]
- self._sigmas_float[self._step_index]
)
x_next = sample + dt * model_output
self._step_index += 1
return x_next
@@ -139,13 +139,8 @@ class FlowDPMPP2MScheduler:
# Decide order: 1st for first step, last step (if lower_order_final
# and few steps), otherwise 2nd
use_first_order = (
self._prev_x0 is None
or (
self.lower_order_final
and i == self._num_steps - 1
and self._num_steps < 15
)
use_first_order = self._prev_x0 is None or (
self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15
)
if use_first_order or sigma_next == 0.0:

View File

@@ -49,20 +49,19 @@ class T5RelativeEmbedding(nn.Module):
is_small = rel_pos < max_exact
rel_pos_f = rel_pos.astype(mx.float32)
rel_pos_large = (
max_exact
+ (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
)
rel_pos_large = max_exact + (
mx.log(rel_pos_f / max_exact)
/ math.log(self.max_dist / max_exact)
* (num_buckets - max_exact)
).astype(mx.int32)
rel_pos_large = mx.minimum(
rel_pos_large,
mx.full(rel_pos_large.shape, num_buckets - 1, dtype=mx.int32),
)
rel_buckets = rel_buckets + mx.where(is_small, rel_pos.astype(mx.int32), rel_pos_large)
rel_buckets = rel_buckets + mx.where(
is_small, rel_pos.astype(mx.int32), rel_pos_large
)
return rel_buckets
def __call__(self, lq: int, lk: int) -> mx.array:
@@ -115,7 +114,7 @@ class T5Attention(nn.Module):
v = v.transpose(0, 2, 1, 3)
# QK^T (no scaling) — compute in float32 for precision
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
attn = q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2)
# Add position bias
if pos_bias is not None:

View File

@@ -75,7 +75,11 @@ def decode_with_tiling(
b, c, f_latent, h_latent, w_latent = latents.shape
# Compute output shape
out_f = (1 + (f_latent - 1) * temporal_scale) if causal_temporal else (f_latent * temporal_scale)
out_f = (
(1 + (f_latent - 1) * temporal_scale)
if causal_temporal
else (f_latent * temporal_scale)
)
out_h = h_latent * spatial_scale
out_w = w_latent * spatial_scale
@@ -98,9 +102,13 @@ def decode_with_tiling(
# Compute intervals for each dimension
if causal_temporal:
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_temporal(
temporal_tile_size, temporal_overlap, f_latent
)
else:
temporal_intervals = split_in_spatial(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_spatial(
temporal_tile_size, temporal_overlap, f_latent
)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -124,9 +132,13 @@ def decode_with_tiling(
# Map temporal coordinates
if causal_temporal:
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_temporal_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
else:
out_t_slice, t_mask = map_spatial_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_spatial_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -135,7 +147,9 @@ def decode_with_tiling(
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
out_h_slice, h_mask = map_spatial_slice(
h_start, h_end, h_left, h_right, spatial_scale
)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
@@ -144,13 +158,23 @@ def decode_with_tiling(
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
out_w_slice, w_mask = map_spatial_slice(
w_start, w_end, w_left, w_right, spatial_scale
)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
tile_latents = latents[
:, :, t_start:t_end, h_start:h_end, w_start:w_end
]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
tile_output = decoder_fn(
tile_latents,
causal=causal,
timestep=timestep,
debug=False,
chunked_conv=chunked_conv,
)
mx.eval(tile_output)
# Clear tile_latents reference
@@ -173,13 +197,15 @@ def decode_with_tiling(
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
t_mask_slice.reshape(1, 1, -1, 1, 1)
* h_mask_slice.reshape(1, 1, 1, -1, 1)
* w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
tile_output_slice = tile_output[
:, :, :actual_t, :actual_h, :actual_w
].astype(mx.float32)
# Clear full tile_output
del tile_output
@@ -196,11 +222,37 @@ def decode_with_tiling(
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ blend_mask
)
# Force evaluation to free memory
@@ -232,12 +284,14 @@ def decode_with_tiling(
if next_tile_start_latent == 0:
next_tile_start_out = 0
elif causal_temporal:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
next_tile_start_out = (
1 + (next_tile_start_latent - 1) * temporal_scale
)
else:
next_tile_start_out = next_tile_start_latent * temporal_scale
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
if not hasattr(decode_with_tiling, "_emitted_frames"):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
@@ -245,7 +299,10 @@ def decode_with_tiling(
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = (
output[:, :, emitted:next_tile_start_out, :, :]
/ finalized_weights
)
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
@@ -262,7 +319,7 @@ def decode_with_tiling(
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
@@ -270,7 +327,7 @@ def decode_with_tiling(
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
if hasattr(decode_with_tiling, "_emitted_frames"):
del decode_with_tiling._emitted_frames
# Clean up weights

View File

@@ -25,9 +25,7 @@ class WanAttentionBlock(nn.Module):
# Cross-attention (with optional norm on context)
self.norm3 = (
WanLayerNorm(dim, eps, elementwise_affine=True)
if cross_attn_norm
else None
WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else None
)
self.cross_attn = WanCrossAttention(dim, num_heads, qk_norm, eps)
@@ -36,7 +34,9 @@ class WanAttentionBlock(nn.Module):
self.ffn = WanFFN(dim, ffn_dim)
# Learned modulation: 6 vectors for scale/shift/gate (kept in float32 for precision)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(mx.float32)
self.modulation = (mx.random.normal((1, 6, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__(
self,
@@ -67,7 +67,14 @@ class WanAttentionBlock(nn.Module):
# Self-attention with modulation (hidden state stays in w_dtype)
x_mod = self.norm1(x) * (1 + e1) + e0
y = self.self_attn(x_mod, seq_lens, grid_sizes, freqs, rope_cos_sin=rope_cos_sin, attn_mask=attn_mask)
y = self.self_attn(
x_mod,
seq_lens,
grid_sizes,
freqs,
rope_cos_sin=rope_cos_sin,
attn_mask=attn_mask,
)
x = x + y * e2
# Cross-attention (no modulation, just norm)

View File

@@ -6,19 +6,45 @@ so weights load directly without key sanitization.
import mlx.core as mx
import mlx.nn as nn
import numpy as np
CACHE_T = 2
# Per-channel normalization statistics for z_dim=16
VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921,
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
]
VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160,
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
]
@@ -50,7 +76,9 @@ class CausalConv3d(nn.Module):
self._pad_w = padding[2]
# MLX Conv3d: weight shape [O, D, H, W, I]
self.weight = mx.zeros((out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels))
self.weight = mx.zeros(
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
)
self.bias = mx.zeros((out_channels,))
def __call__(self, x: mx.array, cache_x: mx.array = None) -> mx.array:
@@ -67,8 +95,16 @@ class CausalConv3d(nn.Module):
x = mx.concatenate([pad_t, x], axis=2)
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (0, 0),
(self._pad_h, self._pad_h), (self._pad_w, self._pad_w)])
x = mx.pad(
x,
[
(0, 0),
(0, 0),
(0, 0),
(self._pad_h, self._pad_h),
(self._pad_w, self._pad_w),
],
)
x = x.transpose(0, 2, 3, 4, 1) # [B, T, H, W, C]
out = self._conv3d(x)
@@ -118,7 +154,11 @@ class RMS_norm(nn.Module):
def __call__(self, x: mx.array) -> mx.array:
norm_dim = 1 if self.channel_first else -1
# L2 normalize along channel dim (matches F.normalize)
norm = mx.sqrt(mx.clip(mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None))
norm = mx.sqrt(
mx.clip(
mx.sum(x * x, axis=norm_dim, keepdims=True), a_min=1e-12, a_max=None
)
)
return (x / norm) * self.scale * self.gamma
@@ -133,12 +173,12 @@ class ResidualBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.residual = [
RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU
RMS_norm(in_dim, images=False), # [0]
None, # [1] SiLU
CausalConv3d(in_dim, out_dim, 3, padding=1), # [2]
RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU
None, # [5] Dropout
RMS_norm(out_dim, images=False), # [3]
None, # [4] SiLU
None, # [5] Dropout
CausalConv3d(out_dim, out_dim, 3, padding=1), # [6]
]
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
@@ -226,13 +266,16 @@ class Resample(nn.Module):
# resample.0 = Upsample (no params), resample.1 = Conv2d
self.resample = [None, nn.Conv2d(dim, dim // 2, 3, padding=1)]
if mode == "upsample3d":
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)
)
else:
# resample.0 = ZeroPad2d (no params), resample.1 = Conv2d(stride=2)
self.resample = [None, nn.Conv2d(dim, dim, 3, stride=2)]
if mode == "downsample3d":
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
def __call__(self, x: mx.array, feat_cache=None, feat_idx=None) -> mx.array:
"""x: [B, C, T, H, W]"""
@@ -272,8 +315,7 @@ class Resample(nn.Module):
else:
# Subsequent chunks: use cached frame as temporal context
cache_x = x[:, :, -1:]
x = self.time_conv(
x, cache_x=feat_cache[idx][:, :, -1:])
x = self.time_conv(x, cache_x=feat_cache[idx][:, :, -1:])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
@@ -328,8 +370,8 @@ class Decoder3d(nn.Module):
# Output head: [RMS_norm, SiLU (no params), CausalConv3d]
self.head = [
RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU
RMS_norm(dims[-1], images=False), # [0]
None, # [1] SiLU
CausalConv3d(dims[-1], 3, 3, padding=1), # [2]
]
@@ -405,8 +447,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -431,8 +472,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:]
if cache_x.shape[2] < CACHE_T and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, :, -1:], cache_x], axis=2)
cache_x = mx.concatenate([feat_cache[idx][:, :, -1:], cache_x], axis=2)
x = self.head[2](x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -583,7 +623,7 @@ class WanVAE(nn.Module):
decoder_fn=tile_decode,
latents=z_denorm,
tiling_config=tiling_config,
spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
spatial_scale=8, # 3× spatial 2× upsamples = 8×
temporal_scale=4, # 2× temporal upsamples × 2 = 4×
causal_temporal=False, # Wan2.1 uses non-causal temporal (T → 4T)
)

View File

@@ -8,7 +8,6 @@ conversion (channels-first → channels-last) is needed.
"""
import logging
import math
import mlx.core as mx
import mlx.nn as nn
@@ -19,23 +18,111 @@ logger = logging.getLogger(__name__)
CACHE_T = 2
# Per-channel normalization for z_dim=48 latent space
VAE22_MEAN = mx.array([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
])
VAE22_MEAN = mx.array(
[
-0.2289,
-0.0052,
-0.1323,
-0.2339,
-0.2799,
0.0174,
0.1838,
0.1557,
-0.1382,
0.0542,
0.2813,
0.0891,
0.1570,
-0.0098,
0.0375,
-0.1825,
-0.2246,
-0.1207,
-0.0698,
0.5109,
0.2665,
-0.2108,
-0.2158,
0.2502,
-0.2055,
-0.0322,
0.1109,
0.1567,
-0.0729,
0.0899,
-0.2799,
-0.1230,
-0.0313,
-0.1649,
0.0117,
0.0723,
-0.2839,
-0.2083,
-0.0520,
0.3748,
0.0152,
0.1957,
0.1433,
-0.2944,
0.3573,
-0.0548,
-0.1681,
-0.0667,
]
)
VAE22_STD = mx.array([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
])
VAE22_STD = mx.array(
[
0.4765,
1.0364,
0.4514,
1.1677,
0.5313,
0.4990,
0.4818,
0.5013,
0.8158,
1.0344,
0.5894,
1.0901,
0.6885,
0.6165,
0.8454,
0.4978,
0.5759,
0.3523,
0.7135,
0.6804,
0.5833,
1.4146,
0.8986,
0.5659,
0.7069,
0.5338,
0.4889,
0.4917,
0.4069,
0.4999,
0.6866,
0.4093,
0.5709,
0.6065,
0.6415,
0.4944,
0.5726,
1.2042,
0.5458,
1.6887,
0.3971,
1.0600,
0.3943,
0.5537,
0.5444,
0.4089,
0.7468,
0.7744,
]
)
class CausalConv3d(nn.Module):
@@ -65,9 +152,9 @@ class CausalConv3d(nn.Module):
self._pad_w = padding[2]
# Weight: [O, D, H, W, I] for MLX
self.weight = mx.zeros((
out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels
))
self.weight = mx.zeros(
(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels)
)
self.bias = mx.zeros((out_channels,))
def __call__(self, x, cache_x=None):
@@ -96,8 +183,16 @@ class CausalConv3d(nn.Module):
# Spatial padding
if self._pad_h > 0 or self._pad_w > 0:
x = mx.pad(x, [(0, 0), (0, 0), (self._pad_h, self._pad_h),
(self._pad_w, self._pad_w), (0, 0)])
x = mx.pad(
x,
[
(0, 0),
(0, 0),
(self._pad_h, self._pad_h),
(self._pad_w, self._pad_w),
(0, 0),
],
)
T_padded = x.shape[1]
H_padded, W_padded = x.shape[2], x.shape[3]
@@ -113,8 +208,9 @@ class CausalConv3d(nn.Module):
for d in range(kd):
frame = x[:, t_start + d] # [B, H_padded, W_padded, C]
w2d = self.weight[:, d, :, :, :] # [O, kh, kw, I]
conv_out = mx.conv_general(frame, w2d,
stride=(self.stride[1], self.stride[2]))
conv_out = mx.conv_general(
frame, w2d, stride=(self.stride[1], self.stride[2])
)
accum = conv_out if accum is None else accum + conv_out
outputs.append(accum + self.bias)
@@ -126,7 +222,7 @@ class RMS_norm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.scale = dim**0.5
# Weight stored as (dim,) — PyTorch stores (dim, 1, 1, 1) but we squeeze
self.gamma = mx.ones((dim,))
@@ -134,7 +230,9 @@ class RMS_norm(nn.Module):
# x: [..., C] (channels-last)
# PyTorch uses F.normalize (L2 norm), not RMS: x / max(||x||_2, eps)
l2_sq = mx.sum(x * x, axis=-1, keepdims=True)
return x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
return (
x * mx.rsqrt(mx.maximum(l2_sq, mx.array(1e-24))) * self.scale * self.gamma
)
class ResidualBlock(nn.Module):
@@ -145,11 +243,7 @@ class ResidualBlock(nn.Module):
# Sequential residual path: [norm, silu, conv3d, norm, silu, dropout, conv3d]
# We store as named layers matching PyTorch's indices
self.residual = ResidualBlockLayers(in_dim, out_dim)
self.shortcut = (
CausalConv3d(in_dim, out_dim, 1)
if in_dim != out_dim
else None
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else None
def __call__(self, x, feat_cache=None, feat_idx=None):
h = self.shortcut(x) if self.shortcut is not None else x
@@ -182,9 +276,7 @@ class ResidualBlockLayers(nn.Module):
# Save last CACHE_T frames before conv (for next chunk's context)
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
out = conv(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -231,7 +323,9 @@ class AttentionBlock(nn.Module):
x = self.norm(x)
# QKV via 1x1 conv2d (equivalent to linear on last dim)
qkv = mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias # [BT, H, W, 3C]
qkv = (
mx.conv_general(x, self.to_qkv_weight) + self.to_qkv_bias
) # [BT, H, W, 3C]
qkv = qkv.reshape(B * T, H * W, 3 * C)
q, k, v = mx.split(qkv, 3, axis=-1) # each [BT, HW, C]
@@ -240,8 +334,10 @@ class AttentionBlock(nn.Module):
k = k[:, None, :, :]
v = v[:, None, :, :]
scale = C ** -0.5
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) # [BT, 1, HW, C]
scale = C**-0.5
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale
) # [BT, 1, HW, C]
out = out.squeeze(1).reshape(B * T, H, W, C)
# Project output
@@ -270,16 +366,24 @@ class DupUp3D(nn.Module):
x = mx.repeat(x, self.repeats, axis=-1) # [B, T, H, W, C*repeats]
# Reshape to [B, T, H, W, out_C, factor_t, factor_s, factor_s]
x = x.reshape(B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s)
x = x.reshape(
B, T, H, W, self.out_channels, self.factor_t, self.factor_s, self.factor_s
)
# Permute to interleave: [B, T, factor_t, H, factor_s, W, factor_s, out_C]
x = x.transpose(0, 1, 5, 2, 6, 3, 7, 4)
# Reshape to final: [B, T*factor_t, H*factor_s, W*factor_s, out_C]
x = x.reshape(B, T * self.factor_t, H * self.factor_s, W * self.factor_s, self.out_channels)
x = x.reshape(
B,
T * self.factor_t,
H * self.factor_s,
W * self.factor_s,
self.out_channels,
)
if first_chunk:
x = x[:, self.factor_t - 1:, :, :, :]
x = x[:, self.factor_t - 1 :, :, :, :]
return x
@@ -348,7 +452,9 @@ class Resample(nn.Module):
self.resample_weight = mx.zeros((dim, 3, 3, dim))
self.resample_bias = mx.zeros((dim,))
# time_conv: CausalConv3d(dim, dim, (3,1,1), stride=(2,1,1))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
)
else:
raise ValueError(f"Unsupported mode: {mode}")
@@ -369,7 +475,9 @@ class Resample(nn.Module):
"""Apply strided Conv2d for downsampling. x: [N, H, W, C]."""
# ZeroPad2d((0,1,0,1)): pad right=1, bottom=1
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
return mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
return (
mx.conv_general(x, self.resample_weight, stride=(2, 2)) + self.resample_bias
)
def __call__(self, x, first_chunk=False, feat_cache=None, feat_idx=None):
# x: [B, T, H, W, C]
@@ -444,14 +552,17 @@ class Resample(nn.Module):
class Up_ResidualBlock(nn.Module):
"""Upsampling residual block with optional DupUp3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False):
def __init__(
self, in_dim, out_dim, num_res_blocks, temperal_upsample=False, up_flag=False
):
super().__init__()
self.up_flag = up_flag
# DupUp3D shortcut (no learnable params)
if up_flag:
self.avg_shortcut = DupUp3D(
in_dim, out_dim,
in_dim,
out_dim,
factor_t=2 if temperal_upsample else 1,
factor_s=2 if up_flag else 1,
)
@@ -490,13 +601,21 @@ class Up_ResidualBlock(nn.Module):
class Down_ResidualBlock(nn.Module):
"""Downsampling residual block with AvgDown3D shortcut."""
def __init__(self, in_dim, out_dim, num_res_blocks, temperal_downsample=False, down_flag=False):
def __init__(
self,
in_dim,
out_dim,
num_res_blocks,
temperal_downsample=False,
down_flag=False,
):
super().__init__()
self.down_flag = down_flag
# AvgDown3D shortcut (no learnable params, always present)
self.avg_shortcut = AvgDown3D(
in_dim, out_dim,
in_dim,
out_dim,
factor_t=2 if temperal_downsample else 1,
factor_s=2 if down_flag else 1,
)
@@ -562,13 +681,15 @@ class Decoder3d(nn.Module):
self.upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
t_up = temperal_upsample[i] if i < len(temperal_upsample) else False
self.upsamples.append(Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
))
self.upsamples.append(
Up_ResidualBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks + 1,
temperal_upsample=t_up,
up_flag=(i != len(dim_mult) - 1),
)
)
# Output head: [RMS_norm, SiLU, CausalConv3d]
self.head = Head22(dims[-1])
@@ -612,13 +733,15 @@ class Encoder3d(nn.Module):
for i in range(len(dim_mult)):
in_d, out_d = dims[i], dims[i + 1]
t_down = temperal_downsample[i] if i < len(temperal_downsample) else False
self.downsamples.append(Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
))
self.downsamples.append(
Down_ResidualBlock(
in_dim=in_d,
out_dim=out_d,
num_res_blocks=num_res_blocks,
temperal_downsample=t_down,
down_flag=(i < len(dim_mult) - 1),
)
)
# Middle blocks (same as decoder)
out_dim = dims[-1]
@@ -658,9 +781,7 @@ class Encoder3d(nn.Module):
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
x = self.conv1(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -700,9 +821,7 @@ class Head22(nn.Module):
idx = feat_idx[0]
cache_x = x[:, -CACHE_T:]
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
cache_x = mx.concatenate(
[feat_cache[idx][:, -1:], cache_x], axis=1
)
cache_x = mx.concatenate([feat_cache[idx][:, -1:], cache_x], axis=1)
x = self.layer_2(x, cache_x=feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -768,7 +887,7 @@ class Wan22VAEEncoder(nn.Module):
if i == 0:
chunk = x[:, :1]
else:
chunk = x[:, 1 + 4 * (i - 1):1 + 4 * i]
chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i]
chunk_out = self.encoder(chunk, feat_cache=feat_cache, feat_idx=feat_idx)
if out is None:
out = chunk_out
@@ -778,7 +897,7 @@ class Wan22VAEEncoder(nn.Module):
# conv1 (pointwise) + split into mu, log_var
out = self.conv1(out)
mu = out[:, :, :, :, :self.z_dim]
mu = out[:, :, :, :, : self.z_dim]
# Normalize
mu = normalize_latents(mu)
@@ -885,8 +1004,8 @@ class Wan22VAEDecoder(nn.Module):
decoder_fn=tile_decode,
latents=z_cf,
tiling_config=tiling_config,
spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
spatial_scale=16, # 8× conv upsample + 2× unpatchify
temporal_scale=4, # two 2× temporal upsamples (first_chunk=True → causal)
causal_temporal=True,
)