Refactor weight loading and sanitization processes for audio models

This commit is contained in:
Prince Canuma
2026-01-23 17:31:25 +01:00
parent 2681f75d2f
commit 02bfa228d9
18 changed files with 510 additions and 498 deletions

View File

@@ -2,7 +2,7 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple, Set
class LTXModelType(Enum):
@@ -180,3 +180,141 @@ class LTXModelConfig(BaseModelConfig):
d_head=self.audio_attention_head_dim,
context_dim=self.audio_cross_attention_dim,
)
class CausalityAxis(Enum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
@dataclass
class AudioDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int | None = None
resamp_with_conv: bool = True
attn_type: str = None
give_pre_end: bool = False
tanh_out: bool = False
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
# Import here to avoid circular imports
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
# Convert causality_axis string to enum
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
# Convert norm_type string to enum
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
# Convert attn_type string to enum
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class VocoderModelConfig(BaseModelConfig):
resblock_kernel_sizes: Optional[List[int]] = None
upsample_rates: Optional[List[int]] = None
upsample_kernel_sizes: Optional[List[int]] = None
resblock_dilation_sizes: Optional[List[List[int]]] = None
upsample_initial_channel: int = 1024
stereo: bool = True
resblock: str = "1"
output_sample_rate: int = 24000
def __post_init__(self):
if self.resblock_kernel_sizes is None:
self.resblock_kernel_sizes = [3, 7, 11]
if self.upsample_rates is None:
self.upsample_rates = [6, 5, 2, 2, 2]
if self.upsample_kernel_sizes is None:
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
if self.resblock_dilation_sizes is None:
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@dataclass
class VideoDecoderModelConfig(BaseModelConfig):
ch: int = 128
out_ch: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
@dataclass
class VideoEncoderModelConfig(BaseModelConfig):
convolution_dimensions: int = 3
in_channels : int = 3,
out_channels: int = 128,
patch_size: int = 4,
norm_layer: Enum = None,
latent_log_var: Enum = None,
encoder_spatial_padding_mode: Enum = None,
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2})
])
def __post_init__(self):
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
if self.norm_layer is None:
self.norm_layer = NormLayerType.PIXEL_NORM
if self.latent_log_var is None:
self.latent_log_var = LogVarianceType.UNIFORM
if self.encoder_spatial_padding_mode is None:
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
if isinstance(self.norm_layer, str):
self.norm_layer = NormLayerType(self.norm_layer)
if isinstance(self.latent_log_var, str):
self.latent_log_var = LogVarianceType(self.latent_log_var)
if isinstance(self.encoder_spatial_padding_mode, str):
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.encoder_blocks is not None:
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
return result