Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -1,11 +1,59 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights from mlx_video.convert import (
import os load_transformer_weights,
load_vae_weights,
load_audio_vae_weights,
load_vocoder_weights,
sanitize_audio_vae_weights,
sanitize_vocoder_weights,
)
# Audio VAE components
from mlx_video.models.ltx.audio_vae import (
AudioEncoder,
AudioDecoder,
Vocoder,
AudioProcessor,
decode_audio,
)
# Patchifiers
from mlx_video.components.patchifiers import (
VideoLatentPatchifier,
AudioPatchifier,
VideoLatentShape,
AudioLatentShape,
)
# Conditioning
from mlx_video.conditioning import (
VideoConditionByKeyframeIndex,
VideoConditionByLatentIndex,
)
__all__ = [ __all__ = [
# Models
"LTXModel", "LTXModel",
"LTXModelConfig", "LTXModelConfig",
# Weight loading
"load_transformer_weights", "load_transformer_weights",
"load_vae_weights", "load_vae_weights",
] "load_audio_vae_weights",
"load_vocoder_weights",
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" "sanitize_audio_vae_weights",
"sanitize_vocoder_weights",
# Audio VAE
"AudioEncoder",
"AudioDecoder",
"Vocoder",
"AudioProcessor",
"decode_audio",
# Patchifiers
"VideoLatentPatchifier",
"AudioPatchifier",
"VideoLatentShape",
"AudioLatentShape",
# Conditioning
"VideoConditionByKeyframeIndex",
"VideoConditionByLatentIndex",
]

View File

@@ -355,6 +355,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr
""" """
sanitized = {} sanitized = {}
if "audio_vae." in weights:
return weights
for key, value in weights.items(): for key, value in weights.items():
new_key = key new_key = key
@@ -364,9 +367,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr
elif key.startswith("audio_vae.per_channel_statistics."): elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics # Map per-channel statistics
if "mean-of-means" in key: if "mean-of-means" in key:
new_key = "per_channel_statistics._mean_of_means" new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key: elif "std-of-means" in key:
new_key = "per_channel_statistics._std_of_means" new_key = "per_channel_statistics.std_of_means"
else: else:
continue # Skip other statistics keys continue # Skip other statistics keys
else: else:

View File

@@ -3,7 +3,7 @@
from .attention import AttentionType, AttnBlock, make_attn from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, decode_audio from .audio_vae import AudioDecoder, decode_audio
from .causal_conv_2d import CausalConv2d, make_conv2d from .causal_conv_2d import CausalConv2d, make_conv2d
from .causality_axis import CausalityAxis from ..config import CausalityAxis
from .downsample import Downsample, build_downsampling_path from .downsample import Downsample, build_downsampling_path
from .normalization import NormType, PixelNorm, build_normalization_layer from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics

View File

