import functools import logging import math import re from pathlib import Path from typing import Dict, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn from mlx_vlm.models.gemma3.config import TextConfig from mlx_vlm.models.gemma3.language import Gemma3Model from rich.console import Console from rich.progress import ( BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn, TimeRemainingColumn, ) from mlx_video.utils import apply_quantization, rms_norm # Path to system prompts PROMPTS_DIR = Path(__file__).parent / "prompts" def _load_system_prompt(prompt_name: str) -> str: """Load a system prompt from the prompts directory.""" prompt_path = PROMPTS_DIR / prompt_name if prompt_path.exists(): with open(prompt_path, "r") as f: return f.read() raise FileNotFoundError(f"System prompt not found: {prompt_path}") class LanguageModel(nn.Module): def __init__(self, config: TextConfig): super().__init__() # Create config matching LTX-2 text encoder requirements self.config = config # Create the Gemma3Model from mlx-vlm self.model = Gemma3Model(self.config) def _create_causal_mask_with_padding( self, seq_len: int, attention_mask: Optional[mx.array], dtype: mx.Dtype, ) -> mx.array: causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_)) if attention_mask is not None: batch_size = attention_mask.shape[0] padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len) combined = causal_mask[None, :, :] & padding_mask[:, None, :] min_val = ( mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9 ) mask = mx.where( combined, mx.zeros(combined.shape, dtype=dtype), mx.full(combined.shape, min_val, dtype=dtype), ) return mask[:, None, :, :] else: # No padding mask, just causal min_val = ( mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9 ) mask = mx.where( causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype), mx.full((seq_len, seq_len), min_val, dtype=dtype), ) return mask[None, None, :, :] # (1, 1, seq, seq) def __call__( self, inputs: mx.array, input_embeddings: Optional[mx.array] = None, attention_mask: Optional[mx.array] = None, output_hidden_states: bool = False, cache: Optional[List[mx.array]] = None, ) -> Tuple[mx.array, List[mx.array]]: """Forward pass returning hidden states. Args: input_ids: Input token IDs of shape (batch, seq_len) attention_mask: Optional attention mask (1 for valid, 0 for padding) output_hidden_states: Whether to return all hidden states Returns: Tuple of (final_hidden_states, list_of_all_hidden_states) """ batch_size, seq_len = inputs.shape # Get embeddings h = ( input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs) ) # Apply Gemma scaling h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype) mx.eval(h) all_hidden_states = [h] if output_hidden_states else [] # Set up cache (all None for non-cached inference) if cache is None: cache = [None] * len(self.model.layers) full_causal_mask = self._create_causal_mask_with_padding( seq_len, attention_mask, h.dtype ) sliding_mask = full_causal_mask num_layers = len(self.model.layers) for i, layer in enumerate(self.model.layers): is_global = ( i % self.config.sliding_window_pattern == self.config.sliding_window_pattern - 1 ) # Select appropriate mask for this layer if is_global: local_mask = full_causal_mask else: local_mask = sliding_mask h = layer(h, local_mask, cache[i]) mx.eval(h) if output_hidden_states and i < num_layers - 1: all_hidden_states.append(h) # Apply final norm hidden_states = self.model.norm(h) mx.eval(hidden_states) # Append the final normalized output as the last hidden state if output_hidden_states: all_hidden_states.append(hidden_states) return hidden_states, all_hidden_states else: # Return logits return self.model.embed_tokens.as_linear(hidden_states) def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: prefix = "language_model." sanitized = {} for key, value in weights.items(): if key.startswith(prefix): if hasattr(value, "dtype") and value.dtype == mx.float32: sanitized[key[len(prefix) :]] = value.astype(mx.bfloat16) else: sanitized[key[len(prefix) :]] = value return sanitized @property def layers(self) -> List[nn.Module]: return self.model.layers def make_cache(self): from mlx_vlm.models.cache import KVCache, RotatingKVCache caches = [] for i in range(len(self.layers)): if ( i % self.config.sliding_window_pattern == self.config.sliding_window_pattern - 1 ): caches.append(KVCache()) else: caches.append(RotatingKVCache(max_size=self.config.sliding_window)) return caches @classmethod def from_pretrained(cls, model_path: str): import json weight_files = sorted(Path(model_path).glob("*.safetensors")) config_file = Path(model_path) / "config.json" config_dict = {} if config_file.exists(): with open(config_file, "r") as f: config_dict = json.load(f) language_model = cls( config=TextConfig.from_dict(config_dict["text_config"]) ) else: raise ValueError(f"Config file not found at {model_path}") quantization = config_dict.get("quantization", None) weights = {} for i, wf in enumerate(weight_files): weights.update(mx.load(str(wf))) if hasattr(language_model, "sanitize"): weights = language_model.sanitize(weights=weights) apply_quantization( model=language_model, weights=weights, quantization=quantization ) language_model.load_weights(list(weights.items()), strict=False) return language_model class ConnectorAttention(nn.Module): def __init__( self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim inner_dim = num_heads * head_dim self.scale = 1.0 / math.sqrt(head_dim) self.to_q = nn.Linear(dim, inner_dim, bias=True) self.to_k = nn.Linear(dim, inner_dim, bias=True) self.to_v = nn.Linear(dim, inner_dim, bias=True) self.to_out = nn.Linear(inner_dim, dim, bias=True) self.q_norm = nn.RMSNorm(inner_dim, eps=1e-6) self.k_norm = nn.RMSNorm(inner_dim, eps=1e-6) if has_gate_logits: self.to_gate_logits = nn.Linear(dim, num_heads, bias=True) def __call__( self, x: mx.array, attention_mask: Optional[mx.array] = None, pe: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: batch_size, seq_len, _ = x.shape # Compute per-head gate early (from original input) gate = None if hasattr(self, "to_gate_logits"): gate = 2.0 * mx.sigmoid(self.to_gate_logits(x)) # (B, seq, heads) # Project to Q, K, V q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) # QK normalization on full inner_dim BEFORE reshape q = self.q_norm(q) k = self.k_norm(k) # Reshape to (B, H, T, D) for SPLIT RoPE q = mx.reshape( q, (batch_size, seq_len, self.num_heads, self.head_dim) ).transpose(0, 2, 1, 3) k = mx.reshape( k, (batch_size, seq_len, self.num_heads, self.head_dim) ).transpose(0, 2, 1, 3) v = mx.reshape( v, (batch_size, seq_len, self.num_heads, self.head_dim) ).transpose(0, 2, 1, 3) if pe is not None: q = self._apply_split_rope(q, pe[0], pe[1]) k = self._apply_split_rope(k, pe[0], pe[1]) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None) out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) # Apply per-head gating if gate is not None: out = mx.reshape(out, (batch_size, seq_len, self.num_heads, self.head_dim)) out = out * gate[..., None] out = mx.reshape(out, (batch_size, seq_len, -1)) return self.to_out(out) def _apply_split_rope( self, x: mx.array, cos_freq: mx.array, sin_freq: mx.array, ) -> mx.array: """Apply SPLIT RoPE to input tensor. Args: x: Input tensor of shape (B, H, T, D) cos_freq: Cosine frequencies of shape (1, H, T, D//2) sin_freq: Sine frequencies of shape (1, H, T, D//2) Returns: Tensor with SPLIT rotary embeddings applied """ input_dtype = x.dtype # Cast to float32 for precision, then cast back x = x.astype(mx.float32) cos_freq = cos_freq.astype(mx.float32) sin_freq = sin_freq.astype(mx.float32) # Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2) half_dim = x.shape[-1] // 2 x1 = x[..., :half_dim] x2 = x[..., half_dim:] # Apply rotation: SPLIT pattern # out1 = x1 * cos - x2 * sin # out2 = x2 * cos + x1 * sin out1 = x1 * cos_freq - x2 * sin_freq out2 = x2 * cos_freq + x1 * sin_freq return mx.concatenate([out1, out2], axis=-1).astype(input_dtype) class GEGLU(nn.Module): """GELU-gated linear unit.""" def __init__(self, in_dim: int, out_dim: int): super().__init__() self.proj = nn.Linear(in_dim, out_dim, bias=True) def __call__(self, x: mx.array) -> mx.array: return nn.gelu(self.proj(x)) class ConnectorFeedForward(nn.Module): def __init__(self, dim: int = 3840, mult: int = 4, dropout: float = 0.0): super().__init__() inner_dim = dim * mult # Use explicit named attributes to match weight key structure (proj_in, proj_out) self.proj_in = nn.Linear(dim, inner_dim, bias=True) self.dropout = nn.Dropout(dropout) self.proj_out = nn.Linear(inner_dim, dim, bias=True) def __call__(self, x: mx.array) -> mx.array: x = nn.gelu_approx(self.proj_in(x)) x = self.dropout(x) x = self.proj_out(x) return x class ConnectorTransformerBlock(nn.Module): def __init__( self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, has_gate_logits: bool = False, ): super().__init__() self.attn1 = ConnectorAttention( dim, num_heads, head_dim, has_gate_logits=has_gate_logits ) self.ff = ConnectorFeedForward(dim) def __call__( self, x: mx.array, attention_mask: Optional[mx.array] = None, pe: Optional[mx.array] = None, ) -> mx.array: # Pre-norm + attention + residual norm_x = rms_norm(x) if norm_x.ndim == 4: norm_x = mx.squeeze(norm_x, axis=1) attn_out = self.attn1(norm_x, attention_mask, pe) x = x + attn_out if x.ndim == 4: x = mx.squeeze(x, axis=1) # Pre-norm + FFN + residual norm_x = rms_norm(x) ff_out = self.ff(norm_x) x = x + ff_out if x.ndim == 4: x = mx.squeeze(x, axis=1) return x class Embeddings1DConnector(nn.Module): def __init__( self, dim: int = 3840, num_heads: int = 30, head_dim: int = 128, num_layers: int = 2, num_learnable_registers: int = 128, positional_embedding_theta: float = 10000.0, positional_embedding_max_pos: list = None, has_gate_logits: bool = False, ): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = head_dim self.num_learnable_registers = num_learnable_registers self.positional_embedding_theta = positional_embedding_theta self.positional_embedding_max_pos = positional_embedding_max_pos or [1] self.transformer_1d_blocks = { i: ConnectorTransformerBlock( dim, num_heads, head_dim, has_gate_logits=has_gate_logits ) for i in range(num_layers) } if num_learnable_registers > 0: self.learnable_registers = mx.zeros((num_learnable_registers, dim)) def _precompute_freqs_cis( self, seq_len: int, dtype: mx.Dtype ) -> Tuple[mx.array, mx.array]: """Compute RoPE frequencies for connector (SPLIT type matching PyTorch). Returns tuple of (cos, sin) each with shape (1, num_heads, seq_len, head_dim//2). """ import numpy as np dim = self.num_heads * self.head_dim # inner_dim theta = self.positional_embedding_theta max_pos = self.positional_embedding_max_pos # [1] = PyTorch default n_elem = 2 * len(max_pos) # = 2 start = 1.0 end = theta num_indices = dim // n_elem # generate_freq_grid_np: compute indices in float64 then cast to float32 # (matches PyTorch: double_precision_rope generates in np.float64, # but returns torch.float32) log_start = np.log(start) / np.log(theta) # = 0 log_end = np.log(end) / np.log(theta) # = 1 lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64) indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32) # generate_freqs: positions and freqs in float32 (matching PyTorch) positions = np.arange(seq_len, dtype=np.float32) fractional_positions = positions / max_pos[0] scaled_positions = fractional_positions * 2 - 1 # Shape: (seq_len,) # freqs = scaled_positions * indices (outer product) in float32 freqs = scaled_positions[:, None] * indices[None, :] # split_freqs_cis: cos/sin in float32 (matching PyTorch) expected_freqs = dim // 2 pad_size = expected_freqs - freqs.shape[-1] cos_freq = np.cos(freqs) # (seq_len, num_indices) sin_freq = np.sin(freqs) if pad_size > 0: cos_padding = np.ones((seq_len, pad_size), dtype=np.float32) sin_padding = np.zeros((seq_len, pad_size), dtype=np.float32) cos_freq = np.concatenate([cos_padding, cos_freq], axis=-1) sin_freq = np.concatenate([sin_padding, sin_freq], axis=-1) # Reshape: (T, dim//2) -> (T, H, D//2) -> (1, H, T, D//2) cos_freq = cos_freq.reshape(seq_len, self.num_heads, self.head_dim // 2) sin_freq = sin_freq.reshape(seq_len, self.num_heads, self.head_dim // 2) cos_freq = np.transpose(cos_freq, (1, 0, 2))[np.newaxis, ...] sin_freq = np.transpose(sin_freq, (1, 0, 2))[np.newaxis, ...] # Convert to MLX cos_full = mx.array(cos_freq) sin_full = mx.array(sin_freq) return cos_full.astype(dtype), sin_full.astype(dtype) def _replace_padded_with_registers( self, hidden_states: mx.array, attention_mask: mx.array, ) -> Tuple[mx.array, mx.array]: batch_size, seq_len, dim = hidden_states.shape dtype = hidden_states.dtype # Binary mask: 1 for valid tokens, 0 for padded # attention_mask is additive: 0 for valid, large negative for padded mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype( mx.int32 ) # (batch, seq) # Tile registers to match sequence length, cast to hidden_states dtype num_tiles = seq_len // self.num_learnable_registers registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype( dtype ) # (seq_len, dim) # Process each batch item (PyTorch uses advanced indexing) result_list = [] for b in range(batch_size): mask_b = mask_binary[b] # (seq,) hs_b = hidden_states[b] # (seq, dim) # Count valid tokens num_valid = int(mx.sum(mask_b)) # Extract valid tokens (where mask is 1) # Since we have left-padded input, valid tokens are at the end valid_tokens = hs_b[seq_len - num_valid :] # (num_valid, dim) # Pad with zeros on the right to get back to seq_len pad_length = seq_len - num_valid if pad_length > 0: padding = mx.zeros((pad_length, dim), dtype=dtype) adjusted = mx.concatenate( [valid_tokens, padding], axis=0 ) # (seq_len, dim) else: adjusted = valid_tokens # Create flipped mask: 1s at front (where valid tokens now are), 0s at back flipped_mask = mx.concatenate( [ mx.ones((num_valid,), dtype=mx.int32), mx.zeros((pad_length,), dtype=mx.int32), ], axis=0, ) # (seq,) # Combine: valid tokens at front, registers at back flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1) combined = ( flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers ) result_list.append(combined) hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim) # Reset attention mask to all zeros (no masking after register replacement) attention_mask = mx.zeros_like(attention_mask) return hidden_states, attention_mask def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, ) -> Tuple[mx.array, mx.array]: # Replace padded tokens with learnable registers if self.num_learnable_registers > 0 and attention_mask is not None: hidden_states, attention_mask = self._replace_padded_with_registers( hidden_states, attention_mask ) # Compute RoPE frequencies seq_len = hidden_states.shape[1] freqs_cis = self._precompute_freqs_cis(seq_len, hidden_states.dtype) # Process through transformer blocks for i in range(len(self.transformer_1d_blocks)): hidden_states = self.transformer_1d_blocks[i]( hidden_states, attention_mask, freqs_cis ) # Final RMS norm hidden_states = rms_norm(hidden_states) return hidden_states, attention_mask def norm_and_concat_hidden_states( hidden_states: List[mx.array], attention_mask: mx.array, padding_side: str = "left", ) -> mx.array: # Stack hidden states: (batch, seq, dim, num_layers) stacked = mx.stack(hidden_states, axis=-1) dtype = stacked.dtype b, t, d, num_layers = stacked.shape # Compute sequence lengths from attention mask sequence_lengths = mx.sum(attention_mask, axis=-1) # (batch,) # Build mask based on padding side token_indices = mx.arange(t)[None, :] # (1, T) if padding_side == "right": mask = token_indices < sequence_lengths[:, None] # (B, T) else: # left padding start_indices = t - sequence_lengths[:, None] # (B, 1) mask = token_indices >= start_indices # (B, T) mask = mask[:, :, None, None] # (B, T, 1, 1) eps = mx.array(1e-6, dtype=dtype) # Compute masked mean per layer - ensure dtype consistency masked = mx.where(mask, stacked, mx.zeros_like(stacked)) denom = (sequence_lengths * d).reshape(b, 1, 1, 1).astype(dtype) mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps) # Compute masked min/max per layer x_for_min = mx.where( mask, stacked, mx.full(stacked.shape, float("inf"), dtype=dtype) ) x_for_max = mx.where( mask, stacked, mx.full(stacked.shape, float("-inf"), dtype=dtype) ) x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True) x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True) range_val = x_max - x_min # Normalize: 8 * (x - mean) / range normed = 8 * (stacked - mean) / (range_val + eps) # Flatten layers into feature dimension: (B, T, D*L) normed = mx.reshape(normed, (b, t, -1)) # Zero out padded positions mask_flat = mx.broadcast_to(mask[:, :, :, 0], (b, t, d * num_layers)) normed = mx.where(mask_flat, normed, mx.zeros_like(normed)) return normed def norm_and_concat_per_token_rms( encoded_text: mx.array, attention_mask: mx.array, ) -> mx.array: """Per-token RMSNorm normalization for V2 feature extraction (LTX-2.3). Args: encoded_text: (B, T, D, L) stacked hidden states attention_mask: (B, T) binary mask (1=valid, 0=padding) Returns: (B, T, D*L) normalized tensor with padding zeroed out. """ b, t, d, num_layers = encoded_text.shape dtype = encoded_text.dtype # Per-token RMSNorm across hidden dimension: variance = mean(x^2) over dim D variance = mx.mean( encoded_text.astype(mx.float32) ** 2, axis=2, keepdims=True ) # (B, T, 1, L) normed = encoded_text.astype(mx.float32) * mx.rsqrt(variance + 1e-6) normed = normed.astype(dtype) # Flatten layers: (B, T, D*L) normed = mx.reshape(normed, (b, t, d * num_layers)) # Zero out padded positions mask_3d = attention_mask[:, :, None].astype(mx.bool_) # (B, T, 1) normed = mx.where(mask_3d, normed, mx.zeros_like(normed)) return normed def _rescale_norm(x: mx.array, target_dim: int, source_dim: int) -> mx.array: """Rescale normalization: x * sqrt(target_dim / source_dim).""" return x * math.sqrt(target_dim / source_dim) class GemmaFeaturesExtractor(nn.Module): """V1 feature extractor (LTX-2): 8 * (x - mean) / range normalization.""" def __init__( self, input_dim: int = 188160, output_dim: int = 3840, bias: bool = False ): super().__init__() self.aggregate_embed = nn.Linear(input_dim, output_dim, bias=bias) def __call__(self, x: mx.array) -> mx.array: return self.aggregate_embed(x) class GemmaFeaturesExtractorV2(nn.Module): """V2 feature extractor (LTX-2.3): per-token RMSNorm + rescale normalization.""" def __init__( self, flat_dim: int, embedding_dim: int, video_output_dim: int, audio_output_dim: int, bias: bool = True, ): super().__init__() self.embedding_dim = embedding_dim # Gemma hidden_dim (3840), used for rescale self.video_aggregate_embed = nn.Linear(flat_dim, video_output_dim, bias=bias) self.audio_aggregate_embed = nn.Linear(flat_dim, audio_output_dim, bias=bias) def __call__( self, hidden_states: List[mx.array], attention_mask: mx.array, mode: str = "video", ) -> mx.array: """Extract features with per-token RMSNorm + rescale. Args: hidden_states: List of hidden states from all Gemma layers attention_mask: Binary attention mask (B, T) mode: "video" or "audio" to select which aggregate embed to use Returns: Projected features """ # Stack hidden states: (B, T, D, L) encoded = mx.stack(hidden_states, axis=-1) # Per-token RMSNorm + flatten normed = norm_and_concat_per_token_rms(encoded, attention_mask) normed = normed.astype(encoded.dtype) if mode == "video": target_dim = self.video_aggregate_embed.weight.shape[0] return self.video_aggregate_embed( _rescale_norm(normed, target_dim, self.embedding_dim) ) else: target_dim = self.audio_aggregate_embed.weight.shape[0] return self.audio_aggregate_embed( _rescale_norm(normed, target_dim, self.embedding_dim) ) class AudioEmbeddingsConnector(nn.Module): """Projects video embeddings to audio cross-attention dimension.""" def __init__(self, input_dim: int = 3840, output_dim: int = 2048): super().__init__() self.linear = nn.Linear(input_dim, output_dim, bias=True) def __call__(self, x: mx.array) -> mx.array: return self.linear(x) class LTX2TextEncoder(nn.Module): def __init__( self, hidden_dim: int = 3840, audio_dim: int = 2048, num_layers: int = 49, # 48 transformer layers + 1 embedding has_prompt_adaln: bool = False, ): super().__init__() self.hidden_dim = hidden_dim self.audio_dim = audio_dim self.num_layers = num_layers self.has_prompt_adaln = has_prompt_adaln self.language_model = None feature_input_dim = hidden_dim * num_layers if has_prompt_adaln: # LTX-2.3: V2 feature extractor with per-token RMSNorm + rescale video_output_dim = 4096 audio_output_dim = 2048 self.feature_extractor_v2 = GemmaFeaturesExtractorV2( flat_dim=feature_input_dim, # 3840 * 49 = 188160 (concatenated) embedding_dim=hidden_dim, # 3840 (Gemma hidden_dim, for rescale) video_output_dim=video_output_dim, audio_output_dim=audio_output_dim, bias=True, ) # Deeper connectors with matching dims and gate_logits # connector_positional_embedding_max_pos=[4096] from LTX-2.3 safetensors # config (nested under config.transformer.connector_positional_embedding_max_pos) self.video_embeddings_connector = Embeddings1DConnector( dim=video_output_dim, num_heads=32, head_dim=128, num_layers=8, num_learnable_registers=128, positional_embedding_max_pos=[4096], has_gate_logits=True, ) self.audio_embeddings_connector = Embeddings1DConnector( dim=audio_output_dim, num_heads=32, head_dim=64, num_layers=8, num_learnable_registers=128, positional_embedding_max_pos=[4096], has_gate_logits=True, ) else: # LTX-2: shared feature extractor, 3840-dim connectors self.feature_extractor = GemmaFeaturesExtractor( feature_input_dim, hidden_dim ) self.video_embeddings_connector = Embeddings1DConnector( dim=hidden_dim, num_heads=30, head_dim=128, num_layers=2, num_learnable_registers=128, positional_embedding_max_pos=[1], ) self.audio_embeddings_connector = Embeddings1DConnector( dim=hidden_dim, num_heads=30, head_dim=128, num_layers=2, num_learnable_registers=128, positional_embedding_max_pos=[1], ) self.processor = None def load( self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it", ): if Path(str(text_encoder_path)).joinpath("text_encoder").is_dir(): text_encoder_path = str(Path(text_encoder_path) / "text_encoder") self.language_model = LanguageModel.from_pretrained(text_encoder_path) # Load transformer weights for feature extractor and connector. # These weights are stored differently depending on the repo format: # 1. Monolithic (Lightricks/LTX-2): single ltx-2-19b-*.safetensors at root # with raw PyTorch key names (model.diffusion_model.* prefix) # 2. Reformatted (prince-canuma/LTX-2-distilled): separate text_projections/ # directory with pre-sanitized keys (no prefix, already renamed) transformer_weights = {} is_reformatted = False # Try reformatted layout first: text_projections/ subdirectory text_proj_dir = model_path / "text_projections" if text_proj_dir.is_dir(): is_reformatted = True for sf in text_proj_dir.glob("*.safetensors"): transformer_weights.update(mx.load(str(sf))) # Fall back to monolithic layout: ltx-2-19*.safetensors at root if not transformer_weights: transformer_files = list(model_path.glob("ltx-2-19*.safetensors")) if transformer_files: transformer_weights = mx.load(str(transformer_files[0])) if transformer_weights: self._load_feature_extractors(transformer_weights, is_reformatted) self._load_connector( "video_embeddings_connector", transformer_weights, is_reformatted ) self._load_connector( "audio_embeddings_connector", transformer_weights, is_reformatted ) else: print( "WARNING: No transformer weights found for text projection connectors. " "Text conditioning will use uninitialized weights!" ) # Load tokenizer from transformers import AutoTokenizer tokenizer_path = model_path / "tokenizer" if tokenizer_path.exists(): self.processor = AutoTokenizer.from_pretrained( str(tokenizer_path), trust_remote_code=True ) else: try: self.processor = AutoTokenizer.from_pretrained( text_encoder_path, trust_remote_code=True ) except Exception: self.processor = AutoTokenizer.from_pretrained( "google/gemma-3-12b-it", trust_remote_code=True ) # Set left padding to match official LTX-2 text encoder self.processor.padding_side = "left" print("Text encoder loaded successfully") def _load_feature_extractors(self, weights: dict, is_reformatted: bool): """Load feature extractor weights for both LTX-2 and LTX-2.3.""" if self.has_prompt_adaln: # LTX-2.3: V2 feature extractor with separate video/audio aggregate embeds for attr, prefix in [ ("video_aggregate_embed", "video_aggregate_embed"), ("audio_aggregate_embed", "audio_aggregate_embed"), ]: w_key = f"{prefix}.weight" b_key = f"{prefix}.bias" if w_key in weights: submodule = getattr(self.feature_extractor_v2, attr) submodule.weight = weights[w_key] if b_key in weights: submodule.bias = weights[b_key] else: # LTX-2: single aggregate_embed agg_key = ( "aggregate_embed.weight" if is_reformatted else "text_embedding_projection.aggregate_embed.weight" ) if agg_key in weights: self.feature_extractor.aggregate_embed.weight = weights[agg_key] def _load_connector(self, name: str, weights: dict, is_reformatted: bool): """Load a connector's weights (video or audio).""" connector = getattr(self, name) # Extract connector-specific weights connector_weights = {} if is_reformatted: prefix = f"{name}." for key, value in weights.items(): if key.startswith(prefix): connector_weights[key[len(prefix) :]] = value else: mono_prefix = f"model.diffusion_model.{name}." for key, value in weights.items(): if key.startswith(mono_prefix): connector_weights[key[len(mono_prefix) :]] = value if not connector_weights: return # Sanitize key names (only needed for monolithic/raw PyTorch keys) mapped = {} for key, value in connector_weights.items(): new_key = key if not is_reformatted: new_key = new_key.replace(".ff.net.0.proj.", ".ff.proj_in.") new_key = new_key.replace(".ff.net.2.", ".ff.proj_out.") new_key = new_key.replace(".to_out.0.", ".to_out.") mapped[new_key] = value connector.load_weights(list(mapped.items()), strict=False) # Manually load learnable_registers (plain mx.array, not tracked as parameter) if "learnable_registers" in connector_weights: connector.learnable_registers = connector_weights["learnable_registers"] def encode( self, prompt: str, max_length: int = 1024, return_audio_embeddings: bool = True, ) -> Tuple[mx.array, mx.array]: """Encode text prompt to video and audio embeddings. Args: prompt: Text prompt to encode max_length: Maximum token length (default 1024 to match official PyTorch) return_audio_embeddings: If True, returns (video_emb, audio_emb). If False, returns (video_emb, attention_mask). Returns: Tuple of (video_embeddings, audio_embeddings) if return_audio_embeddings=True Tuple of (video_embeddings, attention_mask) otherwise """ if self.processor is None: raise RuntimeError("Model not loaded. Call load() first.") inputs = self.processor( prompt, return_tensors="np", max_length=max_length, truncation=True, padding="max_length", ) input_ids = mx.array(inputs["input_ids"]) attention_mask = mx.array(inputs["attention_mask"]) _, all_hidden_states = self.language_model( inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True, ) if self.has_prompt_adaln: # LTX-2.3: V2 feature extraction (per-token RMSNorm + rescale) video_features = self.feature_extractor_v2( all_hidden_states, attention_mask, mode="video" ) additive_mask = (attention_mask - 1).astype(video_features.dtype) additive_mask = ( additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 ) video_embeddings, _ = self.video_embeddings_connector( video_features, additive_mask ) if return_audio_embeddings: audio_features = self.feature_extractor_v2( all_hidden_states, attention_mask, mode="audio" ) audio_mask = (attention_mask - 1).astype(audio_features.dtype) audio_mask = audio_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 audio_embeddings, _ = self.audio_embeddings_connector( audio_features, audio_mask ) return video_embeddings, audio_embeddings else: return video_embeddings, attention_mask else: # LTX-2: V1 feature extraction (8 * (x - mean) / range) concat_hidden = norm_and_concat_hidden_states( all_hidden_states, attention_mask, padding_side="left" ) video_features = self.feature_extractor(concat_hidden) additive_mask = (attention_mask - 1).astype(video_features.dtype) additive_mask = ( additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 ) video_embeddings, _ = self.video_embeddings_connector( video_features, additive_mask ) if return_audio_embeddings: audio_embeddings, _ = self.audio_embeddings_connector( video_features, additive_mask ) return video_embeddings, audio_embeddings else: return video_embeddings, attention_mask def __call__( self, prompt: str, max_length: int = 1024, return_audio_embeddings: bool = True, ) -> Tuple[mx.array, mx.array]: """Encode text prompt. Args: prompt: Text prompt to encode max_length: Maximum token length (default 1024 to match official PyTorch) return_audio_embeddings: If True, returns (video_emb, audio_emb). If False, returns (video_emb, attention_mask). Returns: Tuple of embeddings based on return_audio_embeddings flag """ return self.encode(prompt, max_length, return_audio_embeddings) @functools.cached_property def default_t2v_system_prompt(self) -> str: """Load the default T2V system prompt.""" return _load_system_prompt("gemma_t2v_system_prompt.txt") @functools.cached_property def default_i2v_system_prompt(self) -> str: """Load the default I2V system prompt.""" return _load_system_prompt("gemma_i2v_system_prompt.txt") def _clean_response(self, response: str) -> str: """Clean up the generated response.""" # Remove leading/trailing whitespace response = response.strip() # Remove any leading punctuation response = re.sub(r"^[^\w\s]+", "", response) return response def _apply_chat_template( self, messages: List[Dict[str, str]], ) -> str: """Apply Gemma chat template to messages.""" # Gemma 3 chat template format formatted = "" for msg in messages: role = msg["role"] content = msg["content"] if role == "system": formatted += f"user\n{content}\n" elif role == "user": if isinstance(content, str): formatted += f"user\n{content}\n" elif isinstance(content, list): # Handle multimodal content (image + text) text_parts = [c["text"] for c in content if c.get("type") == "text"] formatted += ( f"user\n{' '.join(text_parts)}\n" ) elif role == "assistant": formatted += f"model\n{content}\n" # Add generation prompt formatted += "model\n" return formatted def enhance_t2v( self, prompt: str, max_tokens: int = 512, system_prompt: Optional[str] = None, seed: int = 42, verbose: bool = True, **kwargs, ) -> str: """Enhance a text prompt for T2V generation using mlx-lm. Args: prompt: The original user prompt max_new_tokens: Maximum number of tokens to generate system_prompt: Optional custom system prompt seed: Random seed for generation Returns: Enhanced prompt string """ try: from mlx_lm import stream_generate from mlx_lm.sample_utils import make_logits_processors, make_sampler except ImportError: logging.warning( "mlx-lm not available for prompt enhancement. Using original prompt." ) return prompt if self.processor is None: raise RuntimeError("Model not loaded. Call load() first.") system_prompt = system_prompt or self.default_t2v_system_prompt messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"user prompt: {prompt}"}, ] # Apply chat template formatted = self._apply_chat_template(messages) # Use mlx-lm generate with temperature sampling mx.random.seed(seed) # Tokenize inputs = self.processor( formatted, return_tensors="np", add_special_tokens=False, ) input_ids = mx.array(inputs["input_ids"]) sampler = make_sampler( kwargs.get("temperature", 0.7), kwargs.get("top_p", 1.0), top_k=kwargs.get("top_k", -1), ) logits_processors = make_logits_processors( kwargs.get("logit_bias", None), kwargs.get("repetition_penalty", 1.3), kwargs.get("repetition_context_size", 20), ) generated_token_count = 0 generated_tokens = [] console = Console() generator = stream_generate( self.language_model, tokenizer=self.processor, prompt=input_ids.squeeze(0), max_tokens=max_tokens, sampler=sampler, logits_processors=logits_processors, ) progress = Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), console=console, disable=not verbose, ) with progress: task = progress.add_task("[cyan]Generating[/]", total=max_tokens) for i, response in enumerate(generator): next_token = mx.array([response.token]) input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1) generated_tokens.append(response.token) generated_token_count += 1 progress.update(task, advance=1) if i % 50 == 0: mx.clear_cache() # Check for EOS if response.token == 1 or response.token == 107: # EOS tokens progress.update(task, completed=max_tokens) break mx.clear_cache() # Decode only the new tokens enhanced_prompt = self.processor.decode( generated_tokens, skip_special_tokens=True ) enhanced_prompt = self._clean_response(enhanced_prompt) logging.info(f"Enhanced prompt: {enhanced_prompt}") return enhanced_prompt def enhance_i2v( self, prompt: str, image: Optional[mx.array] = None, max_new_tokens: int = 512, system_prompt: Optional[str] = None, seed: int = 42, ) -> str: """Enhance a text prompt for I2V generation. Args: prompt: The original user prompt image: Optional image tensor (not currently used) max_new_tokens: Maximum number of tokens to generate system_prompt: Optional custom system prompt seed: Random seed for generation Returns: Enhanced prompt string """ # Use T2V enhancement with I2V system prompt return self.enhance_t2v( prompt, max_new_tokens=max_new_tokens, system_prompt=system_prompt or self.default_i2v_system_prompt, seed=seed, ) def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder: encoder = LTX2TextEncoder() encoder.load(model_path=model_path) return encoder