import inspect from dataclasses import dataclass, field from enum import Enum from typing import Any, List, Optional class LTXModelType(Enum): AudioVideo = "ltx av model" VideoOnly = "ltx video only model" AudioOnly = "ltx audio only model" def is_video_enabled(self) -> bool: return self in (LTXModelType.AudioVideo, LTXModelType.VideoOnly) def is_audio_enabled(self) -> bool: return self in (LTXModelType.AudioVideo, LTXModelType.AudioOnly) class LTXRopeType(Enum): INTERLEAVED = "interleaved" SPLIT = "split" TWO_D = "2d" class AttentionType(Enum): DEFAULT = "default" @dataclass class BaseModelConfig: @classmethod def from_dict(cls, params: dict[str, Any]) -> "BaseModelConfig": """Create config from dictionary, filtering only valid parameters.""" return cls( **{ k: v for k, v in params.items() if k in inspect.signature(cls).parameters } ) def to_dict(self) -> dict[str, Any]: """Export config to dictionary.""" result = {} for k, v in self.__dict__.items(): if v is not None: if isinstance(v, Enum): result[k] = v.value elif hasattr(v, 'to_dict'): result[k] = v.to_dict() else: result[k] = v return result @dataclass class TransformerConfig(BaseModelConfig): dim: int heads: int d_head: int context_dim: int @dataclass class VideoVAEConfig(BaseModelConfig): convolution_dimensions: int = 3 in_channels: int = 3 out_channels: int = 128 latent_channels: int = 128 patch_size: int = 4 encoder_blocks: List[tuple] = field(default_factory=lambda: [ ("res_x", {"num_layers": 4}), ("compress_space_res", {"multiplier": 2}), ("res_x", {"num_layers": 6}), ("compress_time_res", {"multiplier": 2}), ("res_x", {"num_layers": 6}), ("compress_all_res", {"multiplier": 2}), ("res_x", {"num_layers": 2}), ("compress_all_res", {"multiplier": 2}), ("res_x", {"num_layers": 2}), ]) decoder_blocks: List[tuple] = field(default_factory=lambda: [ ("res_x", {"num_layers": 5, "inject_noise": False}), ("compress_all", {"residual": True, "multiplier": 2}), ("res_x", {"num_layers": 5, "inject_noise": False}), ("compress_all", {"residual": True, "multiplier": 2}), ("res_x", {"num_layers": 5, "inject_noise": False}), ("compress_all", {"residual": True, "multiplier": 2}), ("res_x", {"num_layers": 5, "inject_noise": False}), ]) @dataclass class LTXModelConfig(BaseModelConfig): # Model type model_type: LTXModelType = LTXModelType.AudioVideo # Video transformer config num_attention_heads: int = 32 attention_head_dim: int = 128 in_channels: int = 128 out_channels: int = 128 num_layers: int = 48 cross_attention_dim: int = 4096 caption_channels: int = 3840 # Audio transformer config audio_num_attention_heads: int = 32 audio_attention_head_dim: int = 64 audio_in_channels: int = 128 audio_out_channels: int = 128 audio_cross_attention_dim: int = 2048 audio_caption_channels: int = 3840 # Input dim for audio text embeddings (same as video) # Positional embedding config positional_embedding_theta: float = 10000.0 positional_embedding_max_pos: Optional[List[int]] = None audio_positional_embedding_max_pos: Optional[List[int]] = None use_middle_indices_grid: bool = True rope_type: LTXRopeType = LTXRopeType.INTERLEAVED double_precision_rope: bool = False # Timestep config timestep_scale_multiplier: int = 1000 av_ca_timestep_scale_multiplier: int = 1 # Normalization norm_eps: float = 1e-6 # Attention type attention_type: AttentionType = AttentionType.DEFAULT # VAE config vae_config: Optional[VideoVAEConfig] = None def __post_init__(self): """Set default values after initialization.""" if self.positional_embedding_max_pos is None: self.positional_embedding_max_pos = [20, 2048, 2048] if self.audio_positional_embedding_max_pos is None: self.audio_positional_embedding_max_pos = [20] # Convert string enum values if loading from dict if isinstance(self.model_type, str): self.model_type = LTXModelType(self.model_type) if isinstance(self.rope_type, str): self.rope_type = LTXRopeType(self.rope_type) if isinstance(self.attention_type, str): self.attention_type = AttentionType(self.attention_type) @property def inner_dim(self) -> int: """Video inner dimension.""" return self.num_attention_heads * self.attention_head_dim @property def audio_inner_dim(self) -> int: """Audio inner dimension.""" return self.audio_num_attention_heads * self.audio_attention_head_dim def get_video_config(self) -> Optional[TransformerConfig]: """Get video transformer configuration.""" if not self.model_type.is_video_enabled(): return None return TransformerConfig( dim=self.inner_dim, heads=self.num_attention_heads, d_head=self.attention_head_dim, context_dim=self.cross_attention_dim, ) def get_audio_config(self) -> Optional[TransformerConfig]: """Get audio transformer configuration.""" if not self.model_type.is_audio_enabled(): return None return TransformerConfig( dim=self.audio_inner_dim, heads=self.audio_num_attention_heads, d_head=self.audio_attention_head_dim, context_dim=self.audio_cross_attention_dim, )