add audio
This commit is contained in:
@@ -216,7 +216,8 @@ class ConnectorAttention(nn.Module):
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.to_out = [nn.Linear(inner_dim, dim, bias=True)]
|
||||
# Direct attribute for MLX parameter tracking (not a list)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=True)
|
||||
|
||||
# Standard RMSNorm (not Gemma-style) on full inner_dim
|
||||
self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6)
|
||||
@@ -239,21 +240,51 @@ class ConnectorAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
|
||||
if pe is not None:
|
||||
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
|
||||
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
|
||||
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
|
||||
|
||||
# Reshape to (B, H, T, D) for SPLIT RoPE
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
|
||||
if pe is not None:
|
||||
# pe: tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2)
|
||||
# Apply SPLIT RoPE: operates on first half of head dimensions
|
||||
q = self._apply_split_rope(q, pe[0], pe[1])
|
||||
k = self._apply_split_rope(k, pe[0], pe[1])
|
||||
|
||||
# No mask needed for connector - after register replacement, all positions are valid
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||
|
||||
return self.to_out[0](out)
|
||||
return self.to_out(out)
|
||||
|
||||
def _apply_split_rope(
|
||||
self,
|
||||
x: mx.array,
|
||||
cos_freq: mx.array,
|
||||
sin_freq: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply SPLIT RoPE to input tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape (B, H, T, D)
|
||||
cos_freq: Cosine frequencies of shape (1, H, T, D//2)
|
||||
sin_freq: Sine frequencies of shape (1, H, T, D//2)
|
||||
|
||||
Returns:
|
||||
Tensor with SPLIT rotary embeddings applied
|
||||
"""
|
||||
# Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2)
|
||||
half_dim = x.shape[-1] // 2
|
||||
x1 = x[..., :half_dim]
|
||||
x2 = x[..., half_dim:]
|
||||
|
||||
# Apply rotation: SPLIT pattern
|
||||
# out1 = x1 * cos - x2 * sin
|
||||
# out2 = x2 * cos + x1 * sin
|
||||
out1 = x1 * cos_freq - x2 * sin_freq
|
||||
out2 = x2 * cos_freq + x1 * sin_freq
|
||||
|
||||
return mx.concatenate([out1, out2], axis=-1)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
@@ -272,15 +303,15 @@ class ConnectorFeedForward(nn.Module):
|
||||
def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim * mult
|
||||
self.net = [
|
||||
GEGLU(dim, inner_dim),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias=True),
|
||||
]
|
||||
# Use explicit named attributes to match weight key structure (proj_in, proj_out)
|
||||
self.proj_in = nn.Linear(dim, inner_dim, bias=True)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.proj_out = nn.Linear(inner_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
for layer in self.net:
|
||||
x = layer(x)
|
||||
x = nn.gelu(self.proj_in(x))
|
||||
x = self.dropout(x)
|
||||
x = self.proj_out(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -326,6 +357,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
num_layers: int = 2,
|
||||
num_learnable_registers: int = 128,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -333,60 +365,69 @@ class Embeddings1DConnector(nn.Module):
|
||||
self.head_dim = head_dim
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = positional_embedding_max_pos or [4096]
|
||||
|
||||
self.transformer_1d_blocks = [
|
||||
ConnectorTransformerBlock(dim, num_heads, head_dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
# Use dict with int keys for MLX to track parameters (lists are not tracked)
|
||||
self.transformer_1d_blocks = {
|
||||
i: ConnectorTransformerBlock(dim, num_heads, head_dim)
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
if num_learnable_registers > 0:
|
||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
||||
"""Compute RoPE frequencies for connector (SPLIT type matching PyTorch).
|
||||
|
||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||
Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
||||
theta = self.positional_embedding_theta
|
||||
max_pos = [1] # Default for connector
|
||||
max_pos = self.positional_embedding_max_pos # [4096] from PyTorch
|
||||
n_elem = 2 * len(max_pos) # = 2
|
||||
|
||||
start = 1.0
|
||||
end = theta
|
||||
num_indices = dim // n_elem # 1920
|
||||
|
||||
# Use numpy float64 for precision
|
||||
# Use numpy float64 for precision (double_precision_rope=True in PyTorch)
|
||||
log_start = np.log(start) / np.log(theta) # = 0
|
||||
log_end = np.log(end) / np.log(theta) # = 1
|
||||
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
||||
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
|
||||
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float64)
|
||||
|
||||
# Generate positions and compute freqs (matches generate_freqs)
|
||||
positions = np.arange(seq_len, dtype=np.float64)
|
||||
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
||||
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
||||
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
||||
# Scale positions by max_pos (PyTorch uses max_pos=[4096])
|
||||
fractional_positions = positions / max_pos[0]
|
||||
scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,)
|
||||
|
||||
# freqs = indices * scaled_positions (outer product)
|
||||
# Shape: (seq_len, num_indices)
|
||||
freqs = scaled_positions[:, None] * indices[None, :]
|
||||
|
||||
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
||||
cos_freq = np.cos(freqs)
|
||||
# Compute cos/sin
|
||||
cos_freq = np.cos(freqs) # (seq_len, 1920)
|
||||
sin_freq = np.sin(freqs)
|
||||
|
||||
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
||||
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
||||
cos_full = np.repeat(cos_freq, 2, axis=-1)
|
||||
sin_full = np.repeat(sin_freq, 2, axis=-1)
|
||||
# For SPLIT RoPE: pad to head_dim//2 = 64 per head, then reshape to (1, H, T, D//2)
|
||||
# Current: (T, 1920) -> need (1, 30, T, 64)
|
||||
# 30 heads * 64 = 1920, so no padding needed
|
||||
|
||||
# Add batch dimension and convert to MLX: (1, seq_len, dim)
|
||||
cos_full = mx.array(cos_full[None, :, :].astype(np.float32))
|
||||
sin_full = mx.array(sin_full[None, :, :].astype(np.float32))
|
||||
# Reshape: (T, 1920) -> (T, 30, 64) -> (1, 30, T, 64)
|
||||
cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
|
||||
sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2)
|
||||
|
||||
# Transpose to (1, H, T, D//2)
|
||||
cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...]
|
||||
sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...]
|
||||
|
||||
# Convert to MLX
|
||||
cos_full = mx.array(cos_freq.astype(np.float32))
|
||||
sin_full = mx.array(sin_freq.astype(np.float32))
|
||||
|
||||
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||
|
||||
@@ -462,8 +503,8 @@ class Embeddings1DConnector(nn.Module):
|
||||
freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype)
|
||||
|
||||
# Process through transformer blocks
|
||||
for block in self.transformer_1d_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask, freqs_cis)
|
||||
for i in range(len(self.transformer_1d_blocks)):
|
||||
hidden_states = self.transformer_1d_blocks[i](hidden_states, attention_mask, freqs_cis)
|
||||
|
||||
# Final RMS norm
|
||||
hidden_states = rms_norm(hidden_states)
|
||||
@@ -535,15 +576,28 @@ class GemmaFeaturesExtractor(nn.Module):
|
||||
|
||||
|
||||
|
||||
class AudioEmbeddingsConnector(nn.Module):
|
||||
"""Projects video embeddings to audio cross-attention dimension."""
|
||||
|
||||
def __init__(self, input_dim: int = 3840, output_dim: int = 2048):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class LTX2TextEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int = 3840,
|
||||
audio_dim: int = 2048,
|
||||
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.audio_dim = audio_dim
|
||||
self.num_layers = num_layers
|
||||
self.language_model = None
|
||||
|
||||
@@ -560,14 +614,26 @@ class LTX2TextEncoder(nn.Module):
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], # Match PyTorch
|
||||
)
|
||||
|
||||
# Audio embeddings connector: separate 2-layer transformer (same architecture as video)
|
||||
# Both connectors process the feature extractor output independently
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
dim=hidden_dim,
|
||||
num_heads=30,
|
||||
head_dim=128,
|
||||
num_layers=2,
|
||||
num_learnable_registers=128,
|
||||
positional_embedding_max_pos=[4096], # Match PyTorch
|
||||
)
|
||||
|
||||
self.processor = None
|
||||
|
||||
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
||||
|
||||
if Path(text_encoder_path / "text_encoder").is_dir():
|
||||
text_encoder_path = str(text_encoder_path / "text_encoder")
|
||||
if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir():
|
||||
text_encoder_path = str(Path(text_encoder_path) / "text_encoder")
|
||||
|
||||
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||
|
||||
@@ -594,10 +660,14 @@ class LTX2TextEncoder(nn.Module):
|
||||
# Map weight names to our structure
|
||||
mapped_weights = {}
|
||||
for key, value in connector_weights.items():
|
||||
# transformer_1d_blocks.X.attn1.* -> transformer_1d_blocks.X.attn1.*
|
||||
# transformer_1d_blocks.X.ff.net.0.proj.* -> transformer_1d_blocks.X.ff.net.0.proj.*
|
||||
# transformer_1d_blocks.X.ff.net.2.* -> transformer_1d_blocks.X.ff.net.2.*
|
||||
mapped_weights[key] = value
|
||||
new_key = key
|
||||
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
mapped_weights[new_key] = value
|
||||
|
||||
self.video_embeddings_connector.load_weights(
|
||||
list(mapped_weights.items()), strict=False
|
||||
@@ -607,6 +677,34 @@ class LTX2TextEncoder(nn.Module):
|
||||
if "learnable_registers" in connector_weights:
|
||||
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
||||
|
||||
# Load audio_embeddings_connector weights (same structure as video connector)
|
||||
audio_connector_weights = {}
|
||||
for key, value in transformer_weights.items():
|
||||
if key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "")
|
||||
audio_connector_weights[new_key] = value
|
||||
|
||||
if audio_connector_weights:
|
||||
# Map weight names to our structure (same as video connector)
|
||||
mapped_audio_weights = {}
|
||||
for key, value in audio_connector_weights.items():
|
||||
new_key = key
|
||||
# Map ff.net.0.proj -> ff.proj_in (GEGLU projection)
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.")
|
||||
# Map ff.net.2 -> ff.proj_out (output Linear)
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.")
|
||||
# Map to_out.0 -> to_out (Sequential -> direct)
|
||||
new_key = new_key.replace(".to_out.0.", ".to_out.")
|
||||
mapped_audio_weights[new_key] = value
|
||||
|
||||
self.audio_embeddings_connector.load_weights(
|
||||
list(mapped_audio_weights.items()), strict=False
|
||||
)
|
||||
|
||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||
if "learnable_registers" in audio_connector_weights:
|
||||
self.audio_embeddings_connector.learnable_registers = audio_connector_weights["learnable_registers"]
|
||||
|
||||
# Load tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer_path = model_path / "tokenizer"
|
||||
@@ -623,8 +721,20 @@ class LTX2TextEncoder(nn.Module):
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: int = 1024,
|
||||
return_audio_embeddings: bool = True,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Encode text prompt to video and audio embeddings.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt to encode
|
||||
max_length: Maximum token length (default 1024 to match official PyTorch)
|
||||
return_audio_embeddings: If True, returns (video_emb, audio_emb).
|
||||
If False, returns (video_emb, attention_mask).
|
||||
|
||||
Returns:
|
||||
Tuple of (video_embeddings, audio_embeddings) if return_audio_embeddings=True
|
||||
Tuple of (video_embeddings, attention_mask) otherwise
|
||||
"""
|
||||
if self.processor is None:
|
||||
raise RuntimeError("Model not loaded. Call load() first.")
|
||||
|
||||
@@ -649,16 +759,33 @@ class LTX2TextEncoder(nn.Module):
|
||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
|
||||
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||
video_embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||
|
||||
return embeddings, attention_mask
|
||||
if return_audio_embeddings:
|
||||
# Process features through audio connector independently (same input as video)
|
||||
audio_embeddings, _ = self.audio_embeddings_connector(features, additive_mask)
|
||||
return video_embeddings, audio_embeddings
|
||||
else:
|
||||
return video_embeddings, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
max_length: int = 1024,
|
||||
return_audio_embeddings: bool = True,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
return self.encode(prompt, max_length)
|
||||
"""Encode text prompt.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt to encode
|
||||
max_length: Maximum token length (default 1024 to match official PyTorch)
|
||||
return_audio_embeddings: If True, returns (video_emb, audio_emb).
|
||||
If False, returns (video_emb, attention_mask).
|
||||
|
||||
Returns:
|
||||
Tuple of embeddings based on return_audio_embeddings flag
|
||||
"""
|
||||
return self.encode(prompt, max_length, return_audio_embeddings)
|
||||
|
||||
@functools.cached_property
|
||||
def default_t2v_system_prompt(self) -> str:
|
||||
@@ -833,7 +960,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
|
||||
|
||||
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||
encoder = LTX2TextEncoder(model_path=model_path)
|
||||
encoder.load()
|
||||
encoder = LTX2TextEncoder()
|
||||
encoder.load(model_path=model_path)
|
||||
return encoder
|
||||
|
||||
|
||||
Reference in New Issue
Block a user