Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -1,11 +1,59 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights
import os
from mlx_video.convert import (
load_transformer_weights,
load_vae_weights,
load_audio_vae_weights,
load_vocoder_weights,
sanitize_audio_vae_weights,
sanitize_vocoder_weights,
)
# Audio VAE components
from mlx_video.models.ltx.audio_vae import (
AudioEncoder,
AudioDecoder,
Vocoder,
AudioProcessor,
decode_audio,
)
# Patchifiers
from mlx_video.components.patchifiers import (
VideoLatentPatchifier,
AudioPatchifier,
VideoLatentShape,
AudioLatentShape,
)
# Conditioning
from mlx_video.conditioning import (
VideoConditionByKeyframeIndex,
VideoConditionByLatentIndex,
)
__all__ = [
# Models
"LTXModel",
"LTXModelConfig",
# Weight loading
"load_transformer_weights",
"load_vae_weights",
]
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
"load_audio_vae_weights",
"load_vocoder_weights",
"sanitize_audio_vae_weights",
"sanitize_vocoder_weights",
# Audio VAE
"AudioEncoder",
"AudioDecoder",
"Vocoder",
"AudioProcessor",
"decode_audio",
# Patchifiers
"VideoLatentPatchifier",
"AudioPatchifier",
"VideoLatentShape",
"AudioLatentShape",
# Conditioning
"VideoConditionByKeyframeIndex",
"VideoConditionByLatentIndex",
]

View File

@@ -355,6 +355,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr
"""
sanitized = {}
if "audio_vae." in weights:
return weights
for key, value in weights.items():
new_key = key
@@ -364,9 +367,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics._mean_of_means"
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics._std_of_means"
new_key = "per_channel_statistics.std_of_means"
else:
continue # Skip other statistics keys
else:

View File

@@ -3,7 +3,7 @@
from .attention import AttentionType, AttnBlock, make_attn
from .audio_vae import AudioDecoder, decode_audio
from .causal_conv_2d import CausalConv2d, make_conv2d
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
from .downsample import Downsample, build_downsampling_path
from .normalization import NormType, PixelNorm, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics

View File

@@ -1,14 +1,15 @@
"""Audio VAE encoder and decoder for LTX-2."""
from typing import Set, Tuple
from typing import Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import AudioDecoderModelConfig
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from .downsample import build_downsampling_path
from ..config import CausalityAxis
from .normalization import NormType, build_normalization_layer
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
from .resnet import ResnetBlock
@@ -67,22 +68,7 @@ class AudioDecoder(nn.Module):
def __init__(
self,
*,
ch: int = 128,
out_ch: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Set[int] = None,
resolution: int = 256,
z_channels: int = 8,
norm_type: NormType = NormType.PIXEL,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
dropout: float = 0.0,
mid_block_add_attention: bool = True,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: int | None = None,
config: AudioDecoderModelConfig,
) -> None:
"""
Initialize the AudioDecoder.
@@ -105,86 +91,132 @@ class AudioDecoder(nn.Module):
"""
super().__init__()
if attn_resolutions is None:
attn_resolutions = {8, 16, 32}
# Internal behavioral defaults
resamp_with_conv = True
attn_type = AttentionType.VANILLA
# Per-channel statistics for denormalizing latents
# Uses ch (base channel count) to match the patchified latent dimension
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
# ch=128 matches this dimension, so use ch for per_channel_statistics
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
self.sample_rate = config.sample_rate
self.mel_hop_length = config.mel_hop_length
self.is_causal = config.is_causal
self.mel_bins = config.mel_bins
self.patchifier = AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
sample_rate=config.sample_rate,
hop_length=config.mel_hop_length,
is_causal=config.is_causal,
)
self.ch = ch
self.ch = config.ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.out_ch = out_ch
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.z_channels = z_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
self.attn_type = attn_type
self.num_resolutions = len(config.ch_mult)
self.num_res_blocks = config.num_res_blocks
self.resolution = config.resolution
self.out_ch = config.out_ch
self.give_pre_end = config.give_pre_end
self.tanh_out = config.tanh_out
self.norm_type = config.norm_type
self.z_channels = config.z_channels
self.channel_multipliers = config.ch_mult
self.attn_resolutions = config.attn_resolutions
self.causality_axis = config.causality_axis
self.attn_type = config.attn_type
base_block_channels = ch * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, z_channels, base_resolution, base_resolution)
base_block_channels = config.ch * self.channel_multipliers[-1]
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
self.mid = build_mid_block(
channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
add_attention=mid_block_add_attention,
add_attention=config.mid_block_add_attention,
)
self.up, final_block_channels = build_upsampling_path(
ch=ch,
ch_mult=ch_mult,
ch=config.ch,
ch_mult=config.ch_mult,
num_resolutions=self.num_resolutions,
num_res_blocks=num_res_blocks,
resolution=resolution,
num_res_blocks=config.num_res_blocks,
resolution=config.resolution,
temb_channels=self.temb_ch,
dropout=dropout,
dropout=config.dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
attn_type=self.attn_type,
attn_resolutions=attn_resolutions,
resamp_with_conv=resamp_with_conv,
attn_resolutions=config.attn_resolutions,
resamp_with_conv=config.resamp_with_conv,
initial_block_channels=base_block_channels,
)
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.conv_out = make_conv2d(
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for audio VAE decoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Handle audio_vae.decoder weights
if key.startswith("audio_vae.decoder."):
new_key = key.replace("audio_vae.decoder.", "")
elif key.startswith("audio_vae.per_channel_statistics."):
# Map per-channel statistics
if "mean-of-means" in key:
new_key = "per_channel_statistics.mean_of_means"
elif "std-of-means" in key:
new_key = "per_channel_statistics.std_of_means"
else:
continue # Skip other statistics keys
else:
continue # Skip non-decoder keys
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
"""Load audio VAE decoder from pretrained model."""
from mlx_video.models.ltx.config import AudioDecoderModelConfig
import json
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
decoder = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = decoder.sanitize(weights)
decoder.load_weights(list(weights.items()), strict=True)
return decoder
def __call__(self, sample: mx.array) -> mx.array:
"""
Decode latent features back to audio spectrograms.

