From 02bfa228d92e748e112e93e3207ca156afb77769 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 23 Jan 2026 17:31:25 +0100 Subject: [PATCH] Refactor weight loading and sanitization processes for audio models --- mlx_video/__init__.py | 58 ++++- mlx_video/convert.py | 7 +- mlx_video/models/ltx/audio_vae/__init__.py | 2 +- mlx_video/models/ltx/audio_vae/audio_vae.py | 154 +++++++----- .../models/ltx/audio_vae/causal_conv_2d.py | 2 +- .../models/ltx/audio_vae/causality_axis.py | 12 - mlx_video/models/ltx/audio_vae/downsample.py | 2 +- mlx_video/models/ltx/audio_vae/ops.py | 12 +- mlx_video/models/ltx/audio_vae/resnet.py | 2 +- mlx_video/models/ltx/audio_vae/upsample.py | 2 +- mlx_video/models/ltx/audio_vae/vocoder.py | 111 ++++++--- mlx_video/models/ltx/config.py | 140 ++++++++++- mlx_video/models/ltx/ltx.py | 87 ++----- mlx_video/models/ltx/rope.py | 27 ++- mlx_video/models/ltx/video_vae/__init__.py | 4 +- mlx_video/models/ltx/video_vae/decoder.py | 221 +++++++----------- mlx_video/models/ltx/video_vae/encoder.py | 145 +----------- mlx_video/models/ltx/video_vae/video_vae.py | 20 +- 18 files changed, 510 insertions(+), 498 deletions(-) delete mode 100644 mlx_video/models/ltx/audio_vae/causality_axis.py diff --git a/mlx_video/__init__.py b/mlx_video/__init__.py index f6a1720..07fd7c1 100644 --- a/mlx_video/__init__.py +++ b/mlx_video/__init__.py @@ -1,11 +1,59 @@ from mlx_video.models.ltx import LTXModel, LTXModelConfig -from mlx_video.convert import load_transformer_weights, load_vae_weights -import os +from mlx_video.convert import ( + load_transformer_weights, + load_vae_weights, + load_audio_vae_weights, + load_vocoder_weights, + sanitize_audio_vae_weights, + sanitize_vocoder_weights, +) + +# Audio VAE components +from mlx_video.models.ltx.audio_vae import ( + AudioEncoder, + AudioDecoder, + Vocoder, + AudioProcessor, + decode_audio, +) + +# Patchifiers +from mlx_video.components.patchifiers import ( + VideoLatentPatchifier, + AudioPatchifier, + VideoLatentShape, + AudioLatentShape, +) + +# Conditioning +from mlx_video.conditioning import ( + VideoConditionByKeyframeIndex, + VideoConditionByLatentIndex, +) + __all__ = [ + # Models "LTXModel", "LTXModelConfig", + # Weight loading "load_transformer_weights", "load_vae_weights", -] - -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" + "load_audio_vae_weights", + "load_vocoder_weights", + "sanitize_audio_vae_weights", + "sanitize_vocoder_weights", + # Audio VAE + "AudioEncoder", + "AudioDecoder", + "Vocoder", + "AudioProcessor", + "decode_audio", + # Patchifiers + "VideoLatentPatchifier", + "AudioPatchifier", + "VideoLatentShape", + "AudioLatentShape", + # Conditioning + "VideoConditionByKeyframeIndex", + "VideoConditionByLatentIndex", +] \ No newline at end of file diff --git a/mlx_video/convert.py b/mlx_video/convert.py index 11491e0..cbefd68 100644 --- a/mlx_video/convert.py +++ b/mlx_video/convert.py @@ -355,6 +355,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr """ sanitized = {} + if "audio_vae." in weights: + return weights + for key, value in weights.items(): new_key = key @@ -364,9 +367,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr elif key.startswith("audio_vae.per_channel_statistics."): # Map per-channel statistics if "mean-of-means" in key: - new_key = "per_channel_statistics._mean_of_means" + new_key = "per_channel_statistics.mean_of_means" elif "std-of-means" in key: - new_key = "per_channel_statistics._std_of_means" + new_key = "per_channel_statistics.std_of_means" else: continue # Skip other statistics keys else: diff --git a/mlx_video/models/ltx/audio_vae/__init__.py b/mlx_video/models/ltx/audio_vae/__init__.py index 5907e2d..8786118 100644 --- a/mlx_video/models/ltx/audio_vae/__init__.py +++ b/mlx_video/models/ltx/audio_vae/__init__.py @@ -3,7 +3,7 @@ from .attention import AttentionType, AttnBlock, make_attn from .audio_vae import AudioDecoder, decode_audio from .causal_conv_2d import CausalConv2d, make_conv2d -from .causality_axis import CausalityAxis +from ..config import CausalityAxis from .downsample import Downsample, build_downsampling_path from .normalization import NormType, PixelNorm, build_normalization_layer from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics diff --git a/mlx_video/models/ltx/audio_vae/audio_vae.py b/mlx_video/models/ltx/audio_vae/audio_vae.py index 08caec5..4c6f97b 100644 --- a/mlx_video/models/ltx/audio_vae/audio_vae.py +++ b/mlx_video/models/ltx/audio_vae/audio_vae.py @@ -1,14 +1,15 @@ """Audio VAE encoder and decoder for LTX-2.""" -from typing import Set, Tuple +from typing import Dict +from pathlib import Path import mlx.core as mx import mlx.nn as nn - +from mlx_vlm.models.base import check_array_shape +from ..config import AudioDecoderModelConfig from .attention import AttentionType, make_attn from .causal_conv_2d import make_conv2d -from .causality_axis import CausalityAxis -from .downsample import build_downsampling_path +from ..config import CausalityAxis from .normalization import NormType, build_normalization_layer from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics from .resnet import ResnetBlock @@ -67,22 +68,7 @@ class AudioDecoder(nn.Module): def __init__( self, - *, - ch: int = 128, - out_ch: int = 2, - ch_mult: Tuple[int, ...] = (1, 2, 4), - num_res_blocks: int = 2, - attn_resolutions: Set[int] = None, - resolution: int = 256, - z_channels: int = 8, - norm_type: NormType = NormType.PIXEL, - causality_axis: CausalityAxis = CausalityAxis.HEIGHT, - dropout: float = 0.0, - mid_block_add_attention: bool = True, - sample_rate: int = 16000, - mel_hop_length: int = 160, - is_causal: bool = True, - mel_bins: int | None = None, + config: AudioDecoderModelConfig, ) -> None: """ Initialize the AudioDecoder. @@ -105,86 +91,132 @@ class AudioDecoder(nn.Module): """ super().__init__() - if attn_resolutions is None: - attn_resolutions = {8, 16, 32} - - # Internal behavioral defaults - resamp_with_conv = True - attn_type = AttentionType.VANILLA # Per-channel statistics for denormalizing latents # Uses ch (base channel count) to match the patchified latent dimension # Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16) # After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128) # ch=128 matches this dimension, so use ch for per_channel_statistics - self.per_channel_statistics = PerChannelStatistics(latent_channels=ch) - self.sample_rate = sample_rate - self.mel_hop_length = mel_hop_length - self.is_causal = is_causal - self.mel_bins = mel_bins + self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch) + self.sample_rate = config.sample_rate + self.mel_hop_length = config.mel_hop_length + self.is_causal = config.is_causal + self.mel_bins = config.mel_bins self.patchifier = AudioPatchifier( patch_size=1, audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, - sample_rate=sample_rate, - hop_length=mel_hop_length, - is_causal=is_causal, + sample_rate=config.sample_rate, + hop_length=config.mel_hop_length, + is_causal=config.is_causal, ) - self.ch = ch + self.ch = config.ch self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.out_ch = out_ch - self.give_pre_end = False - self.tanh_out = False - self.norm_type = norm_type - self.z_channels = z_channels - self.channel_multipliers = ch_mult - self.attn_resolutions = attn_resolutions - self.causality_axis = causality_axis - self.attn_type = attn_type + self.num_resolutions = len(config.ch_mult) + self.num_res_blocks = config.num_res_blocks + self.resolution = config.resolution + self.out_ch = config.out_ch + self.give_pre_end = config.give_pre_end + self.tanh_out = config.tanh_out + self.norm_type = config.norm_type + self.z_channels = config.z_channels + self.channel_multipliers = config.ch_mult + self.attn_resolutions = config.attn_resolutions + self.causality_axis = config.causality_axis + self.attn_type = config.attn_type - base_block_channels = ch * self.channel_multipliers[-1] - base_resolution = resolution // (2 ** (self.num_resolutions - 1)) - self.z_shape = (1, z_channels, base_resolution, base_resolution) + base_block_channels = config.ch * self.channel_multipliers[-1] + base_resolution = config.resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, config.z_channels, base_resolution, base_resolution) self.conv_in = make_conv2d( - z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis ) self.mid = build_mid_block( channels=base_block_channels, temb_channels=self.temb_ch, - dropout=dropout, + dropout=config.dropout, norm_type=self.norm_type, causality_axis=self.causality_axis, attn_type=self.attn_type, - add_attention=mid_block_add_attention, + add_attention=config.mid_block_add_attention, ) self.up, final_block_channels = build_upsampling_path( - ch=ch, - ch_mult=ch_mult, + ch=config.ch, + ch_mult=config.ch_mult, num_resolutions=self.num_resolutions, - num_res_blocks=num_res_blocks, - resolution=resolution, + num_res_blocks=config.num_res_blocks, + resolution=config.resolution, temb_channels=self.temb_ch, - dropout=dropout, + dropout=config.dropout, norm_type=self.norm_type, causality_axis=self.causality_axis, attn_type=self.attn_type, - attn_resolutions=attn_resolutions, - resamp_with_conv=resamp_with_conv, + attn_resolutions=config.attn_resolutions, + resamp_with_conv=config.resamp_with_conv, initial_block_channels=base_block_channels, ) self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type) self.conv_out = make_conv2d( - final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis + final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis ) + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Sanitize audio VAE weight names from PyTorch format to MLX format. + + Args: + weights: Dictionary of weights with PyTorch naming + + Returns: + Dictionary with MLX-compatible naming for audio VAE decoder + """ + sanitized = {} + + for key, value in weights.items(): + new_key = key + + # Handle audio_vae.decoder weights + if key.startswith("audio_vae.decoder."): + new_key = key.replace("audio_vae.decoder.", "") + elif key.startswith("audio_vae.per_channel_statistics."): + # Map per-channel statistics + if "mean-of-means" in key: + new_key = "per_channel_statistics.mean_of_means" + elif "std-of-means" in key: + new_key = "per_channel_statistics.std_of_means" + else: + continue # Skip other statistics keys + else: + continue # Skip non-decoder keys + + # Handle Conv2d weight shape conversion + # PyTorch: (out_channels, in_channels, H, W) + # MLX: (out_channels, H, W, in_channels) + if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: + value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1)) + + sanitized[new_key] = value + + return sanitized + + @classmethod + def from_pretrained(cls, model_path: Path) -> "AudioDecoder": + """Load audio VAE decoder from pretrained model.""" + from mlx_video.models.ltx.config import AudioDecoderModelConfig + import json + + config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json"))) + decoder = cls(config) + weights = mx.load(str(model_path / "model.safetensors")) + # weights = decoder.sanitize(weights) + decoder.load_weights(list(weights.items()), strict=True) + return decoder + + def __call__(self, sample: mx.array) -> mx.array: """ Decode latent features back to audio spectrograms. diff --git a/mlx_video/models/ltx/audio_vae/causal_conv_2d.py b/mlx_video/models/ltx/audio_vae/causal_conv_2d.py index 2a38448..b303268 100644 --- a/mlx_video/models/ltx/audio_vae/causal_conv_2d.py +++ b/mlx_video/models/ltx/audio_vae/causal_conv_2d.py @@ -5,7 +5,7 @@ from typing import Tuple, Union import mlx.core as mx import mlx.nn as nn -from .causality_axis import CausalityAxis +from ..config import CausalityAxis def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]: diff --git a/mlx_video/models/ltx/audio_vae/causality_axis.py b/mlx_video/models/ltx/audio_vae/causality_axis.py deleted file mode 100644 index 15545b3..0000000 --- a/mlx_video/models/ltx/audio_vae/causality_axis.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Causality axis enum for specifying causal convolution dimensions.""" - -from enum import Enum - - -class CausalityAxis(Enum): - """Enum for specifying the causality axis in causal convolutions.""" - - NONE = None - WIDTH = "width" - HEIGHT = "height" - WIDTH_COMPATIBILITY = "width-compatibility" diff --git a/mlx_video/models/ltx/audio_vae/downsample.py b/mlx_video/models/ltx/audio_vae/downsample.py index 2f553c8..8831668 100644 --- a/mlx_video/models/ltx/audio_vae/downsample.py +++ b/mlx_video/models/ltx/audio_vae/downsample.py @@ -6,7 +6,7 @@ import mlx.core as mx import mlx.nn as nn from .attention import AttentionType, make_attn -from .causality_axis import CausalityAxis +from ..config import CausalityAxis from .normalization import NormType from .resnet import ResnetBlock diff --git a/mlx_video/models/ltx/audio_vae/ops.py b/mlx_video/models/ltx/audio_vae/ops.py index bf2d111..ae3cd30 100644 --- a/mlx_video/models/ltx/audio_vae/ops.py +++ b/mlx_video/models/ltx/audio_vae/ops.py @@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module): self.latent_channels = latent_channels # Initialize buffers - will be loaded from weights # Using underscores for MLX compatibility with weight loading - self._std_of_means = mx.ones((latent_channels,)) - self._mean_of_means = mx.zeros((latent_channels,)) + self.std_of_means = mx.ones((latent_channels,)) + self.mean_of_means = mx.zeros((latent_channels,)) def un_normalize(self, x: mx.array) -> mx.array: """Denormalize latent representation.""" # Broadcast statistics to match x shape # x shape: (B, C, ...) or (B, ..., C) - std = self._std_of_means.astype(x.dtype) - mean = self._mean_of_means.astype(x.dtype) + std = self.std_of_means.astype(x.dtype) + mean = self.mean_of_means.astype(x.dtype) return (x * std) + mean def normalize(self, x: mx.array) -> mx.array: """Normalize latent representation.""" - std = self._std_of_means.astype(x.dtype) - mean = self._mean_of_means.astype(x.dtype) + std = self.std_of_means.astype(x.dtype) + mean = self.mean_of_means.astype(x.dtype) return (x - mean) / std diff --git a/mlx_video/models/ltx/audio_vae/resnet.py b/mlx_video/models/ltx/audio_vae/resnet.py index c80d938..ca20f67 100644 --- a/mlx_video/models/ltx/audio_vae/resnet.py +++ b/mlx_video/models/ltx/audio_vae/resnet.py @@ -6,7 +6,7 @@ import mlx.core as mx import mlx.nn as nn from .causal_conv_2d import make_conv2d -from .causality_axis import CausalityAxis +from ..config import CausalityAxis from .normalization import NormType, build_normalization_layer LRELU_SLOPE = 0.1 diff --git a/mlx_video/models/ltx/audio_vae/upsample.py b/mlx_video/models/ltx/audio_vae/upsample.py index 731ac85..734ccab 100644 --- a/mlx_video/models/ltx/audio_vae/upsample.py +++ b/mlx_video/models/ltx/audio_vae/upsample.py @@ -7,7 +7,7 @@ import mlx.nn as nn from .attention import AttentionType, make_attn from .causal_conv_2d import make_conv2d -from .causality_axis import CausalityAxis +from ..config import CausalityAxis from .normalization import NormType from .resnet import ResnetBlock diff --git a/mlx_video/models/ltx/audio_vae/vocoder.py b/mlx_video/models/ltx/audio_vae/vocoder.py index 02b5393..f996d2f 100644 --- a/mlx_video/models/ltx/audio_vae/vocoder.py +++ b/mlx_video/models/ltx/audio_vae/vocoder.py @@ -1,11 +1,12 @@ """Vocoder for converting mel spectrograms to audio waveforms.""" import math -from typing import List - +from typing import Dict +from pathlib import Path import mlx.core as mx import mlx.nn as nn - +from mlx_vlm.models.base import check_array_shape +from ..config import VocoderModelConfig from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu @@ -27,44 +28,29 @@ class Vocoder(nn.Module): def __init__( self, - resblock_kernel_sizes: List[int] | None = None, - upsample_rates: List[int] | None = None, - upsample_kernel_sizes: List[int] | None = None, - resblock_dilation_sizes: List[List[int]] | None = None, - upsample_initial_channel: int = 1024, - stereo: bool = True, - resblock: str = "1", - output_sample_rate: int = 24000, + config: VocoderModelConfig ): super().__init__() - # Initialize default values if not provided - if resblock_kernel_sizes is None: - resblock_kernel_sizes = [3, 7, 11] - if upsample_rates is None: - upsample_rates = [6, 5, 2, 2, 2] - if upsample_kernel_sizes is None: - upsample_kernel_sizes = [16, 15, 8, 4, 4] - if resblock_dilation_sizes is None: - resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] - self.output_sample_rate = output_sample_rate - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.upsample_rates = upsample_rates - self.upsample_kernel_sizes = upsample_kernel_sizes - self.upsample_initial_channel = upsample_initial_channel + + self.output_sample_rate = config.output_sample_rate + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.upsample_rates = config.upsample_rates + self.upsample_kernel_sizes = config.upsample_kernel_sizes + self.upsample_initial_channel = config.upsample_initial_channel - in_channels = 128 if stereo else 64 - self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3) + in_channels = 128 if config.stereo else 64 + self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2 # Upsampling layers using ConvTranspose1d self.ups = {} - for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - in_ch = upsample_initial_channel // (2**i) - out_ch = upsample_initial_channel // (2 ** (i + 1)) + for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + in_ch = config.upsample_initial_channel // (2**i) + out_ch = config.upsample_initial_channel // (2 ** (i + 1)) self.ups[i] = nn.ConvTranspose1d( in_ch, out_ch, @@ -77,16 +63,67 @@ class Vocoder(nn.Module): self.resblocks = {} block_idx = 0 for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes): + ch = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations)) block_idx += 1 - out_channels = 2 if stereo else 1 - final_channels = upsample_initial_channel // (2**self.num_upsamples) + out_channels = 2 if config.stereo else 1 + final_channels = config.upsample_initial_channel // (2**self.num_upsamples) self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3) - self.upsample_factor = math.prod(upsample_rates) + self.upsample_factor = math.prod(config.upsample_rates) + + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + sanitized = {} + + if "vocoder." not in weights: + return weights + + for key, value in weights.items(): + new_key = key + + # Handle vocoder weights + if key.startswith("vocoder."): + new_key = key.replace("vocoder.", "") + + # Handle ModuleList indices -> dict keys + # PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ... + # PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ... + + # Handle Conv1d weight shape conversion + # PyTorch: (out_channels, in_channels, kernel) + # MLX: (out_channels, kernel, in_channels) + if "weight" in new_key and value.ndim == 3: + if "ups" in new_key: + # ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch) + value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0)) + else: + # Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch) + value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1)) + + sanitized[new_key] = value + + return sanitized + + @classmethod + def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder": + """Load vocoder from pretrained model.""" + from mlx_video.models.ltx.config import VocoderModelConfig + import json + + config_dict = {} + with open(model_path / "config.json", "r") as f: + config_dict = json.load(f) + + config = VocoderModelConfig.from_dict(config_dict) + model = cls(config) + weights = mx.load(str(model_path / "model.safetensors")) + + # weights = vocoder.sanitize(weights) + model.load_weights(list(weights.items()), strict=strict) + return model + def __call__(self, x: mx.array) -> mx.array: """ diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index 6ac9de2..2ca56a9 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -2,7 +2,7 @@ import inspect from dataclasses import dataclass, field from enum import Enum -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple, Set class LTXModelType(Enum): @@ -180,3 +180,141 @@ class LTXModelConfig(BaseModelConfig): d_head=self.audio_attention_head_dim, context_dim=self.audio_cross_attention_dim, ) + + + +class CausalityAxis(Enum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +@dataclass +class AudioDecoderModelConfig(BaseModelConfig): + ch: int = 128 + out_ch: int = 2 + ch_mult: Tuple[int, ...] = (1, 2, 4) + num_res_blocks: int = 2 + attn_resolutions: Optional[List[int]] = None + resolution: int = 256 + z_channels: int = 8 + norm_type: Enum = None + causality_axis: Enum = None + dropout: float = 0.0 + mid_block_add_attention: bool = True + sample_rate: int = 16000 + mel_hop_length: int = 160 + is_causal: bool = True + mel_bins: int | None = None + resamp_with_conv: bool = True + attn_type: str = None + give_pre_end: bool = False + tanh_out: bool = False + + def to_dict(self) -> dict[str, Any]: + result = super().to_dict() + if self.attn_resolutions is not None: + result["attn_resolutions"] = list(self.attn_resolutions) + return result + + def __post_init__(self): + """Convert string enum values to proper enum types.""" + # Import here to avoid circular imports + from .audio_vae.normalization import NormType + from .audio_vae.attention import AttentionType + + # Convert causality_axis string to enum + if isinstance(self.causality_axis, str): + self.causality_axis = CausalityAxis(self.causality_axis) + + # Convert norm_type string to enum + if isinstance(self.norm_type, str): + self.norm_type = NormType(self.norm_type) + + # Convert attn_type string to enum + if isinstance(self.attn_type, str): + self.attn_type = AttentionType(self.attn_type) + +@dataclass +class VocoderModelConfig(BaseModelConfig): + resblock_kernel_sizes: Optional[List[int]] = None + upsample_rates: Optional[List[int]] = None + upsample_kernel_sizes: Optional[List[int]] = None + resblock_dilation_sizes: Optional[List[List[int]]] = None + upsample_initial_channel: int = 1024 + stereo: bool = True + resblock: str = "1" + output_sample_rate: int = 24000 + + def __post_init__(self): + + if self.resblock_kernel_sizes is None: + self.resblock_kernel_sizes = [3, 7, 11] + if self.upsample_rates is None: + self.upsample_rates = [6, 5, 2, 2, 2] + if self.upsample_kernel_sizes is None: + self.upsample_kernel_sizes = [16, 15, 8, 4, 4] + if self.resblock_dilation_sizes is None: + self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + + +@dataclass +class VideoDecoderModelConfig(BaseModelConfig): + ch: int = 128 + out_ch: int = 2 + ch_mult: Tuple[int, ...] = (1, 2, 4) + num_res_blocks: int = 2 + attn_resolutions: Optional[List[int]] = None + resolution: int = 256 + z_channels: int = 8 + norm_type: Enum = None + causality_axis: Enum = None + dropout: float = 0.0 + +@dataclass +class VideoEncoderModelConfig(BaseModelConfig): + convolution_dimensions: int = 3 + in_channels : int = 3, + out_channels: int = 128, + patch_size: int = 4, + norm_layer: Enum = None, + latent_log_var: Enum = None, + encoder_spatial_padding_mode: Enum = None, + encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}), + ("compress_space_res", {"multiplier": 2}), + ("res_x", {"num_layers": 6}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 6}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}) + ]) + + def __post_init__(self): + from mlx_video.models.ltx.video_vae.resnet import NormLayerType + from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType + from mlx_video.models.ltx.video_vae.convolution import PaddingModeType + + if self.norm_layer is None: + self.norm_layer = NormLayerType.PIXEL_NORM + if self.latent_log_var is None: + self.latent_log_var = LogVarianceType.UNIFORM + if self.encoder_spatial_padding_mode is None: + self.encoder_spatial_padding_mode = PaddingModeType.ZEROS + + if isinstance(self.norm_layer, str): + self.norm_layer = NormLayerType(self.norm_layer) + if isinstance(self.latent_log_var, str): + self.latent_log_var = LogVarianceType(self.latent_log_var) + if isinstance(self.encoder_spatial_padding_mode, str): + self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode) + + def to_dict(self) -> dict[str, Any]: + result = super().to_dict() + if self.encoder_blocks is not None: + result["encoder_blocks"] = [list(block) for block in self.encoder_blocks] + return result \ No newline at end of file diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index f083485..e89f140 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from pathlib import Path + from mlx_video.models.ltx.config import ( LTXModelConfig, LTXModelType, @@ -52,11 +52,10 @@ class TransformerArgsPreprocessor: self, timestep: mx.array, batch_size: int, - hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * self.timestep_scale_multiplier - timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) + timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1)) # Reshape to (batch, tokens, dim) timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) @@ -71,9 +70,6 @@ class TransformerArgsPreprocessor: attention_mask: Optional[mx.array] = None, ) -> Tuple[mx.array, Optional[mx.array]]: batch_size = x.shape[0] - - # Context is already processed through embeddings connector in text encoder - # Here we just apply the caption projection context = self.caption_projection(context) context = mx.reshape(context, (batch_size, -1, x.shape[-1])) return context, attention_mask @@ -118,21 +114,16 @@ class TransformerArgsPreprocessor: def prepare(self, modality: Modality) -> TransformerArgs: x = self.patchify_proj(modality.latent) - timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) + timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0]) context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) - - # Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation) - if modality.positional_embeddings is not None: - pe = modality.positional_embeddings - else: - pe = self._prepare_positional_embeddings( - positions=modality.positions, - inner_dim=self.inner_dim, - max_pos=self.max_pos, - use_middle_indices_grid=self.use_middle_indices_grid, - num_attention_heads=self.num_attention_heads, - ) + pe = self._prepare_positional_embeddings( + positions=modality.positions, + inner_dim=self.inner_dim, + max_pos=self.max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) return TransformerArgs( x=x, @@ -207,7 +198,6 @@ class MultiModalTransformerArgsPreprocessor: timestep=modality.timesteps, timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, batch_size=transformer_args.x.shape[0], - hidden_dtype=transformer_args.x.dtype, ) return replace( @@ -222,16 +212,15 @@ class MultiModalTransformerArgsPreprocessor: timestep: mx.array, timestep_scale_multiplier: int, batch_size: int, - hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier - scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) + scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1)) scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])) - gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype) + gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor) gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1])) return scale_shift_timestep, gate_timestep @@ -293,8 +282,6 @@ class LTXModel(nn.Module): def _init_audio(self, config: LTXModelConfig) -> None: self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True) self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim) - - # Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector self.audio_caption_projection = PixArtAlphaTextProjection( in_features=config.audio_caption_channels, hidden_size=self.audio_inner_dim, @@ -397,9 +384,8 @@ class LTXModel(nn.Module): video_config = config.get_video_config() audio_config = config.get_audio_config() - - self.transformer_blocks = { - idx: BasicAVTransformerBlock( + self.transformer_blocks = [ + BasicAVTransformerBlock( idx=idx, video=video_config, audio=audio_config, @@ -407,7 +393,7 @@ class LTXModel(nn.Module): norm_eps=config.norm_eps, ) for idx in range(config.num_layers) - } + ] def _process_transformer_blocks( self, @@ -415,7 +401,7 @@ class LTXModel(nn.Module): audio: Optional[TransformerArgs], ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: """Process through all transformer blocks.""" - for block in self.transformer_blocks.values(): + for block in self.transformer_blocks: video, audio = block(video=video, audio=audio) return video, audio @@ -497,50 +483,19 @@ class LTXModel(nn.Module): def sanitize(self, weights: dict) -> dict: sanitized = {} - for key, value in weights.items(): new_key = key - # Skip non-transformer weights (VAE, vocoder, audio_vae, connectors) - if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key: - continue - - # Remove 'model.diffusion_model.' prefix - new_key = new_key.replace("model.diffusion_model.", "") - - new_key = new_key.replace(".to_out.0.", ".to_out.") - - new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") - new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") - new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.") - new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.") - - new_key = new_key.replace(".linear_1.", ".linear1.") - new_key = new_key.replace(".linear_2.", ".linear2.") + # Handle common remappings + # transformer_blocks.X -> transformer_blocks[X] + if "transformer_blocks." in new_key: + # Keep as-is for now, MLX handles this + pass sanitized[new_key] = value return sanitized - @classmethod - def from_pretrained(cls, model_path: [Path, List[Path]], config: LTXModelConfig, strict: bool = True) -> None: - model = cls(config) - - weights = {} - if isinstance(model_path, Path): - model_path = [model_path] - for weight_file in model_path: - weights.update(mx.load(str(weight_file))) - - - sanitized = model.sanitize(weights) - sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} - - model.load_weights(list(sanitized.items()), strict=strict) - mx.eval(model.parameters()) - model.eval() - return model - class X0Model(nn.Module): diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index 9e2db5f..66a8710 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -428,11 +428,14 @@ def _precompute_freqs_cis_double_precision( num_attention_heads: int, rope_type: LTXRopeType, ) -> Tuple[mx.array, mx.array]: - """Compute RoPE frequencies with higher precision using float32. + """Compute RoPE frequencies with higher precision using float64 for frequency grid. - This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips. - Uses float32 for computation precision (sufficient for RoPE). + Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid + computation (log-spaced values), then converts to float32 for the final tensor. + This provides better numerical precision in the frequency generation phase. """ + import numpy as np + # Warn if positions are bfloat16 - this causes quality degradation if indices_grid.dtype == mx.bfloat16: import warnings @@ -443,21 +446,27 @@ def _precompute_freqs_cis_double_precision( stacklevel=2 ) - # Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion) + # Cast to float32 for position computation indices_grid_f32 = indices_grid.astype(mx.float32) n_pos_dims = indices_grid_f32.shape[1] n_elem = 2 * n_pos_dims - # Compute log-spaced frequencies in float32 - log_start = math.log(1.0) / math.log(theta) - log_end = math.log(theta) / math.log(theta) + # Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np) + # This is the critical precision step - PyTorch uses np.float64 here + log_start = np.log(1.0) / np.log(theta) + log_end = np.log(theta) / np.log(theta) # = 1.0 num_indices = dim // n_elem if num_indices == 0: num_indices = 1 - lin_space = mx.linspace(log_start, log_end, num_indices) - freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2) + # Use numpy float64 for the linspace computation (matches PyTorch) + pow_indices = np.power( + theta, + np.linspace(log_start, log_end, num_indices, dtype=np.float64), + ) + # Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32)) + freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32) # Handle middle indices grid # Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise diff --git a/mlx_video/models/ltx/video_vae/__init__.py b/mlx_video/models/ltx/video_vae/__init__.py index bac1644..79f68cd 100644 --- a/mlx_video/models/ltx/video_vae/__init__.py +++ b/mlx_video/models/ltx/video_vae/__init__.py @@ -1,6 +1,6 @@ from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder -from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image -from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder +from mlx_video.models.ltx.video_vae.encoder import encode_image +from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder from mlx_video.models.ltx.video_vae.tiling import ( TilingConfig, SpatialTilingConfig, diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 9a6cbb3..7499238 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -15,13 +15,14 @@ Architecture (from PyTorch weights): """ import math -from typing import Optional +from typing import Optional, Dict +from pathlib import Path import mlx.core as mx import mlx.nn as nn from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType -from mlx_video.models.ltx.video_vae.ops import unpatchify +from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling @@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module): self.decode_timestep = 0.05 # Per-channel statistics for denormalization (loaded from weights) - self.latents_mean = mx.zeros((in_channels,)) - self.latents_std = mx.ones((in_channels,)) + self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) # Initial conv: 128 -> 1024 class ConvInWrapper(nn.Module): @@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module): ) self.last_scale_shift_table = mx.zeros((2, 128)) - def denormalize(self, x: mx.array) -> mx.array: - """Denormalize latents using per-channel statistics.""" - dtype = x.dtype - # Cast to float32 for precision (statistics may be in bfloat16) - mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) - std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1) - return (x * std + mean).astype(dtype) + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + # Build decoder weights dict with key remapping + sanitized = {} + for key, value in weights.items(): + new_key = key + + if not key.startswith("vae.") or key.startswith("vae.encoder."): + continue + + if key.startswith("vae.per_channel_statistics."): + # Map per-channel statistics (use exact key matching) + if key == "vae.per_channel_statistics.mean-of-means": + new_key = "per_channel_statistics.mean" + elif key == "vae.per_channel_statistics.std-of-means": + new_key = "per_channel_statistics.std" + else: + continue # Skip other statistics keys + + if key.startswith("vae.decoder."): + new_key = key.replace("vae.decoder.", "") + + + # Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I) + if ".conv.weight" in key and value.ndim == 5: + value = mx.transpose(value, (0, 2, 3, 4, 1)) + + if ".conv.bias" in key: + pass # bias doesn't need transpose + + if ".conv.weight" in new_key or ".conv.bias" in new_key: + + if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key: + new_key = new_key.replace(".conv.weight", ".conv.conv.weight") + new_key = new_key.replace(".conv.bias", ".conv.conv.bias") + + sanitized[new_key] = value + return sanitized + + @classmethod + def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder": + from safetensors import safe_open + import json + weights = mx.load(str(model_path)) + + # Read config from safetensors metadata to auto-detect timestep_conditioning + if timestep_conditioning is None: + try: + with safe_open(str(model_path), framework="numpy") as f: + metadata = f.metadata() + if metadata and "config" in metadata: + configs = json.loads(metadata["config"]) + vae_config = configs.get("vae", {}) + timestep_conditioning = vae_config.get("timestep_conditioning", False) + print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights") + else: + timestep_conditioning = False + except Exception as e: + print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False") + timestep_conditioning = False + + model = cls(timestep_conditioning=timestep_conditioning) + weights = model.sanitize(weights) + model.load_weights(list(weights.items()), strict=strict) + return model + + def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: """Apply pixel normalization.""" @@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module): chunked_conv: bool = False, ) -> mx.array: - def debug_stats(name, t): - if debug: - mx.eval(t) - print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}") batch_size = sample.shape[0] - if debug: - debug_stats("Input", sample) + # Add noise if timestep conditioning is enabled if self.timestep_conditioning: noise = mx.random.normal(sample.shape) * self.decode_noise_scale sample = noise + (1.0 - self.decode_noise_scale) * sample - if debug: - debug_stats("After noise", sample) + - if debug: - print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]") - sample = self.denormalize(sample) - if debug: - debug_stats("After denormalize", sample) + sample = self.per_channel_statistics.un_normalize(sample) + if timestep is None and self.timestep_conditioning: timestep = mx.full((batch_size,), self.decode_timestep) @@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module): scaled_timestep = timestep * self.timestep_scale_multiplier x = self.conv_in(sample, causal=causal) - if debug: - debug_stats("After conv_in", x) + for i, block in self.up_blocks.items(): if isinstance(block, ResBlockGroup): @@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module): x = block(x, causal=causal, chunked_conv=chunked_conv) else: x = block(x, causal=causal) - if debug: - block_type = type(block).__name__ - debug_stats(f"After up_blocks[{i}] ({block_type})", x) + x = self.pixel_norm(x) - if debug: - debug_stats("After pixel_norm", x) + if self.timestep_conditioning and scaled_timestep is not None: embedded_timestep = self.last_time_embedder( @@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module): scale = ada_values[:, 1] x = x * (1 + scale) + shift - if debug: - debug_stats("After timestep modulation", x) + x = self.act(x) - if debug: - debug_stats("After activation", x) + x = self.conv_out(x, causal=causal) - if debug: - debug_stats("After conv_out", x) - + # Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4) x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1) - if debug: - debug_stats("After unpatchify", x) + return x @@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module): chunked_conv=use_chunked_conv, on_frames_ready=on_frames_ready, ) - - -def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder: - from pathlib import Path - import json - from safetensors import safe_open - - model_path = Path(model_path) - - # Try to find the weights file - if model_path.is_file() and model_path.suffix == ".safetensors": - weights_path = model_path - elif (model_path / "ltx-2-19b-distilled.safetensors").exists(): - weights_path = model_path / "ltx-2-19b-distilled.safetensors" - elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists(): - weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors" - else: - raise FileNotFoundError(f"VAE weights not found at {model_path}") - - print(f"Loading VAE decoder from {weights_path}...") - - # Read config from safetensors metadata to auto-detect timestep_conditioning - if timestep_conditioning is None: - try: - with safe_open(str(weights_path), framework="numpy") as f: - metadata = f.metadata() - if metadata and "config" in metadata: - configs = json.loads(metadata["config"]) - vae_config = configs.get("vae", {}) - timestep_conditioning = vae_config.get("timestep_conditioning", False) - print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights") - else: - timestep_conditioning = False - except Exception as e: - print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False") - timestep_conditioning = False - - decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning) - - weights = mx.load(str(weights_path)) - - # Determine prefix based on weight keys - has_vae_prefix = any(k.startswith("vae.") for k in weights.keys()) - has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys()) - - if has_vae_prefix: - prefix = "vae.decoder." - stats_prefix = "vae.per_channel_statistics." - elif has_decoder_prefix: - prefix = "decoder." - stats_prefix = "" - else: - prefix = "" - stats_prefix = "" - - # Load per-channel statistics for denormalization - # Note: use std-of-means (not mean-of-stds) for proper denormalization - mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean" - std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std" - - if mean_key in weights: - decoder.latents_mean = weights[mean_key] - print(f" Loaded latent mean: shape {decoder.latents_mean.shape}") - if std_key in weights: - decoder.latents_std = weights[std_key] - print(f" Loaded latent std: shape {decoder.latents_std.shape}") - - # Build decoder weights dict with key remapping - decoder_weights = {} - for key, value in weights.items(): - if not key.startswith(prefix): - continue - - # Remove prefix - new_key = key[len(prefix):] - - # Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I) - if ".conv.weight" in key and value.ndim == 5: - value = mx.transpose(value, (0, 2, 3, 4, 1)) - if ".conv.bias" in key: - pass # bias doesn't need transpose - - - if ".conv.weight" in new_key or ".conv.bias" in new_key: - if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key: - new_key = new_key.replace(".conv.weight", ".conv.conv.weight") - new_key = new_key.replace(".conv.bias", ".conv.conv.bias") - - decoder_weights[new_key] = value - - print(f" Found {len(decoder_weights)} decoder weights") - - ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k] - print(f" Found {len(ts_keys)} timestep conditioning weights") - - # Load weights - decoder.load_weights(list(decoder_weights.items()), strict=False) - - print("VAE decoder loaded successfully") - return decoder diff --git a/mlx_video/models/ltx/video_vae/encoder.py b/mlx_video/models/ltx/video_vae/encoder.py index 6c90a4b..ed4dcc4 100644 --- a/mlx_video/models/ltx/video_vae/encoder.py +++ b/mlx_video/models/ltx/video_vae/encoder.py @@ -5,152 +5,9 @@ Used for I2V (image-to-video) conditioning by encoding the input image to latent space, which can then be used to condition video generation. """ -from pathlib import Path -from typing import List, Tuple, Any, Optional -import json - import mlx.core as mx -import mlx.nn as nn +from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder -from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType - - -def load_vae_encoder(model_path: str) -> VideoEncoder: - """Load VAE encoder from safetensors file. - - Args: - model_path: Path to the model weights (safetensors file or directory) - - Returns: - Loaded VideoEncoder instance - """ - from safetensors import safe_open - - model_path = Path(model_path) - - # Try to find the weights file - if model_path.is_file() and model_path.suffix == ".safetensors": - weights_path = model_path - elif (model_path / "ltx-2-19b-distilled.safetensors").exists(): - weights_path = model_path / "ltx-2-19b-distilled.safetensors" - elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists(): - weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors" - else: - raise FileNotFoundError(f"VAE weights not found at {model_path}") - - print(f"Loading VAE encoder from {weights_path}...") - - # Read config from safetensors metadata - encoder_blocks = [] - norm_layer = NormLayerType.PIXEL_NORM - latent_log_var = LogVarianceType.UNIFORM - patch_size = 4 - - try: - with safe_open(str(weights_path), framework="numpy") as f: - metadata = f.metadata() - if metadata and "config" in metadata: - configs = json.loads(metadata["config"]) - vae_config = configs.get("vae", {}) - - # Parse encoder blocks - raw_blocks = vae_config.get("encoder_blocks", []) - for block in raw_blocks: - if isinstance(block, list) and len(block) == 2: - name, params = block - encoder_blocks.append((name, params)) - - # Parse other config - norm_str = vae_config.get("norm_layer", "pixel_norm") - norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM - - var_str = vae_config.get("latent_log_var", "uniform") - if var_str == "uniform": - latent_log_var = LogVarianceType.UNIFORM - elif var_str == "per_channel": - latent_log_var = LogVarianceType.PER_CHANNEL - elif var_str == "constant": - latent_log_var = LogVarianceType.CONSTANT - else: - latent_log_var = LogVarianceType.NONE - - patch_size = vae_config.get("patch_size", 4) - - print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}") - except Exception as e: - print(f" Could not read config from metadata: {e}") - # Use default config - encoder_blocks = [ - ("res_x", {"num_layers": 4}), - ("compress_space_res", {"multiplier": 2}), - ("res_x", {"num_layers": 6}), - ("compress_time_res", {"multiplier": 2}), - ("res_x", {"num_layers": 6}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 2}), - ("compress_all_res", {"multiplier": 2}), - ("res_x", {"num_layers": 2}), - ] - print(f" Using default encoder config with {len(encoder_blocks)} blocks") - - # Create encoder - encoder = VideoEncoder( - convolution_dimensions=3, - in_channels=3, - out_channels=128, - encoder_blocks=encoder_blocks, - patch_size=patch_size, - norm_layer=norm_layer, - latent_log_var=latent_log_var, - encoder_spatial_padding_mode=PaddingModeType.ZEROS, - ) - - # Load weights - weights = mx.load(str(weights_path)) - - # Determine prefix based on weight keys - has_vae_prefix = any(k.startswith("vae.") for k in weights.keys()) - - if has_vae_prefix: - prefix = "vae.encoder." - stats_prefix = "vae.per_channel_statistics." - else: - prefix = "encoder." - stats_prefix = "per_channel_statistics." - - # Load per-channel statistics for normalization - mean_key = f"{stats_prefix}mean-of-means" - std_key = f"{stats_prefix}std-of-means" - - if mean_key in weights: - encoder.per_channel_statistics.mean = weights[mean_key] - print(f" Loaded latent mean: shape {weights[mean_key].shape}") - if std_key in weights: - encoder.per_channel_statistics.std = weights[std_key] - print(f" Loaded latent std: shape {weights[std_key].shape}") - - # Build encoder weights dict with key remapping - encoder_weights = {} - for key, value in weights.items(): - if not key.startswith(prefix): - continue - - # Remove prefix - new_key = key[len(prefix):] - - # Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I) - if ".weight" in key and value.ndim == 5: - value = mx.transpose(value, (0, 2, 3, 4, 1)) - - encoder_weights[new_key] = value - - print(f" Found {len(encoder_weights)} encoder weights") - - # Load weights - encoder.load_weights(list(encoder_weights.items()), strict=False) - - print("VAE encoder loaded successfully") - return encoder def encode_image( diff --git a/mlx_video/models/ltx/video_vae/video_vae.py b/mlx_video/models/ltx/video_vae/video_vae.py index cc3ec3a..af4349e 100644 --- a/mlx_video/models/ltx/video_vae/video_vae.py +++ b/mlx_video/models/ltx/video_vae/video_vae.py @@ -273,9 +273,10 @@ class VideoEncoder(nn.Module): spatial_padding_mode=encoder_spatial_padding_mode, ) - # Build encoder blocks - use dict with int keys for MLX parameter tracking + # Build encoder blocks + # Use dict with int keys for MLX to track parameters (lists are NOT tracked) self.down_blocks = {} - for i, (block_name, block_params) in enumerate(encoder_blocks): + for idx, (block_name, block_params) in enumerate(encoder_blocks): block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params block, feature_channels = _make_encoder_block( @@ -287,7 +288,7 @@ class VideoEncoder(nn.Module): norm_num_groups=self._norm_num_groups, spatial_padding_mode=encoder_spatial_padding_mode, ) - self.down_blocks[i] = block + self.down_blocks[idx] = block # Output normalization and convolution if norm_layer == NormLayerType.GROUP_NORM: @@ -341,7 +342,8 @@ class VideoEncoder(nn.Module): sample = self.conv_in(sample, causal=True) # Process through encoder blocks - for down_block in self.down_blocks.values(): + for i in range(len(self.down_blocks)): + down_block = self.down_blocks[i] if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)): sample = down_block(sample, causal=True) else: @@ -440,8 +442,9 @@ class VideoDecoder(nn.Module): ) # Build decoder blocks (reversed order) - self.up_blocks = [] - for block_name, block_params in list(reversed(decoder_blocks)): + # Use dict with int keys for MLX to track parameters (lists are NOT tracked) + self.up_blocks = {} + for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)): block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params block, feature_channels = _make_decoder_block( @@ -454,7 +457,7 @@ class VideoDecoder(nn.Module): norm_num_groups=self._norm_num_groups, spatial_padding_mode=decoder_spatial_padding_mode, ) - self.up_blocks.append(block) + self.up_blocks[idx] = block # Output normalization if norm_layer == NormLayerType.GROUP_NORM: @@ -509,7 +512,8 @@ class VideoDecoder(nn.Module): sample = self.conv_in(sample, causal=self.causal) # Process through decoder blocks - for up_block in self.up_blocks: + for i in range(len(self.up_blocks)): + up_block = self.up_blocks[i] if isinstance(up_block, UNetMidBlock3D): sample = up_block(sample, causal=self.causal) elif isinstance(up_block, ResnetBlock3D):