@@ -1,14 +1,15 @@
"""Audio VAE encoder and decoder for LTX-2.""" """Audio VAE encoder and decoder for LTX-2."""
from typing import Set, Tuple from typing import Dict
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig
from .attention import AttentionType, make_attn from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis from ..config import CausalityAxis
from .downsample import build_downsampling_path
from .normalization import NormType, build_normalization_layer from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock from .resnet import ResnetBlock
@@ -67,22 +68,7 @@ class AudioDecoder(nn.Module):
def __init__( def __init__(
self, self,
*, config: AudioDecoderModelConfig,
ch: int = 128,
out_ch: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Set[int] = None,
resolution: int = 256,
z_channels: int = 8,
norm_type: NormType = NormType.PIXEL,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
dropout: float = 0.0,
mid_block_add_attention: bool = True,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: int | None = None,
) -> None: ) -> None:
""" """
Initialize the AudioDecoder. Initialize the AudioDecoder.
@@ -105,86 +91,132 @@ class AudioDecoder(nn.Module):
""" """
super().__init__() super().__init__()
if attn_resolutions is None:
attn_resolutions = {8, 16, 32}
# Internal behavioral defaults
resamp_with_conv = True
attn_type = AttentionType.VANILLA
# Per-channel statistics for denormalizing latents # Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension # Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16) # Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128) # After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
# ch=128 matches this dimension, so use ch for per_channel_statistics # ch=128 matches this dimension, so use ch for per_channel_statistics
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = sample_rate self.sample_rate = config.sample_rate
self.mel_hop_length = mel_hop_length self.mel_hop_length = config.mel_hop_length
self.is_causal = is_causal self.is_causal = config.is_causal
self.mel_bins = mel_bins self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier( self.patchifier = AudioPatchifier(
patch_size=1, patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate, sample_rate=config.sample_rate,
hop_length=mel_hop_length, hop_length=config.mel_hop_length,
is_causal=is_causal, is_causal=config.is_causal,
) )
self.ch = ch self.ch = config.ch
self.temb_ch = 0 self.temb_ch = 0
self.num_resolutions = len(ch_mult) self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = num_res_blocks self.num_res_blocks = config.num_res_blocks
self.resolution = resolution self.resolution = config.resolution
self.out_ch = out_ch self.out_ch = config.out_ch
self.give_pre_end = False self.give_pre_end = config.give_pre_end
self.tanh_out = False self.tanh_out = config.tanh_out
self.norm_type = norm_type self.norm_type = config.norm_type
self.z_channels = z_channels self.z_channels = config.z_channels
self.channel_multipliers = ch_mult self.channel_multipliers = config.ch_mult
self.attn_resolutions = attn_resolutions self.attn_resolutions = config.attn_resolutions
self.causality_axis = causality_axis self.causality_axis = config.causality_axis
self.attn_type = attn_type self.attn_type = config.attn_type
base_block_channels = ch * self.channel_multipliers[-1] base_block_channels = config.ch * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1)) base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, z_channels, base_resolution, base_resolution) self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d( self.conv_in = make_conv2d(
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
) )
self.mid = build_mid_block( self.mid = build_mid_block(
channels=base_block_channels, channels=base_block_channels,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout, dropout=config.dropout,
norm_type=self.norm_type, norm_type=self.norm_type,
causality_axis=self.causality_axis, causality_axis=self.causality_axis,
attn_type=self.attn_type, attn_type=self.attn_type,
add_attention=mid_block_add_attention, add_attention=config.mid_block_add_attention,
) )
self.up, final_block_channels = build_upsampling_path( self.up, final_block_channels = build_upsampling_path(
ch=ch, ch=config.ch,
ch_mult=ch_mult, ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions, num_resolutions=self.num_resolutions,
num_res_blocks=num_res_blocks, num_res_blocks=config.num_res_blocks,
resolution=resolution, resolution=config.resolution,
temb_channels=self.temb_ch, temb_channels=self.temb_ch,
dropout=dropout, dropout=config.dropout,
norm_type=self.norm_type, norm_type=self.norm_type,
causality_axis=self.causality_axis, causality_axis=self.causality_axis,
attn_type=self.attn_type, attn_type=self.attn_type,
attn_resolutions=attn_resolutions, attn_resolutions=config.attn_resolutions,
resamp_with_conv=resamp_with_conv, resamp_with_conv=config.resamp_with_conv,
initial_block_channels=base_block_channels, initial_block_channels=base_block_channels,
) )
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.conv_out = make_conv2d( self.conv_out = make_conv2d(
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
) )
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for audio VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle audio_vae.decoder weights
if key.startswith("audio_vae.decoder."):
new_key = key.replace("audio_vae.decoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue # Skip other statistics keys
else:
continue # Skip non-decoder keys
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model."""
from mlx_video.models.ltx.config import AudioDecoderModelConfig
import json
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
decoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = decoder.sanitize(weights)
decoder.load_weights(list(weights.items()), strict=True)
return decoder
def __call__(self, sample: mx.array) -> mx.array: def __call__(self, sample: mx.array) -> mx.array:
""" """
Decode latent features back to audio spectrograms. Decode latent features back to audio spectrograms.

View File

@@ -5,7 +5,7 @@ from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .causality_axis import CausalityAxis from ..config import CausalityAxis
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]: def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:

View File

@@ -1,12 +0,0 @@
"""Causality axis enum for specifying causal convolution dimensions."""
from enum import Enum
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"

View File

@@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .attention import AttentionType, make_attn from .attention import AttentionType, make_attn
from .causality_axis import CausalityAxis from ..config import CausalityAxis
from .normalization import NormType from .normalization import NormType
from .resnet import ResnetBlock from .resnet import ResnetBlock

View File

@@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module):
self.latent_channels = latent_channels self.latent_channels = latent_channels
# Initialize buffers - will be loaded from weights # Initialize buffers - will be loaded from weights
# Using underscores for MLX compatibility with weight loading # Using underscores for MLX compatibility with weight loading
self._std_of_means = mx.ones((latent_channels,)) self.std_of_means = mx.ones((latent_channels,))
self._mean_of_means = mx.zeros((latent_channels,)) self.mean_of_means = mx.zeros((latent_channels,))
def un_normalize(self, x: mx.array) -> mx.array: def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latent representation.""" """Denormalize latent representation."""
# Broadcast statistics to match x shape # Broadcast statistics to match x shape
# x shape: (B, C, ...) or (B, ..., C) # x shape: (B, C, ...) or (B, ..., C)
std = self._std_of_means.astype(x.dtype) std = self.std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype) mean = self.mean_of_means.astype(x.dtype)
return (x * std) + mean return (x * std) + mean
def normalize(self, x: mx.array) -> mx.array: def normalize(self, x: mx.array) -> mx.array:
"""Normalize latent representation.""" """Normalize latent representation."""
std = self._std_of_means.astype(x.dtype) std = self.std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype) mean = self.mean_of_means.astype(x.dtype)
return (x - mean) / std return (x - mean) / std

