Add audio to video conditioning

This commit is contained in:
Prince Canuma
2026-03-16 01:42:11 +01:00
parent f53b9e0807
commit 6f6105b715
7 changed files with 623 additions and 62 deletions

View File

@@ -2,7 +2,7 @@
import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, List, Optional, Tuple, Set
from typing import Any, List, Optional, Tuple
class LTXModelType(Enum):
@@ -252,6 +252,47 @@ class AudioDecoderModelConfig(BaseModelConfig):
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class AudioEncoderModelConfig(BaseModelConfig):
ch: int = 128
in_channels: int = 2
ch_mult: Tuple[int, ...] = (1, 2, 4)
num_res_blocks: int = 2
attn_resolutions: Optional[List[int]] = None
resolution: int = 256
z_channels: int = 8
double_z: bool = True
n_fft: int = 1024
norm_type: Enum = None
causality_axis: Enum = None
dropout: float = 0.0
mid_block_add_attention: bool = True
sample_rate: int = 16000
mel_hop_length: int = 160
is_causal: bool = True
mel_bins: int = 64
resamp_with_conv: bool = True
attn_type: str = None
def to_dict(self) -> dict[str, Any]:
result = super().to_dict()
if self.attn_resolutions is not None:
result["attn_resolutions"] = list(self.attn_resolutions)
return result
def __post_init__(self):
"""Convert string enum values to proper enum types."""
from .audio_vae.normalization import NormType
from .audio_vae.attention import AttentionType
if isinstance(self.causality_axis, str):
self.causality_axis = CausalityAxis(self.causality_axis)
if isinstance(self.norm_type, str):
self.norm_type = NormType(self.norm_type)
if isinstance(self.attn_type, str):
self.attn_type = AttentionType(self.attn_type)
@dataclass
class VocoderModelConfig(BaseModelConfig):
resblock_kernel_sizes: Optional[List[int]] = None