View File

@@ -5,7 +5,7 @@ from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:

View File

@@ -1,12 +0,0 @@
"""Causality axis enum for specifying causal convolution dimensions."""
from enum import Enum
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"

View File

@@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock

View File

@@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module):
self.latent_channels = latent_channels
# Initialize buffers - will be loaded from weights
# Using underscores for MLX compatibility with weight loading
self._std_of_means = mx.ones((latent_channels,))
self._mean_of_means = mx.zeros((latent_channels,))
self.std_of_means = mx.ones((latent_channels,))
self.mean_of_means = mx.zeros((latent_channels,))
def un_normalize(self, x: mx.array) -> mx.array:
"""Denormalize latent representation."""
# Broadcast statistics to match x shape
# x shape: (B, C, ...) or (B, ..., C)
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x * std) + mean
def normalize(self, x: mx.array) -> mx.array:
"""Normalize latent representation."""
std = self._std_of_means.astype(x.dtype)
mean = self._mean_of_means.astype(x.dtype)
std = self.std_of_means.astype(x.dtype)
mean = self.mean_of_means.astype(x.dtype)
return (x - mean) / std

View File

@@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
from .normalization import NormType, build_normalization_layer
LRELU_SLOPE = 0.1

View File

@@ -7,7 +7,7 @@ import mlx.nn as nn
from .attention import AttentionType, make_attn
from .causal_conv_2d import make_conv2d
from .causality_axis import CausalityAxis
from ..config import CausalityAxis
from .normalization import NormType
from .resnet import ResnetBlock

View File