View File

@@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .causal_conv_2d import make_conv2d from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis from ..config import CausalityAxis
from .normalization import NormType, build_normalization_layer from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1 LRELU_SLOPE = 0.1

View File

@@ -7,7 +7,7 @@ import mlx.nn as nn
from .attention import AttentionType, make_attn from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis from ..config import CausalityAxis
from .normalization import NormType from .normalization import NormType
from .resnet import ResnetBlock from .resnet import ResnetBlock

View File

@@ -1,11 +1,12 @@
"""Vocoder for converting mel spectrograms to audio waveforms.""" """Vocoder for converting mel spectrograms to audio waveforms."""
import math import math
from typing import List from typing import Dict
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import VocoderModelConfig
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
@@ -27,44 +28,29 @@ class Vocoder(nn.Module):
def __init__( def __init__(
self, self,
resblock_kernel_sizes: List[int] | None = None, config: VocoderModelConfig
upsample_rates: List[int] | None = None,
upsample_kernel_sizes: List[int] | None = None,
resblock_dilation_sizes: List[List[int]] | None = None,
upsample_initial_channel: int = 1024,
stereo: bool = True,
resblock: str = "1",
output_sample_rate: int = 24000,
): ):
super().__init__() super().__init__()
# Initialize default values if not provided
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if upsample_rates is None:
upsample_rates = [6, 5, 2, 2, 2]
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [16, 15, 8, 4, 4]
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
self.output_sample_rate = output_sample_rate
self.num_kernels = len(resblock_kernel_sizes) self.output_sample_rate = config.output_sample_rate
self.num_upsamples = len(upsample_rates) self.num_kernels = len(config.resblock_kernel_sizes)
self.upsample_rates = upsample_rates self.num_upsamples = len(config.upsample_rates)
self.upsample_kernel_sizes = upsample_kernel_sizes self.upsample_rates = config.upsample_rates
self.upsample_initial_channel = upsample_initial_channel self.upsample_kernel_sizes = config.upsample_kernel_sizes
self.upsample_initial_channel = config.upsample_initial_channel
in_channels = 128 if stereo else 64 in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3) self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2 resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
# Upsampling layers using ConvTranspose1d # Upsampling layers using ConvTranspose1d
self.ups = {} self.ups = {}
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
in_ch = upsample_initial_channel // (2**i) in_ch = config.upsample_initial_channel // (2**i)
out_ch = upsample_initial_channel // (2 ** (i + 1)) out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d( self.ups[i] = nn.ConvTranspose1d(
in_ch, in_ch,
out_ch, out_ch,
@@ -77,16 +63,67 @@ class Vocoder(nn.Module):
self.resblocks = {} self.resblocks = {}
block_idx = 0 block_idx = 0
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1)) ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes): for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations)) self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
block_idx += 1 block_idx += 1
out_channels = 2 if stereo else 1 out_channels = 2 if config.stereo else 1
final_channels = upsample_initial_channel // (2**self.num_upsamples) final_channels = config.upsample_initial_channel // (2**self.num_upsamples)
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3) self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
self.upsample_factor = math.prod(upsample_rates) self.upsample_factor = math.prod(config.upsample_rates)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
if "vocoder." not in weights:
return weights
for key, value in weights.items():
new_key = key
# Handle vocoder weights
if key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
# Handle ModuleList indices -> dict keys
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
# Handle Conv1d weight shape conversion
# PyTorch: (out_channels, in_channels, kernel)
# MLX: (out_channels, kernel, in_channels)
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder":
"""Load vocoder from pretrained model."""
from mlx_video.models.ltx.config import VocoderModelConfig
import json
config_dict = {}
with open(model_path / "config.json", "r") as f:
config_dict = json.load(f)
config = VocoderModelConfig.from_dict(config_dict)
model = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = vocoder.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
""" """

