Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Tuple, Set
|
||||
|
||||
|
||||
class LTXModelType(Enum):
|
||||
@@ -180,3 +180,141 @@ class LTXModelConfig(BaseModelConfig):
|
||||
d_head=self.audio_attention_head_dim,
|
||||
context_dim=self.audio_cross_attention_dim,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class CausalityAxis(Enum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
NONE = None
|
||||
WIDTH = "width"
|
||||
HEIGHT = "height"
|
||||
WIDTH_COMPATIBILITY = "width-compatibility"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioDecoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
out_ch: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
mid_block_add_attention: bool = True
|
||||
sample_rate: int = 16000
|
||||
mel_hop_length: int = 160
|
||||
is_causal: bool = True
|
||||
mel_bins: int | None = None
|
||||
resamp_with_conv: bool = True
|
||||
attn_type: str = None
|
||||
give_pre_end: bool = False
|
||||
tanh_out: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.attn_resolutions is not None:
|
||||
result["attn_resolutions"] = list(self.attn_resolutions)
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
# Import here to avoid circular imports
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
|
||||
# Convert causality_axis string to enum
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
|
||||
# Convert norm_type string to enum
|
||||
if isinstance(self.norm_type, str):
|
||||
self.norm_type = NormType(self.norm_type)
|
||||
|
||||
# Convert attn_type string to enum
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
@dataclass
|
||||
class VocoderModelConfig(BaseModelConfig):
|
||||
resblock_kernel_sizes: Optional[List[int]] = None
|
||||
upsample_rates: Optional[List[int]] = None
|
||||
upsample_kernel_sizes: Optional[List[int]] = None
|
||||
resblock_dilation_sizes: Optional[List[List[int]]] = None
|
||||
upsample_initial_channel: int = 1024
|
||||
stereo: bool = True
|
||||
resblock: str = "1"
|
||||
output_sample_rate: int = 24000
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.resblock_kernel_sizes is None:
|
||||
self.resblock_kernel_sizes = [3, 7, 11]
|
||||
if self.upsample_rates is None:
|
||||
self.upsample_rates = [6, 5, 2, 2, 2]
|
||||
if self.upsample_kernel_sizes is None:
|
||||
self.upsample_kernel_sizes = [16, 15, 8, 4, 4]
|
||||
if self.resblock_dilation_sizes is None:
|
||||
self.resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoDecoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
out_ch: int = 2
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4)
|
||||
num_res_blocks: int = 2
|
||||
attn_resolutions: Optional[List[int]] = None
|
||||
resolution: int = 256
|
||||
z_channels: int = 8
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderModelConfig(BaseModelConfig):
|
||||
convolution_dimensions: int = 3
|
||||
in_channels : int = 3,
|
||||
out_channels: int = 128,
|
||||
patch_size: int = 4,
|
||||
norm_layer: Enum = None,
|
||||
latent_log_var: Enum = None,
|
||||
encoder_spatial_padding_mode: Enum = None,
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2})
|
||||
])
|
||||
|
||||
def __post_init__(self):
|
||||
from mlx_video.models.ltx.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx.video_vae.convolution import PaddingModeType
|
||||
|
||||
if self.norm_layer is None:
|
||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||
if self.latent_log_var is None:
|
||||
self.latent_log_var = LogVarianceType.UNIFORM
|
||||
if self.encoder_spatial_padding_mode is None:
|
||||
self.encoder_spatial_padding_mode = PaddingModeType.ZEROS
|
||||
|
||||
if isinstance(self.norm_layer, str):
|
||||
self.norm_layer = NormLayerType(self.norm_layer)
|
||||
if isinstance(self.latent_log_var, str):
|
||||
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
||||
if isinstance(self.encoder_spatial_padding_mode, str):
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.encoder_blocks is not None:
|
||||
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
|
||||
return result
|
||||
Reference in New Issue
Block a user