@@ -1,11 +1,12 @@
"""Vocoder for converting mel spectrograms to audio waveforms."""
import math
from typing import List
from typing import Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm.models.base import check_array_shape
from ..config import VocoderModelConfig
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
@@ -27,44 +28,29 @@ class Vocoder(nn.Module):
def __init__(
self,
resblock_kernel_sizes: List[int] | None = None,
upsample_rates: List[int] | None = None,
upsample_kernel_sizes: List[int] | None = None,
resblock_dilation_sizes: List[List[int]] | None = None,
upsample_initial_channel: int = 1024,
stereo: bool = True,
resblock: str = "1",
output_sample_rate: int = 24000,
config: VocoderModelConfig
):
super().__init__()
# Initialize default values if not provided
if resblock_kernel_sizes is None:
resblock_kernel_sizes = [3, 7, 11]
if upsample_rates is None:
upsample_rates = [6, 5, 2, 2, 2]
if upsample_kernel_sizes is None:
upsample_kernel_sizes = [16, 15, 8, 4, 4]
if resblock_dilation_sizes is None:
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
self.output_sample_rate = output_sample_rate
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.upsample_rates = upsample_rates
self.upsample_kernel_sizes = upsample_kernel_sizes
self.upsample_initial_channel = upsample_initial_channel
self.output_sample_rate = config.output_sample_rate
self.num_kernels = len(config.resblock_kernel_sizes)
self.num_upsamples = len(config.upsample_rates)
self.upsample_rates = config.upsample_rates
self.upsample_kernel_sizes = config.upsample_kernel_sizes
self.upsample_initial_channel = config.upsample_initial_channel
in_channels = 128 if stereo else 64
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
in_channels = 128 if config.stereo else 64
self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
# Upsampling layers using ConvTranspose1d
self.ups = {}
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
in_ch = upsample_initial_channel // (2**i)
out_ch = upsample_initial_channel // (2 ** (i + 1))
for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
in_ch = config.upsample_initial_channel // (2**i)
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
self.ups[i] = nn.ConvTranspose1d(
in_ch,
out_ch,
@@ -77,16 +63,67 @@ class Vocoder(nn.Module):
self.resblocks = {}
block_idx = 0
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
ch = config.upsample_initial_channel // (2 ** (i + 1))
for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
block_idx += 1
out_channels = 2 if stereo else 1
final_channels = upsample_initial_channel // (2**self.num_upsamples)
out_channels = 2 if config.stereo else 1
final_channels = config.upsample_initial_channel // (2**self.num_upsamples)
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
self.upsample_factor = math.prod(upsample_rates)
self.upsample_factor = math.prod(config.upsample_rates)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
if "vocoder." not in weights:
return weights
for key, value in weights.items():
new_key = key
# Handle vocoder weights
if key.startswith("vocoder."):
new_key = key.replace("vocoder.", "")
# Handle ModuleList indices -> dict keys
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
# Handle Conv1d weight shape conversion
# PyTorch: (out_channels, in_channels, kernel)
# MLX: (out_channels, kernel, in_channels)
if "weight" in new_key and value.ndim == 3:
if "ups" in new_key:
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0))
else:
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1))
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder":
"""Load vocoder from pretrained model."""
from mlx_video.models.ltx.config import VocoderModelConfig
import json
config_dict = {}
with open(model_path / "config.json", "r") as f:
config_dict = json.load(f)
config = VocoderModelConfig.from_dict(config_dict)
model = cls(config)
weights = mx.load(str(model_path / "model.safetensors"))
# weights = vocoder.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
def __call__(self, x: mx.array) -> mx.array:
"""

View File

@@ -2,7 +2,7 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple, Set
class LTXModelType(Enum):
@@ -180,3 +180,141 @@ class LTXModelConfig(BaseModelConfig):
d_head=self.audio_attention_head_dim,
context_dim=self.audio_cross_attention_dim,
)
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
@dataclass
class AudioDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int | None = None
resamp_with_conv: bool = True
attn_type: str = None
give_pre_end: bool = False
tanh_out: bool = False
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
# Import here to avoid circular imports
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
# Convert causality_axis string to enum
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
# Convert norm_type string to enum
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
# Convert attn_type string to enum
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class VocoderModelConfig(BaseModelConfig):
resblock_kernel_sizes: Optional[List[int]] = None
upsample_rates: Optional[List[int]] = None
upsample_kernel_sizes: Optional[List[int]] = None
resblock_dilation_sizes: Optional[List[List[int]]] = None
upsample_initial_channel: int = 1024
stereo: bool = True
resblock: str = "1"
output_sample_rate: int = 24000
def __post_init__(self):
if self.resblock_kernel_sizes is None:
self.resblock_kernel_sizes = [3, 7, 11]
if self.upsample_rates is None:
self.upsample_rates = [6, 5, 2, 2, 2]
if self.upsample_kernel_sizes is None:
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
if self.resblock_dilation_sizes is None:
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@dataclass
class VideoDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
@dataclass
class VideoEncoderModelConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels : int = 3,
out_channels: int = 128,
patch_size: int = 4,
norm_layer: Enum = None,
latent_log_var: Enum = None,
encoder_spatial_padding_mode: Enum = None,
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2})
])
def __post_init__(self):
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM
if self.latent_log_var is None:
self.latent_log_var = LogVarianceType.UNIFORM
if self.encoder_spatial_padding_mode is None:
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
if isinstance(self.norm_layer, str):
self.norm_layer = NormLayerType(self.norm_layer)
if isinstance(self.latent_log_var, str):
self.latent_log_var = LogVarianceType(self.latent_log_var)
if isinstance(self.encoder_spatial_padding_mode, str):
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.encoder_blocks is not None:
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
return result

View File

@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
from mlx_video.models.ltx.config import (
LTXModelConfig,
LTXModelType,
@@ -52,11 +52,10 @@ class TransformerArgsPreprocessor:
self,
timestep: mx.array,
batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * self.timestep_scale_multiplier
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
# Reshape to (batch, tokens, dim)
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
@@ -71,9 +70,6 @@ class TransformerArgsPreprocessor:
attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, Optional[mx.array]]:
batch_size = x.shape[0]
# Context is already processed through embeddings connector in text encoder
# Here we just apply the caption projection
context = self.caption_projection(context)
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
return context, attention_mask
@@ -118,21 +114,16 @@ class TransformerArgsPreprocessor:
def prepare(self, modality: Modality) -> TransformerArgs:
x = self.patchify_proj(modality.latent)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
if modality.positional_embeddings is not None:
pe = modality.positional_embeddings
else:
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return TransformerArgs(
x=x,
@@ -207,7 +198,6 @@ class MultiModalTransformerArgsPreprocessor:
timestep=modality.timesteps,
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
batch_size=transformer_args.x.shape[0],
hidden_dtype=transformer_args.x.dtype,
)
return replace(
@@ -222,16 +212,15 @@ class MultiModalTransformerArgsPreprocessor:
timestep: mx.array,
timestep_scale_multiplier: int,
batch_size: int,
hidden_dtype: mx.Dtype = None,
) -> Tuple[mx.array, mx.array]:
timestep = timestep * timestep_scale_multiplier
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
return scale_shift_timestep, gate_timestep
@@ -293,8 +282,6 @@ class LTXModel(nn.Module):
def _init_audio(self, config: LTXModelConfig) -> None:
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=config.audio_caption_channels,
hidden_size=self.audio_inner_dim,
@@ -397,9 +384,8 @@ class LTXModel(nn.Module):
video_config = config.get_video_config()
audio_config = config.get_audio_config()
self.transformer_blocks = {
idx: BasicAVTransformerBlock(
self.transformer_blocks = [
BasicAVTransformerBlock(
idx=idx,
video=video_config,
audio=audio_config,
@@ -407,7 +393,7 @@ class LTXModel(nn.Module):
norm_eps=config.norm_eps,
)
for idx in range(config.num_layers)
}
]
def _process_transformer_blocks(
self,
@@ -415,7 +401,7 @@ class LTXModel(nn.Module):
audio: Optional[TransformerArgs],
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks."""
for block in self.transformer_blocks.values():
for block in self.transformer_blocks:
video, audio = block(video=video, audio=audio)
return video, audio
@@ -497,50 +483,19 @@ class LTXModel(nn.Module):
def sanitize(self, weights: dict) -> dict:
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
continue
# Remove 'model.diffusion_model.' prefix
new_key = new_key.replace("model.diffusion_model.", "")
new_key = new_key.replace(".to_out.0.", ".to_out.")
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
new_key = new_key.replace(".linear_1.", ".linear1.")
new_key = new_key.replace(".linear_2.", ".linear2.")
# Handle common remappings
# transformer_blocks.X -> transformer_blocks[X]
if "transformer_blocks." in new_key:
# Keep as-is for now, MLX handles this
pass
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: [Path, List[Path]], config: LTXModelConfig, strict: bool = True) -> None:
model = cls(config)
weights = {}
if isinstance(model_path, Path):
model_path = [model_path]
for weight_file in model_path:
weights.update(mx.load(str(weight_file)))
sanitized = model.sanitize(weights)
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
model.load_weights(list(sanitized.items()), strict=strict)
mx.eval(model.parameters())
model.eval()
return model
class X0Model(nn.Module):