View File

@@ -2,7 +2,7 @@
import inspect import inspect
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, List, Optional from typing import Any, List, Optional, Tuple, Set
class LTXModelType(Enum): class LTXModelType(Enum):
@@ -180,3 +180,141 @@ class LTXModelConfig(BaseModelConfig):
d_head=self.audio_attention_head_dim, d_head=self.audio_attention_head_dim,
context_dim=self.audio_cross_attention_dim, context_dim=self.audio_cross_attention_dim,
) )
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
@dataclass
class AudioDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int | None = None
resamp_with_conv: bool = True
attn_type: str = None
give_pre_end: bool = False
tanh_out: bool = False
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
# Import here to avoid circular imports
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
# Convert causality_axis string to enum
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
# Convert norm_type string to enum
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
# Convert attn_type string to enum
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class VocoderModelConfig(BaseModelConfig):
resblock_kernel_sizes: Optional[List[int]] = None
upsample_rates: Optional[List[int]] = None
upsample_kernel_sizes: Optional[List[int]] = None
resblock_dilation_sizes: Optional[List[List[int]]] = None
upsample_initial_channel: int = 1024
stereo: bool = True
resblock: str = "1"
output_sample_rate: int = 24000
def __post_init__(self):
if self.resblock_kernel_sizes is None:
self.resblock_kernel_sizes = [3, 7, 11]
if self.upsample_rates is None:
self.upsample_rates = [6, 5, 2, 2, 2]
if self.upsample_kernel_sizes is None:
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
if self.resblock_dilation_sizes is None:
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@dataclass
class VideoDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
@dataclass
class VideoEncoderModelConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels : int = 3,
out_channels: int = 128,
patch_size: int = 4,
norm_layer: Enum = None,
latent_log_var: Enum = None,
encoder_spatial_padding_mode: Enum = None,
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2})
])
def __post_init__(self):
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM
if self.latent_log_var is None:
self.latent_log_var = LogVarianceType.UNIFORM
if self.encoder_spatial_padding_mode is None:
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
if isinstance(self.norm_layer, str):
self.norm_layer = NormLayerType(self.norm_layer)
if isinstance(self.latent_log_var, str):
self.latent_log_var = LogVarianceType(self.latent_log_var)
if isinstance(self.encoder_spatial_padding_mode, str):
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.encoder_blocks is not None:
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
return result

View File

