Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
from .attention import AttentionType, AttnBlock, make_attn
|
||||
from .audio_vae import AudioDecoder, decode_audio
|
||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
from .downsample import Downsample, build_downsampling_path
|
||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
"""Audio VAE encoder and decoder for LTX-2."""
|
||||
|
||||
from typing import Set, Tuple
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import AudioDecoderModelConfig
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .causality_axis import CausalityAxis
|
||||
from .downsample import build_downsampling_path
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||
from .resnet import ResnetBlock
|
||||
@@ -67,22 +68,7 @@ class AudioDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch: int = 128,
|
||||
out_ch: int = 2,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Set[int] = None,
|
||||
resolution: int = 256,
|
||||
z_channels: int = 8,
|
||||
norm_type: NormType = NormType.PIXEL,
|
||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: int | None = None,
|
||||
config: AudioDecoderModelConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the AudioDecoder.
|
||||
@@ -105,86 +91,132 @@ class AudioDecoder(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if attn_resolutions is None:
|
||||
attn_resolutions = {8, 16, 32}
|
||||
|
||||
# Internal behavioral defaults
|
||||
resamp_with_conv = True
|
||||
attn_type = AttentionType.VANILLA
|
||||
|
||||
# Per-channel statistics for denormalizing latents
|
||||
# Uses ch (base channel count) to match the patchified latent dimension
|
||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
|
||||
# ch=128 matches this dimension, so use ch for per_channel_statistics
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||
self.sample_rate = config.sample_rate
|
||||
self.mel_hop_length = config.mel_hop_length
|
||||
self.is_causal = config.is_causal
|
||||
self.mel_bins = config.mel_bins
|
||||
|
||||
self.patchifier = AudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
sample_rate=config.sample_rate,
|
||||
hop_length=config.mel_hop_length,
|
||||
is_causal=config.is_causal,
|
||||
)
|
||||
|
||||
self.ch = ch
|
||||
self.ch = config.ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.out_ch = out_ch
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.z_channels = z_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
self.attn_type = attn_type
|
||||
self.num_resolutions = len(config.ch_mult)
|
||||
self.num_res_blocks = config.num_res_blocks
|
||||
self.resolution = config.resolution
|
||||
self.out_ch = config.out_ch
|
||||
self.give_pre_end = config.give_pre_end
|
||||
self.tanh_out = config.tanh_out
|
||||
self.norm_type = config.norm_type
|
||||
self.z_channels = config.z_channels
|
||||
self.channel_multipliers = config.ch_mult
|
||||
self.attn_resolutions = config.attn_resolutions
|
||||
self.causality_axis = config.causality_axis
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
base_block_channels = ch * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
||||
base_block_channels = config.ch * self.channel_multipliers[-1]
|
||||
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
self.mid = build_mid_block(
|
||||
channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
add_attention=mid_block_add_attention,
|
||||
add_attention=config.mid_block_add_attention,
|
||||
)
|
||||
|
||||
self.up, final_block_channels = build_upsampling_path(
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
ch=config.ch,
|
||||
ch_mult=config.ch_mult,
|
||||
num_resolutions=self.num_resolutions,
|
||||
num_res_blocks=num_res_blocks,
|
||||
resolution=resolution,
|
||||
num_res_blocks=config.num_res_blocks,
|
||||
resolution=config.resolution,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
dropout=config.dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
attn_type=self.attn_type,
|
||||
attn_resolutions=attn_resolutions,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
attn_resolutions=config.attn_resolutions,
|
||||
resamp_with_conv=config.resamp_with_conv,
|
||||
initial_block_channels=base_block_channels,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||
|
||||
Args:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for audio VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle audio_vae.decoder weights
|
||||
if key.startswith("audio_vae.decoder."):
|
||||
new_key = key.replace("audio_vae.decoder.", "")
|
||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||
# Map per-channel statistics
|
||||
if "mean-of-means" in key:
|
||||
new_key = "per_channel_statistics.mean_of_means"
|
||||
elif "std-of-means" in key:
|
||||
new_key = "per_channel_statistics.std_of_means"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
else:
|
||||
continue # Skip non-decoder keys
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||
"""Load audio VAE decoder from pretrained model."""
|
||||
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
||||
import json
|
||||
|
||||
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||
decoder = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
# weights = decoder.sanitize(weights)
|
||||
decoder.load_weights(list(weights.items()), strict=True)
|
||||
return decoder
|
||||
|
||||
|
||||
def __call__(self, sample: mx.array) -> mx.array:
|
||||
"""
|
||||
Decode latent features back to audio spectrograms.
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Tuple, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
|
||||
|
||||
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Causality axis enum for specifying causal convolution dimensions."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CausalityAxis(Enum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
@@ -6,7 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
@@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module):
|
||||
self.latent_channels = latent_channels
|
||||
# Initialize buffers - will be loaded from weights
|
||||
# Using underscores for MLX compatibility with weight loading
|
||||
self._std_of_means = mx.ones((latent_channels,))
|
||||
self._mean_of_means = mx.zeros((latent_channels,))
|
||||
self.std_of_means = mx.ones((latent_channels,))
|
||||
self.mean_of_means = mx.zeros((latent_channels,))
|
||||
|
||||
def un_normalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latent representation."""
|
||||
# Broadcast statistics to match x shape
|
||||
# x shape: (B, C, ...) or (B, ..., C)
|
||||
std = self._std_of_means.astype(x.dtype)
|
||||
mean = self._mean_of_means.astype(x.dtype)
|
||||
std = self.std_of_means.astype(x.dtype)
|
||||
mean = self.mean_of_means.astype(x.dtype)
|
||||
return (x * std) + mean
|
||||
|
||||
def normalize(self, x: mx.array) -> mx.array:
|
||||
"""Normalize latent representation."""
|
||||
std = self._std_of_means.astype(x.dtype)
|
||||
mean = self._mean_of_means.astype(x.dtype)
|
||||
std = self.std_of_means.astype(x.dtype)
|
||||
mean = self.mean_of_means.astype(x.dtype)
|
||||
return (x - mean) / std
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType, build_normalization_layer
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
@@ -7,7 +7,7 @@ import mlx.nn as nn
|
||||
|
||||
from .attention import AttentionType, make_attn
|
||||
from .causal_conv_2d import make_conv2d
|
||||
from .causality_axis import CausalityAxis
|
||||
from ..config import CausalityAxis
|
||||
from .normalization import NormType
|
||||
from .resnet import ResnetBlock
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Vocoder for converting mel spectrograms to audio waveforms."""
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_vlm.models.base import check_array_shape
|
||||
from ..config import VocoderModelConfig
|
||||
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
|
||||
|
||||
|
||||
@@ -27,44 +28,29 @@ class Vocoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
resblock_kernel_sizes: List[int] | None = None,
|
||||
upsample_rates: List[int] | None = None,
|
||||
upsample_kernel_sizes: List[int] | None = None,
|
||||
resblock_dilation_sizes: List[List[int]] | None = None,
|
||||
upsample_initial_channel: int = 1024,
|
||||
stereo: bool = True,
|
||||
resblock: str = "1",
|
||||
output_sample_rate: int = 24000,
|
||||
config: VocoderModelConfig
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Initialize default values if not provided
|
||||
if resblock_kernel_sizes is None:
|
||||
resblock_kernel_sizes = [3, 7, 11]
|
||||
if upsample_rates is None:
|
||||
upsample_rates = [6, 5, 2, 2, 2]
|
||||
if upsample_kernel_sizes is None:
|
||||
upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
||||
if resblock_dilation_sizes is None:
|
||||
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
|
||||
self.output_sample_rate = config.output_sample_rate
|
||||
self.num_kernels = len(config.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(config.upsample_rates)
|
||||
self.upsample_rates = config.upsample_rates
|
||||
self.upsample_kernel_sizes = config.upsample_kernel_sizes
|
||||
self.upsample_initial_channel = config.upsample_initial_channel
|
||||
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
|
||||
in_channels = 128 if config.stereo else 64
|
||||
self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
|
||||
|
||||
# Upsampling layers using ConvTranspose1d
|
||||
self.ups = {}
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
in_ch = upsample_initial_channel // (2**i)
|
||||
out_ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
|
||||
in_ch = config.upsample_initial_channel // (2**i)
|
||||
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
self.ups[i] = nn.ConvTranspose1d(
|
||||
in_ch,
|
||||
out_ch,
|
||||
@@ -77,16 +63,67 @@ class Vocoder(nn.Module):
|
||||
self.resblocks = {}
|
||||
block_idx = 0
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||
for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
|
||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||
block_idx += 1
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
||||
out_channels = 2 if config.stereo else 1
|
||||
final_channels = config.upsample_initial_channel // (2**self.num_upsamples)
|
||||
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsample_factor = math.prod(upsample_rates)
|
||||
self.upsample_factor = math.prod(config.upsample_rates)
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
sanitized = {}
|
||||
|
||||
if "vocoder." not in weights:
|
||||
return weights
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
# Handle vocoder weights
|
||||
if key.startswith("vocoder."):
|
||||
new_key = key.replace("vocoder.", "")
|
||||
|
||||
# Handle ModuleList indices -> dict keys
|
||||
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
|
||||
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
|
||||
|
||||
# Handle Conv1d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, kernel)
|
||||
# MLX: (out_channels, kernel, in_channels)
|
||||
if "weight" in new_key and value.ndim == 3:
|
||||
if "ups" in new_key:
|
||||
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0))
|
||||
else:
|
||||
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder":
|
||||
"""Load vocoder from pretrained model."""
|
||||
from mlx_video.models.ltx.config import VocoderModelConfig
|
||||
import json
|
||||
|
||||
config_dict = {}
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
config = VocoderModelConfig.from_dict(config_dict)
|
||||
model = cls(config)
|
||||
weights = mx.load(str(model_path / "model.safetensors"))
|
||||
|
||||
# weights = vocoder.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Tuple, Set
|
||||
|
||||
|
||||
class LTXModelType(Enum):
|
||||
@@ -180,3 +180,141 @@ class LTXModelConfig(BaseModelConfig):
|
||||
d_head=self.audio_attention_head_dim,
|
||||
context_dim=self.audio_cross_attention_dim,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class CausalityAxis(Enum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioDecoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
out_ch: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
mid_block_add_attention: bool = True
|
||||
sample_rate: int = 16000
|
||||
mel_hop_length: int = 160
|
||||
is_causal: bool = True
|
||||
mel_bins: int | None = None
|
||||
resamp_with_conv: bool = True
|
||||
attn_type: str = None
|
||||
give_pre_end: bool = False
|
||||
tanh_out: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.attn_resolutions is not None:
|
||||
result["attn_resolutions"] = list(self.attn_resolutions)
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
# Import here to avoid circular imports
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
|
||||
# Convert causality_axis string to enum
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
|
||||
# Convert norm_type string to enum
|
||||
if isinstance(self.norm_type, str):
|
||||
self.norm_type = NormType(self.norm_type)
|
||||
|
||||
# Convert attn_type string to enum
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
@dataclass
|
||||
class VocoderModelConfig(BaseModelConfig):
|
||||
resblock_kernel_sizes: Optional[List[int]] = None
|
||||
upsample_rates: Optional[List[int]] = None
|
||||
upsample_kernel_sizes: Optional[List[int]] = None
|
||||
resblock_dilation_sizes: Optional[List[List[int]]] = None
|
||||
upsample_initial_channel: int = 1024
|
||||
stereo: bool = True
|
||||
resblock: str = "1"
|
||||
output_sample_rate: int = 24000
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.resblock_kernel_sizes is None:
|
||||
self.resblock_kernel_sizes = [3, 7, 11]
|
||||
if self.upsample_rates is None:
|
||||
self.upsample_rates = [6, 5, 2, 2, 2]
|
||||
if self.upsample_kernel_sizes is None:
|
||||
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
||||
if self.resblock_dilation_sizes is None:
|
||||
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoDecoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
out_ch: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderModelConfig(BaseModelConfig):
|
||||
convolution_dimensions: int = 3
|
||||
in_channels : int = 3,
|
||||
out_channels: int = 128,
|
||||
patch_size: int = 4,
|
||||
norm_layer: Enum = None,
|
||||
latent_log_var: Enum = None,
|
||||
encoder_spatial_padding_mode: Enum = None,
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2})
|
||||
])
|
||||
|
||||
def __post_init__(self):
|
||||
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
|
||||
|
||||
if self.norm_layer is None:
|
||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||
if self.latent_log_var is None:
|
||||
self.latent_log_var = LogVarianceType.UNIFORM
|
||||
if self.encoder_spatial_padding_mode is None:
|
||||
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
|
||||
|
||||
if isinstance(self.norm_layer, str):
|
||||
self.norm_layer = NormLayerType(self.norm_layer)
|
||||
if isinstance(self.latent_log_var, str):
|
||||
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
||||
if isinstance(self.encoder_spatial_padding_mode, str):
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.encoder_blocks is not None:
|
||||
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
|
||||
return result
|
||||
@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from pathlib import Path
|
||||
|
||||
from mlx_video.models.ltx.config import (
|
||||
LTXModelConfig,
|
||||
LTXModelType,
|
||||
@@ -52,11 +52,10 @@ class TransformerArgsPreprocessor:
|
||||
self,
|
||||
timestep: mx.array,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
|
||||
|
||||
# Reshape to (batch, tokens, dim)
|
||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||
@@ -71,9 +70,6 @@ class TransformerArgsPreprocessor:
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# Context is already processed through embeddings connector in text encoder
|
||||
# Here we just apply the caption projection
|
||||
context = self.caption_projection(context)
|
||||
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||
return context, attention_mask
|
||||
@@ -118,21 +114,16 @@ class TransformerArgsPreprocessor:
|
||||
|
||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||
x = self.patchify_proj(modality.latent)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
|
||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||
|
||||
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
||||
if modality.positional_embeddings is not None:
|
||||
pe = modality.positional_embeddings
|
||||
else:
|
||||
pe = self._prepare_positional_embeddings(
|
||||
positions=modality.positions,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=self.max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
pe = self._prepare_positional_embeddings(
|
||||
positions=modality.positions,
|
||||
inner_dim=self.inner_dim,
|
||||
max_pos=self.max_pos,
|
||||
use_middle_indices_grid=self.use_middle_indices_grid,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
)
|
||||
|
||||
return TransformerArgs(
|
||||
x=x,
|
||||
@@ -207,7 +198,6 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
timestep=modality.timesteps,
|
||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||
batch_size=transformer_args.x.shape[0],
|
||||
hidden_dtype=transformer_args.x.dtype,
|
||||
)
|
||||
|
||||
return replace(
|
||||
@@ -222,16 +212,15 @@ class MultiModalTransformerArgsPreprocessor:
|
||||
timestep: mx.array,
|
||||
timestep_scale_multiplier: int,
|
||||
batch_size: int,
|
||||
hidden_dtype: mx.Dtype = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
timestep = timestep * timestep_scale_multiplier
|
||||
|
||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
|
||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||
|
||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
|
||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||
|
||||
return scale_shift_timestep, gate_timestep
|
||||
@@ -293,8 +282,6 @@ class LTXModel(nn.Module):
|
||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
||||
|
||||
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=config.audio_caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
@@ -397,9 +384,8 @@ class LTXModel(nn.Module):
|
||||
video_config = config.get_video_config()
|
||||
audio_config = config.get_audio_config()
|
||||
|
||||
|
||||
self.transformer_blocks = {
|
||||
idx: BasicAVTransformerBlock(
|
||||
self.transformer_blocks = [
|
||||
BasicAVTransformerBlock(
|
||||
idx=idx,
|
||||
video=video_config,
|
||||
audio=audio_config,
|
||||
@@ -407,7 +393,7 @@ class LTXModel(nn.Module):
|
||||
norm_eps=config.norm_eps,
|
||||
)
|
||||
for idx in range(config.num_layers)
|
||||
}
|
||||
]
|
||||
|
||||
def _process_transformer_blocks(
|
||||
self,
|
||||
@@ -415,7 +401,7 @@ class LTXModel(nn.Module):
|
||||
audio: Optional[TransformerArgs],
|
||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||
"""Process through all transformer blocks."""
|
||||
for block in self.transformer_blocks.values():
|
||||
for block in self.transformer_blocks:
|
||||
video, audio = block(video=video, audio=audio)
|
||||
return video, audio
|
||||
|
||||
@@ -497,50 +483,19 @@ class LTXModel(nn.Module):
|
||||
|
||||
def sanitize(self, weights: dict) -> dict:
|
||||
sanitized = {}
|
||||
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
||||
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
||||
continue
|
||||
|
||||
# Remove 'model.diffusion_model.' prefix
|
||||
new_key = new_key.replace("model.diffusion_model.", "")
|
||||
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
||||
|
||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
||||
|
||||
# Handle common remappings
|
||||
# transformer_blocks.X -> transformer_blocks[X]
|
||||
if "transformer_blocks." in new_key:
|
||||
# Keep as-is for now, MLX handles this
|
||||
pass
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: [Path, List[Path]], config: LTXModelConfig, strict: bool = True) -> None:
|
||||
model = cls(config)
|
||||
|
||||
weights = {}
|
||||
if isinstance(model_path, Path):
|
||||
model_path = [model_path]
|
||||
for weight_file in model_path:
|
||||
weights.update(mx.load(str(weight_file)))
|
||||
|
||||
|
||||
sanitized = model.sanitize(weights)
|
||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||
|
||||
model.load_weights(list(sanitized.items()), strict=strict)
|
||||
mx.eval(model.parameters())
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class X0Model(nn.Module):
|
||||
|
||||
|
||||
@@ -428,11 +428,14 @@ def _precompute_freqs_cis_double_precision(
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies with higher precision using float32.
|
||||
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
|
||||
|
||||
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
|
||||
Uses float32 for computation precision (sufficient for RoPE).
|
||||
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
|
||||
computation (log-spaced values), then converts to float32 for the final tensor.
|
||||
This provides better numerical precision in the frequency generation phase.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Warn if positions are bfloat16 - this causes quality degradation
|
||||
if indices_grid.dtype == mx.bfloat16:
|
||||
import warnings
|
||||
@@ -443,21 +446,27 @@ def _precompute_freqs_cis_double_precision(
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
|
||||
# Cast to float32 for position computation
|
||||
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||
|
||||
n_pos_dims = indices_grid_f32.shape[1]
|
||||
n_elem = 2 * n_pos_dims
|
||||
|
||||
# Compute log-spaced frequencies in float32
|
||||
log_start = math.log(1.0) / math.log(theta)
|
||||
log_end = math.log(theta) / math.log(theta)
|
||||
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
|
||||
# This is the critical precision step - PyTorch uses np.float64 here
|
||||
log_start = np.log(1.0) / np.log(theta)
|
||||
log_end = np.log(theta) / np.log(theta) # = 1.0
|
||||
num_indices = dim // n_elem
|
||||
if num_indices == 0:
|
||||
num_indices = 1
|
||||
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
|
||||
# Use numpy float64 for the linspace computation (matches PyTorch)
|
||||
pow_indices = np.power(
|
||||
theta,
|
||||
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
|
||||
)
|
||||
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
|
||||
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
|
||||
|
||||
# Handle middle indices grid
|
||||
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
|
||||
@@ -15,13 +15,14 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
@@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Per-channel statistics for denormalization (loaded from weights)
|
||||
self.latents_mean = mx.zeros((in_channels,))
|
||||
self.latents_std = mx.ones((in_channels,))
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
# Initial conv: 128 -> 1024
|
||||
class ConvInWrapper(nn.Module):
|
||||
@@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module):
|
||||
)
|
||||
self.last_scale_shift_table = mx.zeros((2, 128))
|
||||
|
||||
def denormalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision (statistics may be in bfloat16)
|
||||
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
return (x * std + mean).astype(dtype)
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Build decoder weights dict with key remapping
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||
continue
|
||||
|
||||
if key.startswith("vae.per_channel_statistics."):
|
||||
# Map per-channel statistics (use exact key matching)
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
weights = mx.load(str(model_path))
|
||||
|
||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||
if timestep_conditioning is None:
|
||||
try:
|
||||
with safe_open(str(model_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||
else:
|
||||
timestep_conditioning = False
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||
timestep_conditioning = False
|
||||
|
||||
model = cls(timestep_conditioning=timestep_conditioning)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
@@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module):
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
if debug:
|
||||
debug_stats("Input", sample)
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
if debug:
|
||||
debug_stats("After noise", sample)
|
||||
|
||||
|
||||
if debug:
|
||||
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
|
||||
sample = self.denormalize(sample)
|
||||
if debug:
|
||||
debug_stats("After denormalize", sample)
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
@@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_in", x)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
@@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
if debug:
|
||||
block_type = type(block).__name__
|
||||
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
if debug:
|
||||
debug_stats("After pixel_norm", x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
@@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
if debug:
|
||||
debug_stats("After timestep modulation", x)
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
if debug:
|
||||
debug_stats("After activation", x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_out", x)
|
||||
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
if debug:
|
||||
debug_stats("After unpatchify", x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
@@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module):
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
||||
from pathlib import Path
|
||||
import json
|
||||
from safetensors import safe_open
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE decoder from {weights_path}...")
|
||||
|
||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||
if timestep_conditioning is None:
|
||||
try:
|
||||
with safe_open(str(weights_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||
else:
|
||||
timestep_conditioning = False
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||
timestep_conditioning = False
|
||||
|
||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.decoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
elif has_decoder_prefix:
|
||||
prefix = "decoder."
|
||||
stats_prefix = ""
|
||||
else:
|
||||
prefix = ""
|
||||
stats_prefix = ""
|
||||
|
||||
# Load per-channel statistics for denormalization
|
||||
# Note: use std-of-means (not mean-of-stds) for proper denormalization
|
||||
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
|
||||
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
|
||||
|
||||
if mean_key in weights:
|
||||
decoder.latents_mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
|
||||
if std_key in weights:
|
||||
decoder.latents_std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
|
||||
|
||||
# Build decoder weights dict with key remapping
|
||||
decoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
decoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(decoder_weights)} decoder weights")
|
||||
|
||||
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
|
||||
print(f" Found {len(ts_keys)} timestep conditioning weights")
|
||||
|
||||
# Load weights
|
||||
decoder.load_weights(list(decoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE decoder loaded successfully")
|
||||
return decoder
|
||||
|
||||
@@ -5,152 +5,9 @@ Used for I2V (image-to-video) conditioning by encoding the input image
|
||||
to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Any, Optional
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
|
||||
|
||||
|
||||
def load_vae_encoder(model_path: str) -> VideoEncoder:
|
||||
"""Load VAE encoder from safetensors file.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model weights (safetensors file or directory)
|
||||
|
||||
Returns:
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE encoder from {weights_path}...")
|
||||
|
||||
# Read config from safetensors metadata
|
||||
encoder_blocks = []
|
||||
norm_layer = NormLayerType.PIXEL_NORM
|
||||
latent_log_var = LogVarianceType.UNIFORM
|
||||
patch_size = 4
|
||||
|
||||
try:
|
||||
with safe_open(str(weights_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
|
||||
# Parse encoder blocks
|
||||
raw_blocks = vae_config.get("encoder_blocks", [])
|
||||
for block in raw_blocks:
|
||||
if isinstance(block, list) and len(block) == 2:
|
||||
name, params = block
|
||||
encoder_blocks.append((name, params))
|
||||
|
||||
# Parse other config
|
||||
norm_str = vae_config.get("norm_layer", "pixel_norm")
|
||||
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
|
||||
|
||||
var_str = vae_config.get("latent_log_var", "uniform")
|
||||
if var_str == "uniform":
|
||||
latent_log_var = LogVarianceType.UNIFORM
|
||||
elif var_str == "per_channel":
|
||||
latent_log_var = LogVarianceType.PER_CHANNEL
|
||||
elif var_str == "constant":
|
||||
latent_log_var = LogVarianceType.CONSTANT
|
||||
else:
|
||||
latent_log_var = LogVarianceType.NONE
|
||||
|
||||
patch_size = vae_config.get("patch_size", 4)
|
||||
|
||||
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}")
|
||||
# Use default config
|
||||
encoder_blocks = [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
|
||||
|
||||
# Create encoder
|
||||
encoder = VideoEncoder(
|
||||
convolution_dimensions=3,
|
||||
in_channels=3,
|
||||
out_channels=128,
|
||||
encoder_blocks=encoder_blocks,
|
||||
patch_size=patch_size,
|
||||
norm_layer=norm_layer,
|
||||
latent_log_var=latent_log_var,
|
||||
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.encoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
else:
|
||||
prefix = "encoder."
|
||||
stats_prefix = "per_channel_statistics."
|
||||
|
||||
# Load per-channel statistics for normalization
|
||||
mean_key = f"{stats_prefix}mean-of-means"
|
||||
std_key = f"{stats_prefix}std-of-means"
|
||||
|
||||
if mean_key in weights:
|
||||
encoder.per_channel_statistics.mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
|
||||
if std_key in weights:
|
||||
encoder.per_channel_statistics.std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {weights[std_key].shape}")
|
||||
|
||||
# Build encoder weights dict with key remapping
|
||||
encoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
encoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(encoder_weights)} encoder weights")
|
||||
|
||||
# Load weights
|
||||
encoder.load_weights(list(encoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE encoder loaded successfully")
|
||||
return encoder
|
||||
|
||||
|
||||
def encode_image(
|
||||
|
||||
@@ -273,9 +273,10 @@ class VideoEncoder(nn.Module):
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build encoder blocks - use dict with int keys for MLX parameter tracking
|
||||
# Build encoder blocks
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for i, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
@@ -287,7 +288,7 @@ class VideoEncoder(nn.Module):
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks[i] = block
|
||||
self.down_blocks[idx] = block
|
||||
|
||||
# Output normalization and convolution
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
@@ -341,7 +342,8 @@ class VideoEncoder(nn.Module):
|
||||
sample = self.conv_in(sample, causal=True)
|
||||
|
||||
# Process through encoder blocks
|
||||
for down_block in self.down_blocks.values():
|
||||
for i in range(len(self.down_blocks)):
|
||||
down_block = self.down_blocks[i]
|
||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||
sample = down_block(sample, causal=True)
|
||||
else:
|
||||
@@ -440,8 +442,9 @@ class VideoDecoder(nn.Module):
|
||||
)
|
||||
|
||||
# Build decoder blocks (reversed order)
|
||||
self.up_blocks = []
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
@@ -454,7 +457,7 @@ class VideoDecoder(nn.Module):
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
self.up_blocks.append(block)
|
||||
self.up_blocks[idx] = block
|
||||
|
||||
# Output normalization
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
@@ -509,7 +512,8 @@ class VideoDecoder(nn.Module):
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
# Process through decoder blocks
|
||||
for up_block in self.up_blocks:
|
||||
for i in range(len(self.up_blocks)):
|
||||
up_block = self.up_blocks[i]
|
||||
if isinstance(up_block, UNetMidBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
elif isinstance(up_block, ResnetBlock3D):
|
||||
|
||||
Reference in New Issue
Block a user