Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -1,11 +1,59 @@
|
|||||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||||
from mlx_video.convert import load_transformer_weights, load_vae_weights
|
from mlx_video.convert import (
|
||||||
import os
|
load_transformer_weights,
|
||||||
|
load_vae_weights,
|
||||||
|
load_audio_vae_weights,
|
||||||
|
load_vocoder_weights,
|
||||||
|
sanitize_audio_vae_weights,
|
||||||
|
sanitize_vocoder_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio VAE components
|
||||||
|
from mlx_video.models.ltx.audio_vae import (
|
||||||
|
AudioEncoder,
|
||||||
|
AudioDecoder,
|
||||||
|
Vocoder,
|
||||||
|
AudioProcessor,
|
||||||
|
decode_audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patchifiers
|
||||||
|
from mlx_video.components.patchifiers import (
|
||||||
|
VideoLatentPatchifier,
|
||||||
|
AudioPatchifier,
|
||||||
|
VideoLatentShape,
|
||||||
|
AudioLatentShape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conditioning
|
||||||
|
from mlx_video.conditioning import (
|
||||||
|
VideoConditionByKeyframeIndex,
|
||||||
|
VideoConditionByLatentIndex,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Models
|
||||||
"LTXModel",
|
"LTXModel",
|
||||||
"LTXModelConfig",
|
"LTXModelConfig",
|
||||||
|
# Weight loading
|
||||||
"load_transformer_weights",
|
"load_transformer_weights",
|
||||||
"load_vae_weights",
|
"load_vae_weights",
|
||||||
|
"load_audio_vae_weights",
|
||||||
|
"load_vocoder_weights",
|
||||||
|
"sanitize_audio_vae_weights",
|
||||||
|
"sanitize_vocoder_weights",
|
||||||
|
# Audio VAE
|
||||||
|
"AudioEncoder",
|
||||||
|
"AudioDecoder",
|
||||||
|
"Vocoder",
|
||||||
|
"AudioProcessor",
|
||||||
|
"decode_audio",
|
||||||
|
# Patchifiers
|
||||||
|
"VideoLatentPatchifier",
|
||||||
|
"AudioPatchifier",
|
||||||
|
"VideoLatentShape",
|
||||||
|
"AudioLatentShape",
|
||||||
|
# Conditioning
|
||||||
|
"VideoConditionByKeyframeIndex",
|
||||||
|
"VideoConditionByLatentIndex",
|
||||||
]
|
]
|
||||||
|
|
||||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
|
||||||
|
|||||||
@@ -355,6 +355,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr
|
|||||||
"""
|
"""
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
|
||||||
|
if "audio_vae." in weights:
|
||||||
|
return weights
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
|
|
||||||
@@ -364,9 +367,9 @@ def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.arr
|
|||||||
elif key.startswith("audio_vae.per_channel_statistics."):
|
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||||
# Map per-channel statistics
|
# Map per-channel statistics
|
||||||
if "mean-of-means" in key:
|
if "mean-of-means" in key:
|
||||||
new_key = "per_channel_statistics._mean_of_means"
|
new_key = "per_channel_statistics.mean_of_means"
|
||||||
elif "std-of-means" in key:
|
elif "std-of-means" in key:
|
||||||
new_key = "per_channel_statistics._std_of_means"
|
new_key = "per_channel_statistics.std_of_means"
|
||||||
else:
|
else:
|
||||||
continue # Skip other statistics keys
|
continue # Skip other statistics keys
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from .attention import AttentionType, AttnBlock, make_attn
|
from .attention import AttentionType, AttnBlock, make_attn
|
||||||
from .audio_vae import AudioDecoder, decode_audio
|
from .audio_vae import AudioDecoder, decode_audio
|
||||||
from .causal_conv_2d import CausalConv2d, make_conv2d
|
from .causal_conv_2d import CausalConv2d, make_conv2d
|
||||||
from .causality_axis import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
from .downsample import Downsample, build_downsampling_path
|
from .downsample import Downsample, build_downsampling_path
|
||||||
from .normalization import NormType, PixelNorm, build_normalization_layer
|
from .normalization import NormType, PixelNorm, build_normalization_layer
|
||||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
"""Audio VAE encoder and decoder for LTX-2."""
|
"""Audio VAE encoder and decoder for LTX-2."""
|
||||||
|
|
||||||
from typing import Set, Tuple
|
from typing import Dict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
from mlx_vlm.models.base import check_array_shape
|
||||||
|
from ..config import AudioDecoderModelConfig
|
||||||
from .attention import AttentionType, make_attn
|
from .attention import AttentionType, make_attn
|
||||||
from .causal_conv_2d import make_conv2d
|
from .causal_conv_2d import make_conv2d
|
||||||
from .causality_axis import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
from .downsample import build_downsampling_path
|
|
||||||
from .normalization import NormType, build_normalization_layer
|
from .normalization import NormType, build_normalization_layer
|
||||||
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
from .ops import AudioLatentShape, AudioPatchifier, PerChannelStatistics
|
||||||
from .resnet import ResnetBlock
|
from .resnet import ResnetBlock
|
||||||
@@ -67,22 +68,7 @@ class AudioDecoder(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
config: AudioDecoderModelConfig,
|
||||||
ch: int = 128,
|
|
||||||
out_ch: int = 2,
|
|
||||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
|
||||||
num_res_blocks: int = 2,
|
|
||||||
attn_resolutions: Set[int] = None,
|
|
||||||
resolution: int = 256,
|
|
||||||
z_channels: int = 8,
|
|
||||||
norm_type: NormType = NormType.PIXEL,
|
|
||||||
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
|
|
||||||
dropout: float = 0.0,
|
|
||||||
mid_block_add_attention: bool = True,
|
|
||||||
sample_rate: int = 16000,
|
|
||||||
mel_hop_length: int = 160,
|
|
||||||
is_causal: bool = True,
|
|
||||||
mel_bins: int | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the AudioDecoder.
|
Initialize the AudioDecoder.
|
||||||
@@ -105,86 +91,132 @@ class AudioDecoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if attn_resolutions is None:
|
|
||||||
attn_resolutions = {8, 16, 32}
|
|
||||||
|
|
||||||
# Internal behavioral defaults
|
|
||||||
resamp_with_conv = True
|
|
||||||
attn_type = AttentionType.VANILLA
|
|
||||||
|
|
||||||
# Per-channel statistics for denormalizing latents
|
# Per-channel statistics for denormalizing latents
|
||||||
# Uses ch (base channel count) to match the patchified latent dimension
|
# Uses ch (base channel count) to match the patchified latent dimension
|
||||||
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
# Input latent shape: (B, z_channels, T, latent_mel_bins) = (B, 8, T, 16)
|
||||||
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
|
# After patchify: (B, T, z_channels * latent_mel_bins) = (B, T, 128)
|
||||||
# ch=128 matches this dimension, so use ch for per_channel_statistics
|
# ch=128 matches this dimension, so use ch for per_channel_statistics
|
||||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=ch)
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=config.ch)
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = config.sample_rate
|
||||||
self.mel_hop_length = mel_hop_length
|
self.mel_hop_length = config.mel_hop_length
|
||||||
self.is_causal = is_causal
|
self.is_causal = config.is_causal
|
||||||
self.mel_bins = mel_bins
|
self.mel_bins = config.mel_bins
|
||||||
|
|
||||||
self.patchifier = AudioPatchifier(
|
self.patchifier = AudioPatchifier(
|
||||||
patch_size=1,
|
patch_size=1,
|
||||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||||
sample_rate=sample_rate,
|
sample_rate=config.sample_rate,
|
||||||
hop_length=mel_hop_length,
|
hop_length=config.mel_hop_length,
|
||||||
is_causal=is_causal,
|
is_causal=config.is_causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ch = ch
|
self.ch = config.ch
|
||||||
self.temb_ch = 0
|
self.temb_ch = 0
|
||||||
self.num_resolutions = len(ch_mult)
|
self.num_resolutions = len(config.ch_mult)
|
||||||
self.num_res_blocks = num_res_blocks
|
self.num_res_blocks = config.num_res_blocks
|
||||||
self.resolution = resolution
|
self.resolution = config.resolution
|
||||||
self.out_ch = out_ch
|
self.out_ch = config.out_ch
|
||||||
self.give_pre_end = False
|
self.give_pre_end = config.give_pre_end
|
||||||
self.tanh_out = False
|
self.tanh_out = config.tanh_out
|
||||||
self.norm_type = norm_type
|
self.norm_type = config.norm_type
|
||||||
self.z_channels = z_channels
|
self.z_channels = config.z_channels
|
||||||
self.channel_multipliers = ch_mult
|
self.channel_multipliers = config.ch_mult
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = config.attn_resolutions
|
||||||
self.causality_axis = causality_axis
|
self.causality_axis = config.causality_axis
|
||||||
self.attn_type = attn_type
|
self.attn_type = config.attn_type
|
||||||
|
|
||||||
base_block_channels = ch * self.channel_multipliers[-1]
|
base_block_channels = config.ch * self.channel_multipliers[-1]
|
||||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
base_resolution = config.resolution // (2 ** (self.num_resolutions - 1))
|
||||||
self.z_shape = (1, z_channels, base_resolution, base_resolution)
|
self.z_shape = (1, config.z_channels, base_resolution, base_resolution)
|
||||||
|
|
||||||
self.conv_in = make_conv2d(
|
self.conv_in = make_conv2d(
|
||||||
z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
config.z_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mid = build_mid_block(
|
self.mid = build_mid_block(
|
||||||
channels=base_block_channels,
|
channels=base_block_channels,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout,
|
dropout=config.dropout,
|
||||||
norm_type=self.norm_type,
|
norm_type=self.norm_type,
|
||||||
causality_axis=self.causality_axis,
|
causality_axis=self.causality_axis,
|
||||||
attn_type=self.attn_type,
|
attn_type=self.attn_type,
|
||||||
add_attention=mid_block_add_attention,
|
add_attention=config.mid_block_add_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.up, final_block_channels = build_upsampling_path(
|
self.up, final_block_channels = build_upsampling_path(
|
||||||
ch=ch,
|
ch=config.ch,
|
||||||
ch_mult=ch_mult,
|
ch_mult=config.ch_mult,
|
||||||
num_resolutions=self.num_resolutions,
|
num_resolutions=self.num_resolutions,
|
||||||
num_res_blocks=num_res_blocks,
|
num_res_blocks=config.num_res_blocks,
|
||||||
resolution=resolution,
|
resolution=config.resolution,
|
||||||
temb_channels=self.temb_ch,
|
temb_channels=self.temb_ch,
|
||||||
dropout=dropout,
|
dropout=config.dropout,
|
||||||
norm_type=self.norm_type,
|
norm_type=self.norm_type,
|
||||||
causality_axis=self.causality_axis,
|
causality_axis=self.causality_axis,
|
||||||
attn_type=self.attn_type,
|
attn_type=self.attn_type,
|
||||||
attn_resolutions=attn_resolutions,
|
attn_resolutions=config.attn_resolutions,
|
||||||
resamp_with_conv=resamp_with_conv,
|
resamp_with_conv=config.resamp_with_conv,
|
||||||
initial_block_channels=base_block_channels,
|
initial_block_channels=base_block_channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||||
self.conv_out = make_conv2d(
|
self.conv_out = make_conv2d(
|
||||||
final_block_channels, out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
final_block_channels, config.out_ch, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Dictionary of weights with PyTorch naming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with MLX-compatible naming for audio VAE decoder
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Handle audio_vae.decoder weights
|
||||||
|
if key.startswith("audio_vae.decoder."):
|
||||||
|
new_key = key.replace("audio_vae.decoder.", "")
|
||||||
|
elif key.startswith("audio_vae.per_channel_statistics."):
|
||||||
|
# Map per-channel statistics
|
||||||
|
if "mean-of-means" in key:
|
||||||
|
new_key = "per_channel_statistics.mean_of_means"
|
||||||
|
elif "std-of-means" in key:
|
||||||
|
new_key = "per_channel_statistics.std_of_means"
|
||||||
|
else:
|
||||||
|
continue # Skip other statistics keys
|
||||||
|
else:
|
||||||
|
continue # Skip non-decoder keys
|
||||||
|
|
||||||
|
# Handle Conv2d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, H, W)
|
||||||
|
# MLX: (out_channels, H, W, in_channels)
|
||||||
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||||
|
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_path: Path) -> "AudioDecoder":
|
||||||
|
"""Load audio VAE decoder from pretrained model."""
|
||||||
|
from mlx_video.models.ltx.config import AudioDecoderModelConfig
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = AudioDecoderModelConfig.from_dict(json.load(open(model_path / "config.json")))
|
||||||
|
decoder = cls(config)
|
||||||
|
weights = mx.load(str(model_path / "model.safetensors"))
|
||||||
|
# weights = decoder.sanitize(weights)
|
||||||
|
decoder.load_weights(list(weights.items()), strict=True)
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def __call__(self, sample: mx.array) -> mx.array:
|
def __call__(self, sample: mx.array) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Decode latent features back to audio spectrograms.
|
Decode latent features back to audio spectrograms.
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Tuple, Union
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .causality_axis import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
|
|
||||||
|
|
||||||
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
|
def _pair(x: Union[int, Tuple[int, int]]) -> Tuple[int, int]:
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
"""Causality axis enum for specifying causal convolution dimensions."""
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class CausalityAxis(Enum):
|
|
||||||
"""Enum for specifying the causality axis in causal convolutions."""
|
|
||||||
|
|
||||||
NONE = None
|
|
||||||
WIDTH = "width"
|
|
||||||
HEIGHT = "height"
|
|
||||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
|
||||||
@@ -6,7 +6,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .attention import AttentionType, make_attn
|
from .attention import AttentionType, make_attn
|
||||||
from .causality_axis import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
from .normalization import NormType
|
from .normalization import NormType
|
||||||
from .resnet import ResnetBlock
|
from .resnet import ResnetBlock
|
||||||
|
|
||||||
|
|||||||
@@ -27,21 +27,21 @@ class PerChannelStatistics(nn.Module):
|
|||||||
self.latent_channels = latent_channels
|
self.latent_channels = latent_channels
|
||||||
# Initialize buffers - will be loaded from weights
|
# Initialize buffers - will be loaded from weights
|
||||||
# Using underscores for MLX compatibility with weight loading
|
# Using underscores for MLX compatibility with weight loading
|
||||||
self._std_of_means = mx.ones((latent_channels,))
|
self.std_of_means = mx.ones((latent_channels,))
|
||||||
self._mean_of_means = mx.zeros((latent_channels,))
|
self.mean_of_means = mx.zeros((latent_channels,))
|
||||||
|
|
||||||
def un_normalize(self, x: mx.array) -> mx.array:
|
def un_normalize(self, x: mx.array) -> mx.array:
|
||||||
"""Denormalize latent representation."""
|
"""Denormalize latent representation."""
|
||||||
# Broadcast statistics to match x shape
|
# Broadcast statistics to match x shape
|
||||||
# x shape: (B, C, ...) or (B, ..., C)
|
# x shape: (B, C, ...) or (B, ..., C)
|
||||||
std = self._std_of_means.astype(x.dtype)
|
std = self.std_of_means.astype(x.dtype)
|
||||||
mean = self._mean_of_means.astype(x.dtype)
|
mean = self.mean_of_means.astype(x.dtype)
|
||||||
return (x * std) + mean
|
return (x * std) + mean
|
||||||
|
|
||||||
def normalize(self, x: mx.array) -> mx.array:
|
def normalize(self, x: mx.array) -> mx.array:
|
||||||
"""Normalize latent representation."""
|
"""Normalize latent representation."""
|
||||||
std = self._std_of_means.astype(x.dtype)
|
std = self.std_of_means.astype(x.dtype)
|
||||||
mean = self._mean_of_means.astype(x.dtype)
|
mean = self.mean_of_means.astype(x.dtype)
|
||||||
return (x - mean) / std
|
return (x - mean) / std
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .causal_conv_2d import make_conv2d
|
from .causal_conv_2d import make_conv2d
|
||||||
from .causality_axis import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
from .normalization import NormType, build_normalization_layer
|
from .normalization import NormType, build_normalization_layer
|
||||||
|
|
||||||
LRELU_SLOPE = 0.1
|
LRELU_SLOPE = 0.1
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import mlx.nn as nn
|
|||||||
|
|
||||||
from .attention import AttentionType, make_attn
|
from .attention import AttentionType, make_attn
|
||||||
from .causal_conv_2d import make_conv2d
|
from .causal_conv_2d import make_conv2d
|
||||||
from .causality_axis import CausalityAxis
|
from ..config import CausalityAxis
|
||||||
from .normalization import NormType
|
from .normalization import NormType
|
||||||
from .resnet import ResnetBlock
|
from .resnet import ResnetBlock
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
"""Vocoder for converting mel spectrograms to audio waveforms."""
|
"""Vocoder for converting mel spectrograms to audio waveforms."""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import Dict
|
||||||
|
from pathlib import Path
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
from mlx_vlm.models.base import check_array_shape
|
||||||
|
from ..config import VocoderModelConfig
|
||||||
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
|
from .resnet import LRELU_SLOPE, ResBlock1, ResBlock2, leaky_relu
|
||||||
|
|
||||||
|
|
||||||
@@ -27,44 +28,29 @@ class Vocoder(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
resblock_kernel_sizes: List[int] | None = None,
|
config: VocoderModelConfig
|
||||||
upsample_rates: List[int] | None = None,
|
|
||||||
upsample_kernel_sizes: List[int] | None = None,
|
|
||||||
resblock_dilation_sizes: List[List[int]] | None = None,
|
|
||||||
upsample_initial_channel: int = 1024,
|
|
||||||
stereo: bool = True,
|
|
||||||
resblock: str = "1",
|
|
||||||
output_sample_rate: int = 24000,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Initialize default values if not provided
|
|
||||||
if resblock_kernel_sizes is None:
|
|
||||||
resblock_kernel_sizes = [3, 7, 11]
|
|
||||||
if upsample_rates is None:
|
|
||||||
upsample_rates = [6, 5, 2, 2, 2]
|
|
||||||
if upsample_kernel_sizes is None:
|
|
||||||
upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
|
||||||
if resblock_dilation_sizes is None:
|
|
||||||
resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
|
||||||
|
|
||||||
self.output_sample_rate = output_sample_rate
|
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
|
||||||
self.num_upsamples = len(upsample_rates)
|
|
||||||
self.upsample_rates = upsample_rates
|
|
||||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
|
||||||
self.upsample_initial_channel = upsample_initial_channel
|
|
||||||
|
|
||||||
in_channels = 128 if stereo else 64
|
self.output_sample_rate = config.output_sample_rate
|
||||||
self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, kernel_size=7, stride=1, padding=3)
|
self.num_kernels = len(config.resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(config.upsample_rates)
|
||||||
|
self.upsample_rates = config.upsample_rates
|
||||||
|
self.upsample_kernel_sizes = config.upsample_kernel_sizes
|
||||||
|
self.upsample_initial_channel = config.upsample_initial_channel
|
||||||
|
|
||||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
in_channels = 128 if config.stereo else 64
|
||||||
|
self.conv_pre = nn.Conv1d(in_channels, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3)
|
||||||
|
|
||||||
|
resblock_class = ResBlock1 if config.resblock == "1" else ResBlock2
|
||||||
|
|
||||||
# Upsampling layers using ConvTranspose1d
|
# Upsampling layers using ConvTranspose1d
|
||||||
self.ups = {}
|
self.ups = {}
|
||||||
for i, (stride, kernel_size) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
|
||||||
in_ch = upsample_initial_channel // (2**i)
|
in_ch = config.upsample_initial_channel // (2**i)
|
||||||
out_ch = upsample_initial_channel // (2 ** (i + 1))
|
out_ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||||
self.ups[i] = nn.ConvTranspose1d(
|
self.ups[i] = nn.ConvTranspose1d(
|
||||||
in_ch,
|
in_ch,
|
||||||
out_ch,
|
out_ch,
|
||||||
@@ -77,16 +63,67 @@ class Vocoder(nn.Module):
|
|||||||
self.resblocks = {}
|
self.resblocks = {}
|
||||||
block_idx = 0
|
block_idx = 0
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
ch = config.upsample_initial_channel // (2 ** (i + 1))
|
||||||
for kernel_size, dilations in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
for kernel_size, dilations in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
|
||||||
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
self.resblocks[block_idx] = resblock_class(ch, kernel_size, tuple(dilations))
|
||||||
block_idx += 1
|
block_idx += 1
|
||||||
|
|
||||||
out_channels = 2 if stereo else 1
|
out_channels = 2 if config.stereo else 1
|
||||||
final_channels = upsample_initial_channel // (2**self.num_upsamples)
|
final_channels = config.upsample_initial_channel // (2**self.num_upsamples)
|
||||||
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
|
self.conv_post = nn.Conv1d(final_channels, out_channels, kernel_size=7, stride=1, padding=3)
|
||||||
|
|
||||||
self.upsample_factor = math.prod(upsample_rates)
|
self.upsample_factor = math.prod(config.upsample_rates)
|
||||||
|
|
||||||
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
if "vocoder." not in weights:
|
||||||
|
return weights
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Handle vocoder weights
|
||||||
|
if key.startswith("vocoder."):
|
||||||
|
new_key = key.replace("vocoder.", "")
|
||||||
|
|
||||||
|
# Handle ModuleList indices -> dict keys
|
||||||
|
# PyTorch: ups.0, ups.1, ... -> ups.0, ups.1, ...
|
||||||
|
# PyTorch: resblocks.0, resblocks.1, ... -> resblocks.0, resblocks.1, ...
|
||||||
|
|
||||||
|
# Handle Conv1d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, kernel)
|
||||||
|
# MLX: (out_channels, kernel, in_channels)
|
||||||
|
if "weight" in new_key and value.ndim == 3:
|
||||||
|
if "ups" in new_key:
|
||||||
|
# ConvTranspose1d: PyTorch (in_ch, out_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||||
|
value = value if check_array_shape(value) else mx.transpose(value, (1, 2, 0))
|
||||||
|
else:
|
||||||
|
# Conv1d: PyTorch (out_ch, in_ch, kernel) -> MLX (out_ch, kernel, in_ch)
|
||||||
|
value = value if check_array_shape(value) else mx.transpose(value, (0, 2, 1))
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "Vocoder":
|
||||||
|
"""Load vocoder from pretrained model."""
|
||||||
|
from mlx_video.models.ltx.config import VocoderModelConfig
|
||||||
|
import json
|
||||||
|
|
||||||
|
config_dict = {}
|
||||||
|
with open(model_path / "config.json", "r") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
|
||||||
|
config = VocoderModelConfig.from_dict(config_dict)
|
||||||
|
model = cls(config)
|
||||||
|
weights = mx.load(str(model_path / "model.safetensors"))
|
||||||
|
|
||||||
|
# weights = vocoder.sanitize(weights)
|
||||||
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, Tuple, Set
|
||||||
|
|
||||||
|
|
||||||
class LTXModelType(Enum):
|
class LTXModelType(Enum):
|
||||||
@@ -180,3 +180,141 @@ class LTXModelConfig(BaseModelConfig):
|
|||||||
d_head=self.audio_attention_head_dim,
|
d_head=self.audio_attention_head_dim,
|
||||||
context_dim=self.audio_cross_attention_dim,
|
context_dim=self.audio_cross_attention_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CausalityAxis(Enum):
|
||||||
|
"""Enum for specifying the causality axis in causal convolutions."""
|
||||||
|
|
||||||
|
NONE = None
|
||||||
|
WIDTH = "width"
|
||||||
|
HEIGHT = "height"
|
||||||
|
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AudioDecoderModelConfig(BaseModelConfig):
|
||||||
|
ch: int = 128
|
||||||
|
out_ch: int = 2
|
||||||
|
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||||
|
num_res_blocks: int = 2
|
||||||
|
attn_resolutions: Optional[List[int]] = None
|
||||||
|
resolution: int = 256
|
||||||
|
z_channels: int = 8
|
||||||
|
norm_type: Enum = None
|
||||||
|
causality_axis: Enum = None
|
||||||
|
dropout: float = 0.0
|
||||||
|
mid_block_add_attention: bool = True
|
||||||
|
sample_rate: int = 16000
|
||||||
|
mel_hop_length: int = 160
|
||||||
|
is_causal: bool = True
|
||||||
|
mel_bins: int | None = None
|
||||||
|
resamp_with_conv: bool = True
|
||||||
|
attn_type: str = None
|
||||||
|
give_pre_end: bool = False
|
||||||
|
tanh_out: bool = False
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if self.attn_resolutions is not None:
|
||||||
|
result["attn_resolutions"] = list(self.attn_resolutions)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Convert string enum values to proper enum types."""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from .audio_vae.normalization import NormType
|
||||||
|
from .audio_vae.attention import AttentionType
|
||||||
|
|
||||||
|
# Convert causality_axis string to enum
|
||||||
|
if isinstance(self.causality_axis, str):
|
||||||
|
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||||
|
|
||||||
|
# Convert norm_type string to enum
|
||||||
|
if isinstance(self.norm_type, str):
|
||||||
|
self.norm_type = NormType(self.norm_type)
|
||||||
|
|
||||||
|
# Convert attn_type string to enum
|
||||||
|
if isinstance(self.attn_type, str):
|
||||||
|
self.attn_type = AttentionType(self.attn_type)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VocoderModelConfig(BaseModelConfig):
|
||||||
|
resblock_kernel_sizes: Optional[List[int]] = None
|
||||||
|
upsample_rates: Optional[List[int]] = None
|
||||||
|
upsample_kernel_sizes: Optional[List[int]] = None
|
||||||
|
resblock_dilation_sizes: Optional[List[List[int]]] = None
|
||||||
|
upsample_initial_channel: int = 1024
|
||||||
|
stereo: bool = True
|
||||||
|
resblock: str = "1"
|
||||||
|
output_sample_rate: int = 24000
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
|
||||||
|
if self.resblock_kernel_sizes is None:
|
||||||
|
self.resblock_kernel_sizes = [3, 7, 11]
|
||||||
|
if self.upsample_rates is None:
|
||||||
|
self.upsample_rates = [6, 5, 2, 2, 2]
|
||||||
|
if self.upsample_kernel_sizes is None:
|
||||||
|
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
||||||
|
if self.resblock_dilation_sizes is None:
|
||||||
|
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoDecoderModelConfig(BaseModelConfig):
|
||||||
|
ch: int = 128
|
||||||
|
out_ch: int = 2
|
||||||
|
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||||
|
num_res_blocks: int = 2
|
||||||
|
attn_resolutions: Optional[List[int]] = None
|
||||||
|
resolution: int = 256
|
||||||
|
z_channels: int = 8
|
||||||
|
norm_type: Enum = None
|
||||||
|
causality_axis: Enum = None
|
||||||
|
dropout: float = 0.0
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoEncoderModelConfig(BaseModelConfig):
|
||||||
|
convolution_dimensions: int = 3
|
||||||
|
in_channels : int = 3,
|
||||||
|
out_channels: int = 128,
|
||||||
|
patch_size: int = 4,
|
||||||
|
norm_layer: Enum = None,
|
||||||
|
latent_log_var: Enum = None,
|
||||||
|
encoder_spatial_padding_mode: Enum = None,
|
||||||
|
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
|
||||||
|
("compress_space_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 6}),
|
||||||
|
("compress_time_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 6}),
|
||||||
|
("compress_all_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 2}),
|
||||||
|
("compress_all_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 2})
|
||||||
|
])
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
|
||||||
|
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
|
||||||
|
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
|
||||||
|
|
||||||
|
if self.norm_layer is None:
|
||||||
|
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||||
|
if self.latent_log_var is None:
|
||||||
|
self.latent_log_var = LogVarianceType.UNIFORM
|
||||||
|
if self.encoder_spatial_padding_mode is None:
|
||||||
|
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
|
||||||
|
|
||||||
|
if isinstance(self.norm_layer, str):
|
||||||
|
self.norm_layer = NormLayerType(self.norm_layer)
|
||||||
|
if isinstance(self.latent_log_var, str):
|
||||||
|
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
||||||
|
if isinstance(self.encoder_spatial_padding_mode, str):
|
||||||
|
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
result = super().to_dict()
|
||||||
|
if self.encoder_blocks is not None:
|
||||||
|
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
|
||||||
|
return result
|
||||||
@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from pathlib import Path
|
|
||||||
from mlx_video.models.ltx.config import (
|
from mlx_video.models.ltx.config import (
|
||||||
LTXModelConfig,
|
LTXModelConfig,
|
||||||
LTXModelType,
|
LTXModelType,
|
||||||
@@ -52,11 +52,10 @@ class TransformerArgsPreprocessor:
|
|||||||
self,
|
self,
|
||||||
timestep: mx.array,
|
timestep: mx.array,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
hidden_dtype: mx.Dtype = None,
|
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
|
||||||
|
|
||||||
# Reshape to (batch, tokens, dim)
|
# Reshape to (batch, tokens, dim)
|
||||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||||
@@ -71,9 +70,6 @@ class TransformerArgsPreprocessor:
|
|||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
# Context is already processed through embeddings connector in text encoder
|
|
||||||
# Here we just apply the caption projection
|
|
||||||
context = self.caption_projection(context)
|
context = self.caption_projection(context)
|
||||||
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
context = mx.reshape(context, (batch_size, -1, x.shape[-1]))
|
||||||
return context, attention_mask
|
return context, attention_mask
|
||||||
@@ -118,14 +114,9 @@ class TransformerArgsPreprocessor:
|
|||||||
|
|
||||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||||
x = self.patchify_proj(modality.latent)
|
x = self.patchify_proj(modality.latent)
|
||||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
|
||||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||||
|
|
||||||
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
|
|
||||||
if modality.positional_embeddings is not None:
|
|
||||||
pe = modality.positional_embeddings
|
|
||||||
else:
|
|
||||||
pe = self._prepare_positional_embeddings(
|
pe = self._prepare_positional_embeddings(
|
||||||
positions=modality.positions,
|
positions=modality.positions,
|
||||||
inner_dim=self.inner_dim,
|
inner_dim=self.inner_dim,
|
||||||
@@ -207,7 +198,6 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
timestep=modality.timesteps,
|
timestep=modality.timesteps,
|
||||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||||
batch_size=transformer_args.x.shape[0],
|
batch_size=transformer_args.x.shape[0],
|
||||||
hidden_dtype=transformer_args.x.dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return replace(
|
return replace(
|
||||||
@@ -222,16 +212,15 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
timestep: mx.array,
|
timestep: mx.array,
|
||||||
timestep_scale_multiplier: int,
|
timestep_scale_multiplier: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
hidden_dtype: mx.Dtype = None,
|
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
timestep = timestep * timestep_scale_multiplier
|
timestep = timestep * timestep_scale_multiplier
|
||||||
|
|
||||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||||
|
|
||||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
|
||||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||||
|
|
||||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
|
||||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||||
|
|
||||||
return scale_shift_timestep, gate_timestep
|
return scale_shift_timestep, gate_timestep
|
||||||
@@ -293,8 +282,6 @@ class LTXModel(nn.Module):
|
|||||||
def _init_audio(self, config: LTXModelConfig) -> None:
|
def _init_audio(self, config: LTXModelConfig) -> None:
|
||||||
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
self.audio_patchify_proj = nn.Linear(config.audio_in_channels, self.audio_inner_dim, bias=True)
|
||||||
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
self.audio_adaln_single = AdaLayerNormSingle(self.audio_inner_dim)
|
||||||
|
|
||||||
# Audio caption projection: receives pre-processed embeddings from text encoder's audio_embeddings_connector
|
|
||||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
in_features=config.audio_caption_channels,
|
in_features=config.audio_caption_channels,
|
||||||
hidden_size=self.audio_inner_dim,
|
hidden_size=self.audio_inner_dim,
|
||||||
@@ -397,9 +384,8 @@ class LTXModel(nn.Module):
|
|||||||
video_config = config.get_video_config()
|
video_config = config.get_video_config()
|
||||||
audio_config = config.get_audio_config()
|
audio_config = config.get_audio_config()
|
||||||
|
|
||||||
|
self.transformer_blocks = [
|
||||||
self.transformer_blocks = {
|
BasicAVTransformerBlock(
|
||||||
idx: BasicAVTransformerBlock(
|
|
||||||
idx=idx,
|
idx=idx,
|
||||||
video=video_config,
|
video=video_config,
|
||||||
audio=audio_config,
|
audio=audio_config,
|
||||||
@@ -407,7 +393,7 @@ class LTXModel(nn.Module):
|
|||||||
norm_eps=config.norm_eps,
|
norm_eps=config.norm_eps,
|
||||||
)
|
)
|
||||||
for idx in range(config.num_layers)
|
for idx in range(config.num_layers)
|
||||||
}
|
]
|
||||||
|
|
||||||
def _process_transformer_blocks(
|
def _process_transformer_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -415,7 +401,7 @@ class LTXModel(nn.Module):
|
|||||||
audio: Optional[TransformerArgs],
|
audio: Optional[TransformerArgs],
|
||||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||||
"""Process through all transformer blocks."""
|
"""Process through all transformer blocks."""
|
||||||
for block in self.transformer_blocks.values():
|
for block in self.transformer_blocks:
|
||||||
video, audio = block(video=video, audio=audio)
|
video, audio = block(video=video, audio=audio)
|
||||||
return video, audio
|
return video, audio
|
||||||
|
|
||||||
@@ -497,50 +483,19 @@ class LTXModel(nn.Module):
|
|||||||
|
|
||||||
def sanitize(self, weights: dict) -> dict:
|
def sanitize(self, weights: dict) -> dict:
|
||||||
sanitized = {}
|
sanitized = {}
|
||||||
|
|
||||||
for key, value in weights.items():
|
for key, value in weights.items():
|
||||||
new_key = key
|
new_key = key
|
||||||
# Skip non-transformer weights (VAE, vocoder, audio_vae, connectors)
|
|
||||||
if not key.startswith("model.diffusion_model.") or "audio_embeddings_connector" in key or "video_embeddings_connector" in key:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Remove 'model.diffusion_model.' prefix
|
|
||||||
new_key = new_key.replace("model.diffusion_model.", "")
|
|
||||||
|
|
||||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
|
||||||
|
|
||||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
|
||||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
|
||||||
new_key = new_key.replace(".audio_ff.net.0.proj.", ".audio_ff.proj_in.")
|
|
||||||
new_key = new_key.replace(".audio_ff.net.2.", ".audio_ff.proj_out.")
|
|
||||||
|
|
||||||
new_key = new_key.replace(".linear_1.", ".linear1.")
|
|
||||||
new_key = new_key.replace(".linear_2.", ".linear2.")
|
|
||||||
|
|
||||||
|
# Handle common remappings
|
||||||
|
# transformer_blocks.X -> transformer_blocks[X]
|
||||||
|
if "transformer_blocks." in new_key:
|
||||||
|
# Keep as-is for now, MLX handles this
|
||||||
|
pass
|
||||||
|
|
||||||
sanitized[new_key] = value
|
sanitized[new_key] = value
|
||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(cls, model_path: [Path, List[Path]], config: LTXModelConfig, strict: bool = True) -> None:
|
|
||||||
model = cls(config)
|
|
||||||
|
|
||||||
weights = {}
|
|
||||||
if isinstance(model_path, Path):
|
|
||||||
model_path = [model_path]
|
|
||||||
for weight_file in model_path:
|
|
||||||
weights.update(mx.load(str(weight_file)))
|
|
||||||
|
|
||||||
|
|
||||||
sanitized = model.sanitize(weights)
|
|
||||||
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
|
||||||
|
|
||||||
model.load_weights(list(sanitized.items()), strict=strict)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class X0Model(nn.Module):
|
class X0Model(nn.Module):
|
||||||
|
|
||||||
|
|||||||
@@ -428,11 +428,14 @@ def _precompute_freqs_cis_double_precision(
|
|||||||
num_attention_heads: int,
|
num_attention_heads: int,
|
||||||
rope_type: LTXRopeType,
|
rope_type: LTXRopeType,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
"""Compute RoPE frequencies with higher precision using float32.
|
"""Compute RoPE frequencies with higher precision using float64 for frequency grid.
|
||||||
|
|
||||||
This version stays entirely in MLX/GPU, avoiding expensive NumPy round-trips.
|
Matches PyTorch's approach: uses NumPy float64 for the critical frequency grid
|
||||||
Uses float32 for computation precision (sufficient for RoPE).
|
computation (log-spaced values), then converts to float32 for the final tensor.
|
||||||
|
This provides better numerical precision in the frequency generation phase.
|
||||||
"""
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Warn if positions are bfloat16 - this causes quality degradation
|
# Warn if positions are bfloat16 - this causes quality degradation
|
||||||
if indices_grid.dtype == mx.bfloat16:
|
if indices_grid.dtype == mx.bfloat16:
|
||||||
import warnings
|
import warnings
|
||||||
@@ -443,21 +446,27 @@ def _precompute_freqs_cis_double_precision(
|
|||||||
stacklevel=2
|
stacklevel=2
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cast to float32 for computation (stay on GPU, no NumPy/CPU conversion)
|
# Cast to float32 for position computation
|
||||||
indices_grid_f32 = indices_grid.astype(mx.float32)
|
indices_grid_f32 = indices_grid.astype(mx.float32)
|
||||||
|
|
||||||
n_pos_dims = indices_grid_f32.shape[1]
|
n_pos_dims = indices_grid_f32.shape[1]
|
||||||
n_elem = 2 * n_pos_dims
|
n_elem = 2 * n_pos_dims
|
||||||
|
|
||||||
# Compute log-spaced frequencies in float32
|
# Compute log-spaced frequencies in float64 (matching PyTorch's generate_freq_grid_np)
|
||||||
log_start = math.log(1.0) / math.log(theta)
|
# This is the critical precision step - PyTorch uses np.float64 here
|
||||||
log_end = math.log(theta) / math.log(theta)
|
log_start = np.log(1.0) / np.log(theta)
|
||||||
|
log_end = np.log(theta) / np.log(theta) # = 1.0
|
||||||
num_indices = dim // n_elem
|
num_indices = dim // n_elem
|
||||||
if num_indices == 0:
|
if num_indices == 0:
|
||||||
num_indices = 1
|
num_indices = 1
|
||||||
|
|
||||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
# Use numpy float64 for the linspace computation (matches PyTorch)
|
||||||
freq_indices = mx.power(mx.array(theta, dtype=mx.float32), lin_space) * (math.pi / 2)
|
pow_indices = np.power(
|
||||||
|
theta,
|
||||||
|
np.linspace(log_start, log_end, num_indices, dtype=np.float64),
|
||||||
|
)
|
||||||
|
# Convert to float32 tensor (matches PyTorch: torch.tensor(..., dtype=torch.float32))
|
||||||
|
freq_indices = mx.array(pow_indices * (math.pi / 2), dtype=mx.float32)
|
||||||
|
|
||||||
# Handle middle indices grid
|
# Handle middle indices grid
|
||||||
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
# Input shape: (B, n_dims, T, 2) for middle indices or (B, n_dims, T, 1) otherwise
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
|
from mlx_video.models.ltx.video_vae.encoder import encode_image
|
||||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
|
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder
|
||||||
from mlx_video.models.ltx.video_vae.tiling import (
|
from mlx_video.models.ltx.video_vae.tiling import (
|
||||||
TilingConfig,
|
TilingConfig,
|
||||||
SpatialTilingConfig,
|
SpatialTilingConfig,
|
||||||
|
|||||||
@@ -15,13 +15,14 @@ Architecture (from PyTorch weights):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional, Dict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||||
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
|
||||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||||
|
|
||||||
@@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
self.decode_timestep = 0.05
|
self.decode_timestep = 0.05
|
||||||
|
|
||||||
# Per-channel statistics for denormalization (loaded from weights)
|
# Per-channel statistics for denormalization (loaded from weights)
|
||||||
self.latents_mean = mx.zeros((in_channels,))
|
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||||
self.latents_std = mx.ones((in_channels,))
|
|
||||||
|
|
||||||
# Initial conv: 128 -> 1024
|
# Initial conv: 128 -> 1024
|
||||||
class ConvInWrapper(nn.Module):
|
class ConvInWrapper(nn.Module):
|
||||||
@@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.last_scale_shift_table = mx.zeros((2, 128))
|
self.last_scale_shift_table = mx.zeros((2, 128))
|
||||||
|
|
||||||
def denormalize(self, x: mx.array) -> mx.array:
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
"""Denormalize latents using per-channel statistics."""
|
# Build decoder weights dict with key remapping
|
||||||
dtype = x.dtype
|
sanitized = {}
|
||||||
# Cast to float32 for precision (statistics may be in bfloat16)
|
for key, value in weights.items():
|
||||||
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
new_key = key
|
||||||
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
|
||||||
return (x * std + mean).astype(dtype)
|
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key.startswith("vae.per_channel_statistics."):
|
||||||
|
# Map per-channel statistics (use exact key matching)
|
||||||
|
if key == "vae.per_channel_statistics.mean-of-means":
|
||||||
|
new_key = "per_channel_statistics.mean"
|
||||||
|
elif key == "vae.per_channel_statistics.std-of-means":
|
||||||
|
new_key = "per_channel_statistics.std"
|
||||||
|
else:
|
||||||
|
continue # Skip other statistics keys
|
||||||
|
|
||||||
|
if key.startswith("vae.decoder."):
|
||||||
|
new_key = key.replace("vae.decoder.", "")
|
||||||
|
|
||||||
|
|
||||||
|
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||||
|
if ".conv.weight" in key and value.ndim == 5:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
if ".conv.bias" in key:
|
||||||
|
pass # bias doesn't need transpose
|
||||||
|
|
||||||
|
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||||
|
|
||||||
|
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||||
|
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||||
|
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
|
||||||
|
from safetensors import safe_open
|
||||||
|
import json
|
||||||
|
weights = mx.load(str(model_path))
|
||||||
|
|
||||||
|
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||||
|
if timestep_conditioning is None:
|
||||||
|
try:
|
||||||
|
with safe_open(str(model_path), framework="numpy") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata and "config" in metadata:
|
||||||
|
configs = json.loads(metadata["config"])
|
||||||
|
vae_config = configs.get("vae", {})
|
||||||
|
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||||
|
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||||
|
else:
|
||||||
|
timestep_conditioning = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||||
|
timestep_conditioning = False
|
||||||
|
|
||||||
|
model = cls(timestep_conditioning=timestep_conditioning)
|
||||||
|
weights = model.sanitize(weights)
|
||||||
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||||
"""Apply pixel normalization."""
|
"""Apply pixel normalization."""
|
||||||
@@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
chunked_conv: bool = False,
|
chunked_conv: bool = False,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
|
||||||
def debug_stats(name, t):
|
|
||||||
if debug:
|
|
||||||
mx.eval(t)
|
|
||||||
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
|
||||||
|
|
||||||
batch_size = sample.shape[0]
|
batch_size = sample.shape[0]
|
||||||
|
|
||||||
if debug:
|
|
||||||
debug_stats("Input", sample)
|
|
||||||
|
|
||||||
# Add noise if timestep conditioning is enabled
|
# Add noise if timestep conditioning is enabled
|
||||||
if self.timestep_conditioning:
|
if self.timestep_conditioning:
|
||||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||||
if debug:
|
|
||||||
debug_stats("After noise", sample)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
|
sample = self.per_channel_statistics.un_normalize(sample)
|
||||||
sample = self.denormalize(sample)
|
|
||||||
if debug:
|
|
||||||
debug_stats("After denormalize", sample)
|
|
||||||
|
|
||||||
if timestep is None and self.timestep_conditioning:
|
if timestep is None and self.timestep_conditioning:
|
||||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||||
@@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||||
|
|
||||||
x = self.conv_in(sample, causal=causal)
|
x = self.conv_in(sample, causal=causal)
|
||||||
if debug:
|
|
||||||
debug_stats("After conv_in", x)
|
|
||||||
|
|
||||||
for i, block in self.up_blocks.items():
|
for i, block in self.up_blocks.items():
|
||||||
if isinstance(block, ResBlockGroup):
|
if isinstance(block, ResBlockGroup):
|
||||||
@@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||||
else:
|
else:
|
||||||
x = block(x, causal=causal)
|
x = block(x, causal=causal)
|
||||||
if debug:
|
|
||||||
block_type = type(block).__name__
|
|
||||||
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
|
|
||||||
|
|
||||||
x = self.pixel_norm(x)
|
x = self.pixel_norm(x)
|
||||||
if debug:
|
|
||||||
debug_stats("After pixel_norm", x)
|
|
||||||
|
|
||||||
if self.timestep_conditioning and scaled_timestep is not None:
|
if self.timestep_conditioning and scaled_timestep is not None:
|
||||||
embedded_timestep = self.last_time_embedder(
|
embedded_timestep = self.last_time_embedder(
|
||||||
@@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
scale = ada_values[:, 1]
|
scale = ada_values[:, 1]
|
||||||
|
|
||||||
x = x * (1 + scale) + shift
|
x = x * (1 + scale) + shift
|
||||||
if debug:
|
|
||||||
debug_stats("After timestep modulation", x)
|
|
||||||
|
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
if debug:
|
|
||||||
debug_stats("After activation", x)
|
|
||||||
|
|
||||||
x = self.conv_out(x, causal=causal)
|
x = self.conv_out(x, causal=causal)
|
||||||
if debug:
|
|
||||||
debug_stats("After conv_out", x)
|
|
||||||
|
|
||||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
if debug:
|
|
||||||
debug_stats("After unpatchify", x)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
chunked_conv=use_chunked_conv,
|
chunked_conv=use_chunked_conv,
|
||||||
on_frames_ready=on_frames_ready,
|
on_frames_ready=on_frames_ready,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
|
||||||
from pathlib import Path
|
|
||||||
import json
|
|
||||||
from safetensors import safe_open
|
|
||||||
|
|
||||||
model_path = Path(model_path)
|
|
||||||
|
|
||||||
# Try to find the weights file
|
|
||||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
|
||||||
weights_path = model_path
|
|
||||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
|
||||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
|
||||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
|
||||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
|
||||||
|
|
||||||
print(f"Loading VAE decoder from {weights_path}...")
|
|
||||||
|
|
||||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
|
||||||
if timestep_conditioning is None:
|
|
||||||
try:
|
|
||||||
with safe_open(str(weights_path), framework="numpy") as f:
|
|
||||||
metadata = f.metadata()
|
|
||||||
if metadata and "config" in metadata:
|
|
||||||
configs = json.loads(metadata["config"])
|
|
||||||
vae_config = configs.get("vae", {})
|
|
||||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
|
||||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
|
||||||
else:
|
|
||||||
timestep_conditioning = False
|
|
||||||
except Exception as e:
|
|
||||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
|
||||||
timestep_conditioning = False
|
|
||||||
|
|
||||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
|
||||||
|
|
||||||
weights = mx.load(str(weights_path))
|
|
||||||
|
|
||||||
# Determine prefix based on weight keys
|
|
||||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
|
||||||
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
|
|
||||||
|
|
||||||
if has_vae_prefix:
|
|
||||||
prefix = "vae.decoder."
|
|
||||||
stats_prefix = "vae.per_channel_statistics."
|
|
||||||
elif has_decoder_prefix:
|
|
||||||
prefix = "decoder."
|
|
||||||
stats_prefix = ""
|
|
||||||
else:
|
|
||||||
prefix = ""
|
|
||||||
stats_prefix = ""
|
|
||||||
|
|
||||||
# Load per-channel statistics for denormalization
|
|
||||||
# Note: use std-of-means (not mean-of-stds) for proper denormalization
|
|
||||||
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
|
|
||||||
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
|
|
||||||
|
|
||||||
if mean_key in weights:
|
|
||||||
decoder.latents_mean = weights[mean_key]
|
|
||||||
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
|
|
||||||
if std_key in weights:
|
|
||||||
decoder.latents_std = weights[std_key]
|
|
||||||
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
|
|
||||||
|
|
||||||
# Build decoder weights dict with key remapping
|
|
||||||
decoder_weights = {}
|
|
||||||
for key, value in weights.items():
|
|
||||||
if not key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Remove prefix
|
|
||||||
new_key = key[len(prefix):]
|
|
||||||
|
|
||||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
|
||||||
if ".conv.weight" in key and value.ndim == 5:
|
|
||||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
|
||||||
if ".conv.bias" in key:
|
|
||||||
pass # bias doesn't need transpose
|
|
||||||
|
|
||||||
|
|
||||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
|
||||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
|
||||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
|
||||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
|
||||||
|
|
||||||
decoder_weights[new_key] = value
|
|
||||||
|
|
||||||
print(f" Found {len(decoder_weights)} decoder weights")
|
|
||||||
|
|
||||||
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
|
|
||||||
print(f" Found {len(ts_keys)} timestep conditioning weights")
|
|
||||||
|
|
||||||
# Load weights
|
|
||||||
decoder.load_weights(list(decoder_weights.items()), strict=False)
|
|
||||||
|
|
||||||
print("VAE decoder loaded successfully")
|
|
||||||
return decoder
|
|
||||||
|
|||||||
@@ -5,152 +5,9 @@ Used for I2V (image-to-video) conditioning by encoding the input image
|
|||||||
to latent space, which can then be used to condition video generation.
|
to latent space, which can then be used to condition video generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple, Any, Optional
|
|
||||||
import json
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||||
|
|
||||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
|
|
||||||
|
|
||||||
|
|
||||||
def load_vae_encoder(model_path: str) -> VideoEncoder:
|
|
||||||
"""Load VAE encoder from safetensors file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to the model weights (safetensors file or directory)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Loaded VideoEncoder instance
|
|
||||||
"""
|
|
||||||
from safetensors import safe_open
|
|
||||||
|
|
||||||
model_path = Path(model_path)
|
|
||||||
|
|
||||||
# Try to find the weights file
|
|
||||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
|
||||||
weights_path = model_path
|
|
||||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
|
||||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
|
||||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
|
||||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
|
||||||
|
|
||||||
print(f"Loading VAE encoder from {weights_path}...")
|
|
||||||
|
|
||||||
# Read config from safetensors metadata
|
|
||||||
encoder_blocks = []
|
|
||||||
norm_layer = NormLayerType.PIXEL_NORM
|
|
||||||
latent_log_var = LogVarianceType.UNIFORM
|
|
||||||
patch_size = 4
|
|
||||||
|
|
||||||
try:
|
|
||||||
with safe_open(str(weights_path), framework="numpy") as f:
|
|
||||||
metadata = f.metadata()
|
|
||||||
if metadata and "config" in metadata:
|
|
||||||
configs = json.loads(metadata["config"])
|
|
||||||
vae_config = configs.get("vae", {})
|
|
||||||
|
|
||||||
# Parse encoder blocks
|
|
||||||
raw_blocks = vae_config.get("encoder_blocks", [])
|
|
||||||
for block in raw_blocks:
|
|
||||||
if isinstance(block, list) and len(block) == 2:
|
|
||||||
name, params = block
|
|
||||||
encoder_blocks.append((name, params))
|
|
||||||
|
|
||||||
# Parse other config
|
|
||||||
norm_str = vae_config.get("norm_layer", "pixel_norm")
|
|
||||||
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
|
|
||||||
|
|
||||||
var_str = vae_config.get("latent_log_var", "uniform")
|
|
||||||
if var_str == "uniform":
|
|
||||||
latent_log_var = LogVarianceType.UNIFORM
|
|
||||||
elif var_str == "per_channel":
|
|
||||||
latent_log_var = LogVarianceType.PER_CHANNEL
|
|
||||||
elif var_str == "constant":
|
|
||||||
latent_log_var = LogVarianceType.CONSTANT
|
|
||||||
else:
|
|
||||||
latent_log_var = LogVarianceType.NONE
|
|
||||||
|
|
||||||
patch_size = vae_config.get("patch_size", 4)
|
|
||||||
|
|
||||||
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" Could not read config from metadata: {e}")
|
|
||||||
# Use default config
|
|
||||||
encoder_blocks = [
|
|
||||||
("res_x", {"num_layers": 4}),
|
|
||||||
("compress_space_res", {"multiplier": 2}),
|
|
||||||
("res_x", {"num_layers": 6}),
|
|
||||||
("compress_time_res", {"multiplier": 2}),
|
|
||||||
("res_x", {"num_layers": 6}),
|
|
||||||
("compress_all_res", {"multiplier": 2}),
|
|
||||||
("res_x", {"num_layers": 2}),
|
|
||||||
("compress_all_res", {"multiplier": 2}),
|
|
||||||
("res_x", {"num_layers": 2}),
|
|
||||||
]
|
|
||||||
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
|
|
||||||
|
|
||||||
# Create encoder
|
|
||||||
encoder = VideoEncoder(
|
|
||||||
convolution_dimensions=3,
|
|
||||||
in_channels=3,
|
|
||||||
out_channels=128,
|
|
||||||
encoder_blocks=encoder_blocks,
|
|
||||||
patch_size=patch_size,
|
|
||||||
norm_layer=norm_layer,
|
|
||||||
latent_log_var=latent_log_var,
|
|
||||||
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load weights
|
|
||||||
weights = mx.load(str(weights_path))
|
|
||||||
|
|
||||||
# Determine prefix based on weight keys
|
|
||||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
|
||||||
|
|
||||||
if has_vae_prefix:
|
|
||||||
prefix = "vae.encoder."
|
|
||||||
stats_prefix = "vae.per_channel_statistics."
|
|
||||||
else:
|
|
||||||
prefix = "encoder."
|
|
||||||
stats_prefix = "per_channel_statistics."
|
|
||||||
|
|
||||||
# Load per-channel statistics for normalization
|
|
||||||
mean_key = f"{stats_prefix}mean-of-means"
|
|
||||||
std_key = f"{stats_prefix}std-of-means"
|
|
||||||
|
|
||||||
if mean_key in weights:
|
|
||||||
encoder.per_channel_statistics.mean = weights[mean_key]
|
|
||||||
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
|
|
||||||
if std_key in weights:
|
|
||||||
encoder.per_channel_statistics.std = weights[std_key]
|
|
||||||
print(f" Loaded latent std: shape {weights[std_key].shape}")
|
|
||||||
|
|
||||||
# Build encoder weights dict with key remapping
|
|
||||||
encoder_weights = {}
|
|
||||||
for key, value in weights.items():
|
|
||||||
if not key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Remove prefix
|
|
||||||
new_key = key[len(prefix):]
|
|
||||||
|
|
||||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
|
||||||
if ".weight" in key and value.ndim == 5:
|
|
||||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
|
||||||
|
|
||||||
encoder_weights[new_key] = value
|
|
||||||
|
|
||||||
print(f" Found {len(encoder_weights)} encoder weights")
|
|
||||||
|
|
||||||
# Load weights
|
|
||||||
encoder.load_weights(list(encoder_weights.items()), strict=False)
|
|
||||||
|
|
||||||
print("VAE encoder loaded successfully")
|
|
||||||
return encoder
|
|
||||||
|
|
||||||
|
|
||||||
def encode_image(
|
def encode_image(
|
||||||
|
|||||||
@@ -273,9 +273,10 @@ class VideoEncoder(nn.Module):
|
|||||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build encoder blocks - use dict with int keys for MLX parameter tracking
|
# Build encoder blocks
|
||||||
|
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||||
self.down_blocks = {}
|
self.down_blocks = {}
|
||||||
for i, (block_name, block_params) in enumerate(encoder_blocks):
|
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||||
|
|
||||||
block, feature_channels = _make_encoder_block(
|
block, feature_channels = _make_encoder_block(
|
||||||
@@ -287,7 +288,7 @@ class VideoEncoder(nn.Module):
|
|||||||
norm_num_groups=self._norm_num_groups,
|
norm_num_groups=self._norm_num_groups,
|
||||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||||
)
|
)
|
||||||
self.down_blocks[i] = block
|
self.down_blocks[idx] = block
|
||||||
|
|
||||||
# Output normalization and convolution
|
# Output normalization and convolution
|
||||||
if norm_layer == NormLayerType.GROUP_NORM:
|
if norm_layer == NormLayerType.GROUP_NORM:
|
||||||
@@ -341,7 +342,8 @@ class VideoEncoder(nn.Module):
|
|||||||
sample = self.conv_in(sample, causal=True)
|
sample = self.conv_in(sample, causal=True)
|
||||||
|
|
||||||
# Process through encoder blocks
|
# Process through encoder blocks
|
||||||
for down_block in self.down_blocks.values():
|
for i in range(len(self.down_blocks)):
|
||||||
|
down_block = self.down_blocks[i]
|
||||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||||
sample = down_block(sample, causal=True)
|
sample = down_block(sample, causal=True)
|
||||||
else:
|
else:
|
||||||
@@ -440,8 +442,9 @@ class VideoDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build decoder blocks (reversed order)
|
# Build decoder blocks (reversed order)
|
||||||
self.up_blocks = []
|
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
self.up_blocks = {}
|
||||||
|
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||||
|
|
||||||
block, feature_channels = _make_decoder_block(
|
block, feature_channels = _make_decoder_block(
|
||||||
@@ -454,7 +457,7 @@ class VideoDecoder(nn.Module):
|
|||||||
norm_num_groups=self._norm_num_groups,
|
norm_num_groups=self._norm_num_groups,
|
||||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||||
)
|
)
|
||||||
self.up_blocks.append(block)
|
self.up_blocks[idx] = block
|
||||||
|
|
||||||
# Output normalization
|
# Output normalization
|
||||||
if norm_layer == NormLayerType.GROUP_NORM:
|
if norm_layer == NormLayerType.GROUP_NORM:
|
||||||
@@ -509,7 +512,8 @@ class VideoDecoder(nn.Module):
|
|||||||
sample = self.conv_in(sample, causal=self.causal)
|
sample = self.conv_in(sample, causal=self.causal)
|
||||||
|
|
||||||
# Process through decoder blocks
|
# Process through decoder blocks
|
||||||
for up_block in self.up_blocks:
|
for i in range(len(self.up_blocks)):
|
||||||
|
up_block = self.up_blocks[i]
|
||||||
if isinstance(up_block, UNetMidBlock3D):
|
if isinstance(up_block, UNetMidBlock3D):
|
||||||
sample = up_block(sample, causal=self.causal)
|
sample = up_block(sample, causal=self.causal)
|
||||||
elif isinstance(up_block, ResnetBlock3D):
|
elif isinstance(up_block, ResnetBlock3D):
|
||||||
|
|||||||
Reference in New Issue
Block a user