@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from pathlib import Path
from mlx_video.models.ltx.config import ( from mlx_video.models.ltx.config import (
LTXModelConfig, LTXModelConfig,
LTXModelType, LTXModelType,
@@ -52,11 +52,10 @@ class TransformerArgsPreprocessor:
self, self,
timestep: mx.array, timestep: mx.array,
batch_size: int, batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
# Reshape to (batch, tokens, dim) # Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
@@ -71,9 +70,6 @@ class TransformerArgsPreprocessor:
attention_mask: Optional[mx.array] = None, attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, Optional[mx.array]]: ) -> Tuple[mx.array, Optional[mx.array]]:
batch_size = x.shape[0] batch_size = x.shape[0]
# Context is already processed through embeddings connector in text encoder
# Here we just apply the caption projection
context = self.caption_projection(context) context = self.caption_projection(context)
context = mx.reshape(context, (batch_size, -1, x.shape[-1])) context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
return context, attention_mask return context, attention_mask
@@ -118,21 +114,16 @@ class TransformerArgsPreprocessor:
def prepare(self, modality: Modality) -> TransformerArgs: def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent) x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation) positions=modality.positions,
if modality.positional_embeddings is not None: inner_dim=self.inner_dim,
pe = modality.positional_embeddings max_pos=self.max_pos,
else: use_middle_indices_grid=self.use_middle_indices_grid,
pe = self._prepare_positional_embeddings( num_attention_heads=self.num_attention_heads,
positions=modality.positions, )
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return TransformerArgs( return TransformerArgs(
x=x, x=x,
@@ -207,7 +198,6 @@ class MultiModalTransformerArgsPreprocessor:
timestep=modality.timesteps, timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0], batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
) )
return replace( return replace(
@@ -222,16 +212,15 @@ class MultiModalTransformerArgsPreprocessor:
timestep: mx.array, timestep: mx.array,
timestep_scale_multiplier: int, timestep_scale_multiplier: int,
batch_size: int, batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
timestep = timestep * timestep_scale_multiplier timestep = timestep * timestep_scale_multiplier
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])) scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype) gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1])) gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
return scale_shift_timestep, gate_timestep return scale_shift_timestep, gate_timestep
@@ -293,8 +282,6 @@ class LTXModel(nn.Module):
def _init_audio(self, config: LTXModelConfig) -> None: def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim) self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
self.audio_caption_projection = PixArtAlphaTextProjection( self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels, in_features=config.audio_caption_channels,
hidden_size=self.audio_inner_dim, hidden_size=self.audio_inner_dim,
@@ -397,9 +384,8 @@ class LTXModel(nn.Module):
video_config = config.get_video_config() video_config = config.get_video_config()
audio_config = config.get_audio_config() audio_config = config.get_audio_config()
self.transformer_blocks = [
self.transformer_blocks = { BasicAVTransformerBlock(
idx: BasicAVTransformerBlock(
idx=idx, idx=idx,
video=video_config, video=video_config,
audio=audio_config, audio=audio_config,
@@ -407,7 +393,7 @@ class LTXModel(nn.Module):
norm_eps=config.norm_eps, norm_eps=config.norm_eps,
) )
for idx in range(config.num_layers) for idx in range(config.num_layers)
} ]
def _process_transformer_blocks( def _process_transformer_blocks(
self, self,
@@ -415,7 +401,7 @@ class LTXModel(nn.Module):
audio: Optional[TransformerArgs], audio: Optional[TransformerArgs],
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks.""" """Process through all transformer blocks."""
for block in self.transformer_blocks.values(): for block in self.transformer_blocks:
video, audio = block(video=video, audio=audio) video, audio = block(video=video, audio=audio)
return video, audio return video, audio
@@ -497,50 +483,19 @@ class LTXModel(nn.Module):
def sanitize(self, weights: dict) -> dict: def sanitize(self, weights: dict) -> dict:
sanitized = {} sanitized = {}
for key, value in weights.items(): for key, value in weights.items():
new_key = key new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
# Remove 'model.diffusion_model.' prefix
new_key = new_key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle common remappings
# transformer_blocks.X -> transformer_blocks[X]
if "transformer_blocks." in new_key:
# Keep as-is for now, MLX handles this
pass
sanitized[new_key] = value sanitized[new_key] = value
return sanitized return sanitized
@classmethod
def from_pretrained(cls, model_path: [Path, List[Path]], config: LTXModelConfig, strict: bool = True) -> None:
model = cls(config)
weights = {}
if isinstance(model_path, Path):
model_path = [model_path]
for weight_file in model_path:
weights.update(mx.load(str(weight_file)))
sanitized = model.sanitize(weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
model.load_weights(list(sanitized.items()), strict=strict)
mx.eval(model.parameters())
model.eval()
return model
class X0Model(nn.Module): class X0Model(nn.Module):

View File

@@ -428,11 +428,14 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int, num_attention_heads: int,
rope_type: LTXRopeType, rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float32. """Compute RoPE frequencies with higher precision using float64 for frequency grid.
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips. Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
Uses float32 for computation precision (sufficient for RoPE). computation (log-spaced values), then converts to float32 for the final tensor.
This provides better numerical precision in the frequency generation phase.
""" """
import numpy as np
# Warn if positions are bfloat16 - this causes quality degradation # Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16: if indices_grid.dtype == mx.bfloat16:
import warnings import warnings
@@ -443,21 +446,27 @@ def _precompute_freqs_cis_double_precision(
stacklevel=2 stacklevel=2
) )
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion) # Cast to float32 for position computation
indices_grid_f32 = indices_grid.astype(mx.float32) indices_grid_f32 = indices_grid.astype(mx.float32)
n_pos_dims = indices_grid_f32.shape[1] n_pos_dims = indices_grid_f32.shape[1]
n_elem = 2 * n_pos_dims n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies in float32 # Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
log_start = math.log(1.0) / math.log(theta) # This is the critical precision step - PyTorch uses np.float64 here
log_end = math.log(theta) / math.log(theta) log_start = np.log(1.0) / np.log(theta)
log_end = np.log(theta) / np.log(theta) # = 1.0
num_indices = dim // n_elem num_indices = dim // n_elem
if num_indices == 0: if num_indices == 0:
num_indices = 1 num_indices = 1
lin_space = mx.linspace(log_start, log_end, num_indices) # Use numpy float64 for the linspace computation (matches PyTorch)
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2) pow_indices = np.power(
theta,
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
)
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
# Handle middle indices grid # Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise # Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise

View File

@@ -1,6 +1,6 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image from mlx_video.models.ltx.video_vae.encoder import encode_image
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder
from mlx_video.models.ltx.video_vae.tiling import ( from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig, TilingConfig,
SpatialTilingConfig, SpatialTilingConfig,

View File

@@ -15,13 +15,14 @@ Architecture (from PyTorch weights):
""" """
import math import math
from typing import Optional from typing import Optional, Dict
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
@@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module):
self.decode_timestep = 0.05 self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights) # Per-channel statistics for denormalization (loaded from weights)
self.latents_mean = mx.zeros((in_channels,)) self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
self.latents_std = mx.ones((in_channels,))
# Initial conv: 128 -> 1024 # Initial conv: 128 -> 1024
class ConvInWrapper(nn.Module): class ConvInWrapper(nn.Module):
@@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module):
) )
self.last_scale_shift_table = mx.zeros((2, 128)) self.last_scale_shift_table = mx.zeros((2, 128))
def denormalize(self, x: mx.array) -> mx.array: def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Denormalize latents using per-channel statistics.""" # Build decoder weights dict with key remapping
dtype = x.dtype sanitized = {}
# Cast to float32 for precision (statistics may be in bfloat16) for key, value in weights.items():
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) new_key = key
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return (x * std + mean).astype(dtype) if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
if key.startswith("vae.per_channel_statistics."):
# Map per-channel statistics (use exact key matching)
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue # Skip other statistics keys
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
from safetensors import safe_open
import json
weights = mx.load(str(model_path))
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(model_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
model = cls(timestep_conditioning=timestep_conditioning)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization.""" """Apply pixel normalization."""
@@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv: bool = False, chunked_conv: bool = False,
) -> mx.array: ) -> mx.array:
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
batch_size = sample.shape[0] batch_size = sample.shape[0]
if debug:
debug_stats("Input", sample)
# Add noise if timestep conditioning is enabled # Add noise if timestep conditioning is enabled
if self.timestep_conditioning: if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample sample = noise + (1.0 - self.decode_noise_scale) * sample
if debug:
debug_stats("After noise", sample)
if debug: sample = self.per_channel_statistics.un_normalize(sample)
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
sample = self.denormalize(sample)
if debug:
debug_stats("After denormalize", sample)
if timestep is None and self.timestep_conditioning: if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep) timestep = mx.full((batch_size,), self.decode_timestep)
@@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module):
scaled_timestep = timestep * self.timestep_scale_multiplier scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal) x = self.conv_in(sample, causal=causal)
if debug:
debug_stats("After conv_in", x)
for i, block in self.up_blocks.items(): for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup): if isinstance(block, ResBlockGroup):
@@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module):
x = block(x, causal=causal, chunked_conv=chunked_conv) x = block(x, causal=causal, chunked_conv=chunked_conv)
else: else:
x = block(x, causal=causal) x = block(x, causal=causal)
if debug:
block_type = type(block).__name__
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
x = self.pixel_norm(x) x = self.pixel_norm(x)
if debug:
debug_stats("After pixel_norm", x)
if self.timestep_conditioning and scaled_timestep is not None: if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder( embedded_timestep = self.last_time_embedder(
@@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module):
scale = ada_values[:, 1] scale = ada_values[:, 1]
x = x * (1 + scale) + shift x = x * (1 + scale) + shift
if debug:
debug_stats("After timestep modulation", x)
x = self.act(x) x = self.act(x)
if debug:
debug_stats("After activation", x)
x = self.conv_out(x, causal=causal) x = self.conv_out(x, causal=causal)
if debug:
debug_stats("After conv_out", x)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4) # Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1) x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
if debug:
debug_stats("After unpatchify", x)
return x return x
@@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv=use_chunked_conv, chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready, on_frames_ready=on_frames_ready,
) )
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path
import json
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...")
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.decoder."
stats_prefix = "vae.per_channel_statistics."
elif has_decoder_prefix:
prefix = "decoder."
stats_prefix = ""
else:
prefix = ""
stats_prefix = ""
# Load per-channel statistics for denormalization
# Note: use std-of-means (not mean-of-stds) for proper denormalization
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
if mean_key in weights:
decoder.latents_mean = weights[mean_key]
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
if std_key in weights:
decoder.latents_std = weights[std_key]
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
# Build decoder weights dict with key remapping
decoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
decoder_weights[new_key] = value
print(f" Found {len(decoder_weights)} decoder weights")
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
print(f" Found {len(ts_keys)} timestep conditioning weights")
# Load weights
decoder.load_weights(list(decoder_weights.items()), strict=False)
print("VAE decoder loaded successfully")
return decoder