View File

@@ -428,11 +428,14 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies with higher precision using float32.
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
Uses float32 for computation precision (sufficient for RoPE).
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
computation (log-spaced values), then converts to float32 for the final tensor.
This provides better numerical precision in the frequency generation phase.
"""
import numpy as np
# Warn if positions are bfloat16 - this causes quality degradation
if indices_grid.dtype == mx.bfloat16:
import warnings
@@ -443,21 +446,27 @@ def _precompute_freqs_cis_double_precision(
stacklevel=2
)
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
# Cast to float32 for position computation
indices_grid_f32 = indices_grid.astype(mx.float32)
n_pos_dims = indices_grid_f32.shape[1]
n_elem = 2 * n_pos_dims
# Compute log-spaced frequencies in float32
log_start = math.log(1.0) / math.log(theta)
log_end = math.log(theta) / math.log(theta)
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
# This is the critical precision step - PyTorch uses np.float64 here
log_start = np.log(1.0) / np.log(theta)
log_end = np.log(theta) / np.log(theta) # = 1.0
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = mx.linspace(log_start, log_end, num_indices)
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
# Use numpy float64 for the linspace computation (matches PyTorch)
pow_indices = np.power(
theta,
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
)
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
# Handle middle indices grid
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise

View File

@@ -1,6 +1,6 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import encode_image
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder
from mlx_video.models.ltx.video_vae.tiling import (
TilingConfig,
SpatialTilingConfig,

View File

@@ -15,13 +15,14 @@ Architecture (from PyTorch weights):
"""
import math
from typing import Optional
from typing import Optional, Dict
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
from mlx_video.models.ltx.video_vae.ops import unpatchify
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
@@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module):
self.decode_timestep = 0.05
# Per-channel statistics for denormalization (loaded from weights)
self.latents_mean = mx.zeros((in_channels,))
self.latents_std = mx.ones((in_channels,))
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
# Initial conv: 128 -> 1024
class ConvInWrapper(nn.Module):
@@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module):
)
self.last_scale_shift_table = mx.zeros((2, 128))
def denormalize(self, x: mx.array) -> mx.array:
"""Denormalize latents using per-channel statistics."""
dtype = x.dtype
# Cast to float32 for precision (statistics may be in bfloat16)
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
return (x * std + mean).astype(dtype)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
sanitized = {}
for key, value in weights.items():
new_key = key
if not key.startswith("vae.") or key.startswith("vae.encoder."):
continue
if key.startswith("vae.per_channel_statistics."):
# Map per-channel statistics (use exact key matching)
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
continue # Skip other statistics keys
if key.startswith("vae.decoder."):
new_key = key.replace("vae.decoder.", "")
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
sanitized[new_key] = value
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
from safetensors import safe_open
import json
weights = mx.load(str(model_path))
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(model_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
model = cls(timestep_conditioning=timestep_conditioning)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
"""Apply pixel normalization."""
@@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv: bool = False,
) -> mx.array:
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
batch_size = sample.shape[0]
if debug:
debug_stats("Input", sample)
# Add noise if timestep conditioning is enabled
if self.timestep_conditioning:
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
sample = noise + (1.0 - self.decode_noise_scale) * sample
if debug:
debug_stats("After noise", sample)
if debug:
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
sample = self.denormalize(sample)
if debug:
debug_stats("After denormalize", sample)
sample = self.per_channel_statistics.un_normalize(sample)
if timestep is None and self.timestep_conditioning:
timestep = mx.full((batch_size,), self.decode_timestep)
@@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module):
scaled_timestep = timestep * self.timestep_scale_multiplier
x = self.conv_in(sample, causal=causal)
if debug:
debug_stats("After conv_in", x)
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
@@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module):
x = block(x, causal=causal, chunked_conv=chunked_conv)
else:
x = block(x, causal=causal)
if debug:
block_type = type(block).__name__
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
x = self.pixel_norm(x)
if debug:
debug_stats("After pixel_norm", x)
if self.timestep_conditioning and scaled_timestep is not None:
embedded_timestep = self.last_time_embedder(
@@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module):
scale = ada_values[:, 1]
x = x * (1 + scale) + shift
if debug:
debug_stats("After timestep modulation", x)
x = self.act(x)
if debug:
debug_stats("After activation", x)
x = self.conv_out(x, causal=causal)
if debug:
debug_stats("After conv_out", x)
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
if debug:
debug_stats("After unpatchify", x)
return x
@@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module):
chunked_conv=use_chunked_conv,
on_frames_ready=on_frames_ready,
)
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path
import json
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...")
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.decoder."
stats_prefix = "vae.per_channel_statistics."
elif has_decoder_prefix:
prefix = "decoder."
stats_prefix = ""
else:
prefix = ""
stats_prefix = ""
# Load per-channel statistics for denormalization
# Note: use std-of-means (not mean-of-stds) for proper denormalization
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
if mean_key in weights:
decoder.latents_mean = weights[mean_key]
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
if std_key in weights:
decoder.latents_std = weights[std_key]
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
# Build decoder weights dict with key remapping
decoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".conv.weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
if ".conv.bias" in key:
pass # bias doesn't need transpose
if ".conv.weight" in new_key or ".conv.bias" in new_key:
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
decoder_weights[new_key] = value
print(f" Found {len(decoder_weights)} decoder weights")
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
print(f" Found {len(ts_keys)} timestep conditioning weights")
# Load weights
decoder.load_weights(list(decoder_weights.items()), strict=False)
print("VAE decoder loaded successfully")
return decoder

