format
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
@@ -22,9 +21,11 @@ class LTXRopeType(Enum):
|
||||
SPLIT = "split"
|
||||
TWO_D = "2d"
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
|
||||
@@ -46,7 +47,7 @@ class BaseModelConfig:
|
||||
if v is not None:
|
||||
if isinstance(v, Enum):
|
||||
result[k] = v.value
|
||||
elif hasattr(v, 'to_dict'):
|
||||
elif hasattr(v, "to_dict"):
|
||||
result[k] = v.to_dict()
|
||||
else:
|
||||
result[k] = v
|
||||
@@ -68,26 +69,30 @@ class VideoVAEConfig(BaseModelConfig):
|
||||
out_channels: int = 128
|
||||
latent_channels: int = 128
|
||||
patch_size: int = 4
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
])
|
||||
decoder_blocks: List[tuple] = field(default_factory=lambda: [
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
])
|
||||
encoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
)
|
||||
decoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
("compress_all", {"residual": True, "multiplier": 2}),
|
||||
("res_x", {"num_layers": 5, "inject_noise": False}),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -111,7 +116,9 @@ class LTXModelConfig(BaseModelConfig):
|
||||
audio_in_channels: int = 128
|
||||
audio_out_channels: int = 128
|
||||
audio_cross_attention_dim: int = 2048
|
||||
audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video)
|
||||
audio_caption_channels: int = (
|
||||
3840 # Input dim for audio text embeddings (same as video)
|
||||
)
|
||||
|
||||
# Positional embedding config
|
||||
positional_embedding_theta: float = 10000.0
|
||||
@@ -196,7 +203,6 @@ class LTXModelConfig(BaseModelConfig):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class CausalityAxis(Enum):
|
||||
"""Enum for specifying the causality axis in causal convolutions."""
|
||||
|
||||
@@ -237,21 +243,22 @@ class AudioDecoderModelConfig(BaseModelConfig):
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
# Import here to avoid circular imports
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
|
||||
from .audio_vae.normalization import NormType
|
||||
|
||||
# Convert causality_axis string to enum
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
|
||||
|
||||
# Convert norm_type string to enum
|
||||
if isinstance(self.norm_type, str):
|
||||
self.norm_type = NormType(self.norm_type)
|
||||
|
||||
|
||||
# Convert attn_type string to enum
|
||||
if isinstance(self.attn_type, str):
|
||||
self.attn_type = AttentionType(self.attn_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioEncoderModelConfig(BaseModelConfig):
|
||||
ch: int = 128
|
||||
@@ -282,8 +289,8 @@ class AudioEncoderModelConfig(BaseModelConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
"""Convert string enum values to proper enum types."""
|
||||
from .audio_vae.normalization import NormType
|
||||
from .audio_vae.attention import AttentionType
|
||||
from .audio_vae.normalization import NormType
|
||||
|
||||
if isinstance(self.causality_axis, str):
|
||||
self.causality_axis = CausalityAxis(self.causality_axis)
|
||||
@@ -334,6 +341,7 @@ class VideoDecoderModelConfig(BaseModelConfig):
|
||||
dropout: float = 0.0
|
||||
timestep_conditioning: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderModelConfig(BaseModelConfig):
|
||||
convolution_dimensions: int = 3
|
||||
@@ -343,21 +351,24 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
||||
norm_layer: Enum = None
|
||||
latent_log_var: Enum = None
|
||||
encoder_spatial_padding_mode: Enum = None
|
||||
encoder_blocks: List[tuple] = field(default_factory=lambda: [("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2})
|
||||
])
|
||||
encoder_blocks: List[tuple] = field(
|
||||
default_factory=lambda: [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
from mlx_video.models.ltx_2.video_vae.resnet import NormLayerType
|
||||
from mlx_video.models.ltx_2.video_vae.video_vae import LogVarianceType
|
||||
from mlx_video.models.ltx_2.video_vae.convolution import PaddingModeType
|
||||
|
||||
if self.norm_layer is None:
|
||||
self.norm_layer = NormLayerType.PIXEL_NORM
|
||||
@@ -371,10 +382,12 @@ class VideoEncoderModelConfig(BaseModelConfig):
|
||||
if isinstance(self.latent_log_var, str):
|
||||
self.latent_log_var = LogVarianceType(self.latent_log_var)
|
||||
if isinstance(self.encoder_spatial_padding_mode, str):
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(self.encoder_spatial_padding_mode)
|
||||
self.encoder_spatial_padding_mode = PaddingModeType(
|
||||
self.encoder_spatial_padding_mode
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = super().to_dict()
|
||||
if self.encoder_blocks is not None:
|
||||
result["encoder_blocks"] = [list(block) for block in self.encoder_blocks]
|
||||
return result
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user