View File

@@ -5,152 +5,9 @@ Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation. to latent space, which can then be used to condition video generation.
""" """
from pathlib import Path
from typing import List, Tuple, Any, Optional
import json
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
def load_vae_encoder(model_path: str) -> VideoEncoder:
"""Load VAE encoder from safetensors file.
Args:
model_path: Path to the model weights (safetensors file or directory)
Returns:
Loaded VideoEncoder instance
"""
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE encoder from {weights_path}...")
# Read config from safetensors metadata
encoder_blocks = []
norm_layer = NormLayerType.PIXEL_NORM
latent_log_var = LogVarianceType.UNIFORM
patch_size = 4
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
# Parse encoder blocks
raw_blocks = vae_config.get("encoder_blocks", [])
for block in raw_blocks:
if isinstance(block, list) and len(block) == 2:
name, params = block
encoder_blocks.append((name, params))
# Parse other config
norm_str = vae_config.get("norm_layer", "pixel_norm")
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
var_str = vae_config.get("latent_log_var", "uniform")
if var_str == "uniform":
latent_log_var = LogVarianceType.UNIFORM
elif var_str == "per_channel":
latent_log_var = LogVarianceType.PER_CHANNEL
elif var_str == "constant":
latent_log_var = LogVarianceType.CONSTANT
else:
latent_log_var = LogVarianceType.NONE
patch_size = vae_config.get("patch_size", 4)
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
except Exception as e:
print(f" Could not read config from metadata: {e}")
# Use default config
encoder_blocks = [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
# Create encoder
encoder = VideoEncoder(
convolution_dimensions=3,
in_channels=3,
out_channels=128,
encoder_blocks=encoder_blocks,
patch_size=patch_size,
norm_layer=norm_layer,
latent_log_var=latent_log_var,
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
)
# Load weights
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.encoder."
stats_prefix = "vae.per_channel_statistics."
else:
prefix = "encoder."
stats_prefix = "per_channel_statistics."
# Load per-channel statistics for normalization
mean_key = f"{stats_prefix}mean-of-means"
std_key = f"{stats_prefix}std-of-means"
if mean_key in weights:
encoder.per_channel_statistics.mean = weights[mean_key]
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
if std_key in weights:
encoder.per_channel_statistics.std = weights[std_key]
print(f" Loaded latent std: shape {weights[std_key].shape}")
# Build encoder weights dict with key remapping
encoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
encoder_weights[new_key] = value
print(f" Found {len(encoder_weights)} encoder weights")
# Load weights
encoder.load_weights(list(encoder_weights.items()), strict=False)
print("VAE encoder loaded successfully")
return encoder
def encode_image( def encode_image(

View File

@@ -273,9 +273,10 @@ class VideoEncoder(nn.Module):
spatial_padding_mode=encoder_spatial_padding_mode, spatial_padding_mode=encoder_spatial_padding_mode,
) )
# Build encoder blocks - use dict with int keys for MLX parameter tracking # Build encoder blocks
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {} self.down_blocks = {}
for i, (block_name, block_params) in enumerate(encoder_blocks): for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block( block, feature_channels = _make_encoder_block(
@@ -287,7 +288,7 @@ class VideoEncoder(nn.Module):
norm_num_groups=self._norm_num_groups, norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode, spatial_padding_mode=encoder_spatial_padding_mode,
) )
self.down_blocks[i] = block self.down_blocks[idx] = block
# Output normalization and convolution # Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM: if norm_layer == NormLayerType.GROUP_NORM:
@@ -341,7 +342,8 @@ class VideoEncoder(nn.Module):
sample = self.conv_in(sample, causal=True) sample = self.conv_in(sample, causal=True)
# Process through encoder blocks # Process through encoder blocks
for down_block in self.down_blocks.values(): for i in range(len(self.down_blocks)):
down_block = self.down_blocks[i]
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)): if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True) sample = down_block(sample, causal=True)
else: else:
@@ -440,8 +442,9 @@ class VideoDecoder(nn.Module):
) )
# Build decoder blocks (reversed order) # Build decoder blocks (reversed order)
self.up_blocks = [] # Use dict with int keys for MLX to track parameters (lists are NOT tracked)
for block_name, block_params in list(reversed(decoder_blocks)): self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_decoder_block( block, feature_channels = _make_decoder_block(
@@ -454,7 +457,7 @@ class VideoDecoder(nn.Module):
norm_num_groups=self._norm_num_groups, norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode, spatial_padding_mode=decoder_spatial_padding_mode,
) )
self.up_blocks.append(block) self.up_blocks[idx] = block
# Output normalization # Output normalization
if norm_layer == NormLayerType.GROUP_NORM: if norm_layer == NormLayerType.GROUP_NORM:
@@ -509,7 +512,8 @@ class VideoDecoder(nn.Module):
sample = self.conv_in(sample, causal=self.causal) sample = self.conv_in(sample, causal=self.causal)
# Process through decoder blocks # Process through decoder blocks
for up_block in self.up_blocks: for i in range(len(self.up_blocks)):
up_block = self.up_blocks[i]
if isinstance(up_block, UNetMidBlock3D): if isinstance(up_block, UNetMidBlock3D):
sample = up_block(sample, causal=self.causal) sample = up_block(sample, causal=self.causal)
elif isinstance(up_block, ResnetBlock3D): elif isinstance(up_block, ResnetBlock3D):