add audio

This commit is contained in:
Prince Canuma
2026-01-16 01:15:22 +01:00
parent 81daf3f67d
commit a658911f98
19 changed files with 2335 additions and 54 deletions

View File

@@ -216,7 +216,8 @@ class ConnectorAttention(nn.Module):
self.to_q = nn.Linear(dim, inner_dim, bias=True)
self.to_k = nn.Linear(dim, inner_dim, bias=True)
self.to_v = nn.Linear(dim, inner_dim, bias=True)
self.to_out = [nn.Linear(inner_dim, dim, bias=True)]
# Direct attribute for MLX parameter tracking (not a list)
self.to_out = nn.Linear(inner_dim, dim, bias=True)
# Standard RMSNorm (not Gemma-style) on full inner_dim
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
@@ -239,21 +240,51 @@ class ConnectorAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
# Reshape to (B, H, T, D) for SPLIT RoPE
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
if pe is not None:
# pe: tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2)
# Apply SPLIT RoPE: operates on first half of head dimensions
q = self._apply_split_rope(q, pe[0], pe[1])
k = self._apply_split_rope(k, pe[0], pe[1])
# No mask needed for connector - after register replacement, all positions are valid
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
return self.to_out[0](out)
return self.to_out(out)
def _apply_split_rope(
self,
x: mx.array,
cos_freq: mx.array,
sin_freq: mx.array,
) -> mx.array:
"""Apply SPLIT RoPE to input tensor.
Args:
x: Input tensor of shape (B, H, T, D)
cos_freq: Cosine frequencies of shape (1, H, T, D//2)
sin_freq: Sine frequencies of shape (1, H, T, D//2)
Returns:
Tensor with SPLIT rotary embeddings applied
"""
# Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2)
half_dim = x.shape[-1] // 2
x1 = x[..., :half_dim]
x2 = x[..., half_dim:]
# Apply rotation: SPLIT pattern
# out1 = x1 * cos - x2 * sin
# out2 = x2 * cos + x1 * sin
out1 = x1 * cos_freq - x2 * sin_freq
out2 = x2 * cos_freq + x1 * sin_freq
return mx.concatenate([out1, out2], axis=-1)
class GEGLU(nn.Module):
@@ -272,15 +303,15 @@ class ConnectorFeedForward(nn.Module):
def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0):
super().__init__()
inner_dim = dim * mult
self.net = [
GEGLU(dim, inner_dim),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias=True),
]
# Use explicit named attributes to match weight key structure (proj_in, proj_out)
self.proj_in = nn.Linear(dim, inner_dim, bias=True)
self.dropout = nn.Dropout(dropout)
self.proj_out = nn.Linear(inner_dim, dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
for layer in self.net:
x = layer(x)
x = nn.gelu(self.proj_in(x))
x = self.dropout(x)
x = self.proj_out(x)
return x
@@ -326,6 +357,7 @@ class Embeddings1DConnector(nn.Module):
num_layers: int = 2,
num_learnable_registers: int = 128,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list = None,
):
super().__init__()
self.dim = dim
@@ -333,60 +365,69 @@ class Embeddings1DConnector(nn.Module):
self.head_dim = head_dim
self.num_learnable_registers = num_learnable_registers
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos or [4096]
self.transformer_1d_blocks = [
ConnectorTransformerBlock(dim, num_heads, head_dim)
for _ in range(num_layers)
]
# Use dict with int keys for MLX to track parameters (lists are not tracked)
self.transformer_1d_blocks = {
i: ConnectorTransformerBlock(dim, num_heads, head_dim)
for i in range(num_layers)
}
if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (INTERLEAVED type).
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
"""
import numpy as np
dim = self.num_heads * self.head_dim # inner_dim = 3840
theta = self.positional_embedding_theta
max_pos = [1] # Default for connector
max_pos = self.positional_embedding_max_pos # [4096] from PyTorch
n_elem = 2 * len(max_pos) # = 2
start = 1.0
end = theta
num_indices = dim // n_elem # 1920
# Use numpy float64 for precision
# Use numpy float64 for precision (double_precision_rope=True in PyTorch)
log_start = np.log(start) / np.log(theta) # = 0
log_end = np.log(end) / np.log(theta) # = 1
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float64)
# Generate positions and compute freqs (matches generate_freqs)
positions = np.arange(seq_len, dtype=np.float64)
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
# Scale positions by max_pos (PyTorch uses max_pos=[4096])
fractional_positions = positions / max_pos[0]
scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,)
# freqs = indices * scaled_positions (outer product)
# Shape: (seq_len, num_indices)
freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
cos_freq = np.cos(freqs)
# Compute cos/sin
cos_freq = np.cos(freqs) # (seq_len, 1920)
sin_freq = np.sin(freqs)
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
cos_full = np.repeat(cos_freq, 2, axis=-1)
sin_full = np.repeat(sin_freq, 2, axis=-1)
# For SPLIT RoPE: pad to head_dim//2 = 64 per head, then reshape to (1, H, T, D//2)
# Current: (T, 1920) -> need (1, 30, T, 64)
# 30 heads * 64 = 1920, so no padding needed
# Add batch dimension and convert to MLX: (1, seq_len, dim)
cos_full = mx.array(cos_full[None, :, :].astype(np.float32))
sin_full = mx.array(sin_full[None, :, :].astype(np.float32))
# Reshape: (T, 1920) -> (T, 30, 64) -> (1, 30, T, 64)
cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
# Transpose to (1, H, T, D//2)
cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...]
sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...]
# Convert to MLX
cos_full = mx.array(cos_freq.astype(np.float32))
sin_full = mx.array(sin_freq.astype(np.float32))
return cos_full.astype(dtype), sin_full.astype(dtype)
@@ -462,8 +503,8 @@ class Embeddings1DConnector(nn.Module):
freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype)
# Process through transformer blocks
for block in self.transformer_1d_blocks:
hidden_states = block(hidden_states, attention_mask, freqs_cis)
for i in range(len(self.transformer_1d_blocks)):
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
# Final RMS norm
hidden_states = rms_norm(hidden_states)
@@ -535,15 +576,28 @@ class GemmaFeaturesExtractor(nn.Module):
class AudioEmbeddingsConnector(nn.Module):
"""Projects video embeddings to audio cross-attention dimension."""
def __init__(self, input_dim: int = 3840, output_dim: int = 2048):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.linear(x)
class LTX2TextEncoder(nn.Module):
def __init__(
self,
hidden_dim: int = 3840,
audio_dim: int = 2048,
num_layers: int = 49, # 48 transformer layers + 1 embedding
):
super().__init__()
self.hidden_dim = hidden_dim
self.audio_dim = audio_dim
self.num_layers = num_layers
self.language_model = None
@@ -560,14 +614,26 @@ class LTX2TextEncoder(nn.Module):
head_dim=128,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[4096], # Match PyTorch
)
# Audio embeddings connector: separate 2-layer transformer (same architecture as video)
# Both connectors process the feature extractor output independently
self.audio_embeddings_connector = Embeddings1DConnector(
dim=hidden_dim,
num_heads=30,
head_dim=128,
num_layers=2,
num_learnable_registers=128,
positional_embedding_max_pos=[4096], # Match PyTorch
)
self.processor = None
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
if Path(text_encoder_path / "text_encoder").is_dir():
text_encoder_path = str(text_encoder_path / "text_encoder")
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
@@ -594,10 +660,14 @@ class LTX2TextEncoder(nn.Module):
# Map weight names to our structure
mapped_weights = {}
for key, value in connector_weights.items():
# transformer_1d_blocks.X.attn1.* -> transformer_1d_blocks.X.attn1.*
# transformer_1d_blocks.X.ff.net.0.proj.* -> transformer_1d_blocks.X.ff.net.0.proj.*
# transformer_1d_blocks.X.ff.net.2.* -> transformer_1d_blocks.X.ff.net.2.*
mapped_weights[key] = value
new_key = key
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
# Map ff.net.2 -> ff.proj_out (output Linear)
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
# Map to_out.0 -> to_out (Sequential -> direct)
new_key = new_key.replace(".to_out.0.", ".to_out.")
mapped_weights[new_key] = value
self.video_embeddings_connector.load_weights(
list(mapped_weights.items()), strict=False
@@ -607,6 +677,34 @@ class LTX2TextEncoder(nn.Module):
if "learnable_registers" in connector_weights:
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
# Load audio_embeddings_connector weights (same structure as video connector)
audio_connector_weights = {}
for key, value in transformer_weights.items():
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
audio_connector_weights[new_key] = value
if audio_connector_weights:
# Map weight names to our structure (same as video connector)
mapped_audio_weights = {}
for key, value in audio_connector_weights.items():
new_key = key
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
# Map ff.net.2 -> ff.proj_out (output Linear)
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
# Map to_out.0 -> to_out (Sequential -> direct)
new_key = new_key.replace(".to_out.0.", ".to_out.")
mapped_audio_weights[new_key] = value
self.audio_embeddings_connector.load_weights(
list(mapped_audio_weights.items()), strict=False
)
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in audio_connector_weights:
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
# Load tokenizer
from transformers import AutoTokenizer
tokenizer_path = model_path / "tokenizer"
@@ -623,8 +721,20 @@ class LTX2TextEncoder(nn.Module):
self,
prompt: str,
max_length: int = 1024,
return_audio_embeddings: bool = True,
) -> Tuple[mx.array, mx.array]:
"""Encode text prompt to video and audio embeddings.
Args:
prompt: Text prompt to encode
max_length: Maximum token length (default 1024 to match official PyTorch)
return_audio_embeddings: If True, returns (video_emb, audio_emb).
If False, returns (video_emb, attention_mask).
Returns:
Tuple of (video_embeddings, audio_embeddings) if return_audio_embeddings=True
Tuple of (video_embeddings, attention_mask) otherwise
"""
if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.")
@@ -649,16 +759,33 @@ class LTX2TextEncoder(nn.Module):
additive_mask = (attention_mask - 1).astype(features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
video_embeddings, _ = self.video_embeddings_connector(features, additive_mask)
return embeddings, attention_mask
if return_audio_embeddings:
# Process features through audio connector independently (same input as video)
audio_embeddings, _ = self.audio_embeddings_connector(features, additive_mask)
return video_embeddings, audio_embeddings
else:
return video_embeddings, attention_mask
def __call__(
self,
prompt: str,
max_length: int = 1024,
return_audio_embeddings: bool = True,
) -> Tuple[mx.array, mx.array]:
return self.encode(prompt, max_length)
"""Encode text prompt.
Args:
prompt: Text prompt to encode
max_length: Maximum token length (default 1024 to match official PyTorch)
return_audio_embeddings: If True, returns (video_emb, audio_emb).
If False, returns (video_emb, attention_mask).
Returns:
Tuple of embeddings based on return_audio_embeddings flag
"""
return self.encode(prompt, max_length, return_audio_embeddings)
@functools.cached_property
def default_t2v_system_prompt(self) -> str:
@@ -833,7 +960,7 @@ class LTX2TextEncoder(nn.Module):
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder(model_path=model_path)
encoder.load()
encoder = LTX2TextEncoder()
encoder.load(model_path=model_path)
return encoder