View File

@@ -5,152 +5,9 @@ Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation.
"""
from pathlib import Path
from typing import List, Tuple, Any, Optional
import json
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
def load_vae_encoder(model_path: str) -> VideoEncoder:
"""Load VAE encoder from safetensors file.
Args:
model_path: Path to the model weights (safetensors file or directory)
Returns:
Loaded VideoEncoder instance
"""
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE encoder from {weights_path}...")
# Read config from safetensors metadata
encoder_blocks = []
norm_layer = NormLayerType.PIXEL_NORM
latent_log_var = LogVarianceType.UNIFORM
patch_size = 4
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
# Parse encoder blocks
raw_blocks = vae_config.get("encoder_blocks", [])
for block in raw_blocks:
if isinstance(block, list) and len(block) == 2:
name, params = block
encoder_blocks.append((name, params))
# Parse other config
norm_str = vae_config.get("norm_layer", "pixel_norm")
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
var_str = vae_config.get("latent_log_var", "uniform")
if var_str == "uniform":
latent_log_var = LogVarianceType.UNIFORM
elif var_str == "per_channel":
latent_log_var = LogVarianceType.PER_CHANNEL
elif var_str == "constant":
latent_log_var = LogVarianceType.CONSTANT
else:
latent_log_var = LogVarianceType.NONE
patch_size = vae_config.get("patch_size", 4)
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
except Exception as e:
print(f" Could not read config from metadata: {e}")
# Use default config
encoder_blocks = [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
# Create encoder
encoder = VideoEncoder(
convolution_dimensions=3,
in_channels=3,
out_channels=128,
encoder_blocks=encoder_blocks,
patch_size=patch_size,
norm_layer=norm_layer,
latent_log_var=latent_log_var,
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
)
# Load weights
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.encoder."
stats_prefix = "vae.per_channel_statistics."
else:
prefix = "encoder."
stats_prefix = "per_channel_statistics."
# Load per-channel statistics for normalization
mean_key = f"{stats_prefix}mean-of-means"
std_key = f"{stats_prefix}std-of-means"
if mean_key in weights:
encoder.per_channel_statistics.mean = weights[mean_key]
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
if std_key in weights:
encoder.per_channel_statistics.std = weights[std_key]
print(f" Loaded latent std: shape {weights[std_key].shape}")
# Build encoder weights dict with key remapping
encoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
encoder_weights[new_key] = value
print(f" Found {len(encoder_weights)} encoder weights")
# Load weights
encoder.load_weights(list(encoder_weights.items()), strict=False)
print("VAE encoder loaded successfully")
return encoder
def encode_image(

View File

@@ -273,9 +273,10 @@ class VideoEncoder(nn.Module):
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks - use dict with int keys for MLX parameter tracking
# Build encoder blocks
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.down_blocks = {}
for i, (block_name, block_params) in enumerate(encoder_blocks):
for idx, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
@@ -287,7 +288,7 @@ class VideoEncoder(nn.Module):
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks[i] = block
self.down_blocks[idx] = block
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
@@ -341,7 +342,8 @@ class VideoEncoder(nn.Module):
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for down_block in self.down_blocks.values():
for i in range(len(self.down_blocks)):
down_block = self.down_blocks[i]
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else:
@@ -440,8 +442,9 @@ class VideoDecoder(nn.Module):
)
# Build decoder blocks (reversed order)
self.up_blocks = []
for block_name, block_params in list(reversed(decoder_blocks)):
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
self.up_blocks = {}
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_decoder_block(
@@ -454,7 +457,7 @@ class VideoDecoder(nn.Module):
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=decoder_spatial_padding_mode,
)
self.up_blocks.append(block)
self.up_blocks[idx] = block
# Output normalization
if norm_layer == NormLayerType.GROUP_NORM:
@@ -509,7 +512,8 @@ class VideoDecoder(nn.Module):
sample = self.conv_in(sample, causal=self.causal)
# Process through decoder blocks
for up_block in self.up_blocks:
for i in range(len(self.up_blocks)):
up_block = self.up_blocks[i]
if isinstance(up_block, UNetMidBlock3D):
sample = up_block(sample, causal=self.causal)
elif isinstance(up_block, ResnetBlock3D):