Merge pull request #9 from Blaizzy/pc/fix-text-encoder

Fix text encoder
This commit is contained in:
Prince Canuma
2026-01-17 01:10:36 +01:00
committed by GitHub
11 changed files with 2000 additions and 319 deletions

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm
# ANSI color codes # ANSI color codes
class Colors: class Colors:
@@ -113,9 +114,10 @@ def denoise(
text_embeddings: mx.array, text_embeddings: mx.array,
transformer: LTXModel, transformer: LTXModel,
sigmas: list, sigmas: list,
verbose: bool = True,
) -> mx.array: ) -> mx.array:
"""Run denoising loop.""" """Run denoising loop."""
for i in range(len(sigmas) - 1): for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1] sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape b, c, f, h, w = latents.shape
@@ -148,6 +150,7 @@ def denoise(
def generate_video( def generate_video(
model_repo: str, model_repo: str,
text_encoder_repo: str,
prompt: str, prompt: str,
height: int = 512, height: int = 512,
width: int = 512, width: int = 512,
@@ -156,6 +159,10 @@ def generate_video(
fps: int = 24, fps: int = 24,
output_path: str = "output.mp4", output_path: str = "output.mp4",
save_frames: bool = False, save_frames: bool = False,
verbose: bool = True,
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
): ):
"""Generate video from text prompt. """Generate video from text prompt.
@@ -174,12 +181,19 @@ def generate_video(
# Validate dimensions # Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}" assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}" assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
# Get model path # Get model path
model_path = get_model_path(model_repo) model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
# Calculate latent dimensions # Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
@@ -191,10 +205,16 @@ def generate_video(
# Load text encoder # Load text encoder
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}") print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path)) text_encoder = LTX2TextEncoder()
text_encoder.load(str(model_path)) text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
mx.eval(text_encoder.parameters()) mx.eval(text_encoder.parameters())
# Optionally enhance the prompt
if enhance_prompt:
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
text_embeddings, _ = text_encoder(prompt) text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings) mx.eval(text_embeddings)
@@ -236,7 +256,7 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions) mx.eval(positions)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
# Upsample latents # Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
@@ -265,7 +285,7 @@ def generate_video(
latents = noise * noise_scale + latents * (1 - noise_scale) latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents) mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
del transformer del transformer
mx.clear_cache() mx.clear_cache()
@@ -295,7 +315,7 @@ def generate_video(
for frame in video_np: for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release() out.release()
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}") print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
except Exception as e: except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
@@ -308,6 +328,7 @@ def generate_video(
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np return video_np
@@ -361,7 +382,7 @@ Examples:
help="Frames per second for output video (default: 24)" help="Frames per second for output video (default: 24)"
) )
parser.add_argument( parser.add_argument(
"--output", "-o", "--output-path",
type=str, type=str,
default="output.mp4", default="output.mp4",
help="Output video path (default: output.mp4)" help="Output video path (default: output.mp4)"
@@ -377,18 +398,38 @@ Examples:
default="Lightricks/LTX-2", default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)" help="Model repository to use (default: Lightricks/LTX-2)"
) )
parser.add_argument(
"--text-encoder-repo",
type=str,
default=None,
help="Text encoder repository to use (default: None)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Verbose output"
)
parser.add_argument(
"--enhance-prompt",
action="store_true",
help="Enhance the prompt using Gemma before generation"
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="Maximum number of tokens to generate (default: 512)"
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Temperature for prompt enhancement (default: 0.7)"
)
args = parser.parse_args() args = parser.parse_args()
generate_video( generate_video(
model_repo=args.model_repo, **vars(args)
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
) )

View File

@@ -384,8 +384,9 @@ class LTXModel(nn.Module):
video_config = config.get_video_config() video_config = config.get_video_config()
audio_config = config.get_audio_config() audio_config = config.get_audio_config()
self.transformer_blocks = [
BasicAVTransformerBlock( self.transformer_blocks = {
idx: BasicAVTransformerBlock(
idx=idx, idx=idx,
video=video_config, video=video_config,
audio=audio_config, audio=audio_config,
@@ -393,7 +394,7 @@ class LTXModel(nn.Module):
norm_eps=config.norm_eps, norm_eps=config.norm_eps,
) )
for idx in range(config.num_layers) for idx in range(config.num_layers)
] }
def _process_transformer_blocks( def _process_transformer_blocks(
self, self,
@@ -401,7 +402,7 @@ class LTXModel(nn.Module):
audio: Optional[TransformerArgs], audio: Optional[TransformerArgs],
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]: ) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
"""Process through all transformer blocks.""" """Process through all transformer blocks."""
for block in self.transformer_blocks: for block in self.transformer_blocks.values():
video, audio = block(video=video, audio=audio) video, audio = block(video=video, audio=audio)
return video, audio return video, audio

View File

@@ -0,0 +1,30 @@
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
#### Guidelines:
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Chronological flow: Use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
#### Important notes:
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
- Format: Never start output with punctuation marks or special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
#### Output Format (Strict):
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
#### Example output:
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.

View File

@@ -0,0 +1,40 @@
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
#### Guidelines
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
- Speech (only when requested):
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
- Specify language if not English and accent if relevant.
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
#### Important notes:
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
- Camera motion: DO NOT invent camera motion unless requested by the user.
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
- Format: DO NOT start your response with special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
#### Output Format (Strict):
- Single continuous paragraph in natural language (English).
- NO titles, headings, prefaces, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
#### Example
Input: "A woman at a coffee shop talking on the phone"
Output:
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.

View File

@@ -49,6 +49,12 @@ def apply_interleaved_rotary_emb(
Returns: Returns:
Tensor with interleaved rotary embeddings applied Tensor with interleaved rotary embeddings applied
""" """
# Compute in float32 for better precision
input_dtype = input_tensor.dtype
input_tensor = input_tensor.astype(mx.float32)
cos_freqs = cos_freqs.astype(mx.float32)
sin_freqs = sin_freqs.astype(mx.float32)
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2) # Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
shape = input_tensor.shape shape = input_tensor.shape
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2)) input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
@@ -67,7 +73,7 @@ def apply_interleaved_rotary_emb(
# Apply rotary embeddings # Apply rotary embeddings
out = input_tensor * cos_freqs + t_rot * sin_freqs out = input_tensor * cos_freqs + t_rot * sin_freqs
return out return out.astype(input_dtype)
def rotate_half_interleaved(x: mx.array) -> mx.array: def rotate_half_interleaved(x: mx.array) -> mx.array:

View File

@@ -1,215 +1,201 @@
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
import functools
import logging
import math import math
import re
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np
from mlx_video.utils import rms_norm from mlx_video.utils import rms_norm, apply_quantization
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@dataclass from mlx_vlm.models.gemma3.language import Gemma3Model
class Gemma3Config: from mlx_vlm.models.gemma3.config import TextConfig
"""Configuration for Gemma 3 text model."""
hidden_size: int = 3840
num_attention_heads: int = 16
num_key_value_heads: int = 8
head_dim: int = 256
intermediate_size: int = 15360
num_hidden_layers: int = 48
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
vocab_size: int = 262208
max_position_embeddings: int = 131072
class RMSNorm(nn.Module): # Path to system prompts
"""RMS Normalization (Gemma style with 1+weight scaling).""" PROMPTS_DIR = Path(__file__).parent / "prompts"
def __init__(self, dims: int, eps: float = 1e-6):
def _load_system_prompt(prompt_name: str) -> str:
"""Load a system prompt from the prompts directory."""
prompt_path = PROMPTS_DIR / prompt_name
if prompt_path.exists():
with open(prompt_path, "r") as f:
return f.read()
raise FileNotFoundError(f"System prompt not found: {prompt_path}")
class LanguageModel(nn.Module):
def __init__(self, config: TextConfig):
super().__init__() super().__init__()
self.eps = eps # Create config matching LTX-2 text encoder requirements
# Gemma initializes to ones, but uses (1+weight) scaling self.config = config
# After loading weights, weight will have the actual learned values
self.weight = mx.ones((dims,))
def __call__(self, x: mx.array) -> mx.array: # Create the Gemma3Model from mlx-vlm
# Gemma-style RMSNorm uses (1 + weight) as the scale factor self.model = Gemma3Model(self.config)
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
def _create_causal_mask_with_padding(
def apply_rotary_emb(
q: mx.array,
k: mx.array,
positions: mx.array,
head_dim: int,
rope_theta: float = 1000000.0,
) -> Tuple[mx.array, mx.array]:
"""Apply rotary position embeddings to Q and K."""
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
cos = mx.cos(freqs)
sin = mx.sin(freqs)
cos = cos[:, :, None, :]
sin = sin[:, :, None, :]
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate([-x2, x1], axis=-1)
cos_full = mx.concatenate([cos, cos], axis=-1)
sin_full = mx.concatenate([sin, sin], axis=-1)
q_embed = q * cos_full + rotate_half(q) * sin_full
k_embed = k * cos_full + rotate_half(k) * sin_full
return q_embed, k_embed
class Gemma3MLP(nn.Module):
"""Gemma 3 MLP with gated activation."""
def __init__(self, config: Gemma3Config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
gate = nn.gelu_approx(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class Gemma3Attention(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.scale = 1.0 / math.sqrt(config.head_dim)
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def __call__(
self, self,
hidden_states: mx.array, seq_len: int,
positions: mx.array, attention_mask: Optional[mx.array],
attention_mask: Optional[mx.array] = None, dtype: mx.Dtype,
) -> mx.array: ) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# Create causal mask (lower triangular)
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9 batch_size = attention_mask.shape[0]
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask) padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
out = mx.transpose(out, (0, 2, 1, 3)) combined = causal_mask[None, :, :] & padding_mask[:, None, :]
out = mx.reshape(out, (batch_size, seq_len, -1)) min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
return self.o_proj(out) mx.full(combined.shape, min_val, dtype=dtype))
return mask[:, None, :, :]
else:
class Gemma3DecoderLayer(nn.Module): # No padding mask, just causal
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
def __init__(self, config: Gemma3Config): mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
super().__init__() mx.full((seq_len, seq_len), min_val, dtype=dtype))
self.self_attn = Gemma3Attention(config) return mask[None, None, :, :] # (1, 1, seq, seq)
self.mlp = Gemma3MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__( def __call__(
self, self,
hidden_states: mx.array, inputs: mx.array,
positions: mx.array, input_embeddings: Optional[mx.array] = None,
attention_mask: Optional[mx.array] = None, attention_mask: Optional[mx.array] = None,
) -> mx.array: output_hidden_states: bool = False,
residual = hidden_states cache: Optional[List[mx.array]] = None,
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma3TextModel(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Gemma scales embeddings by sqrt(hidden_size)
self.embed_scale = config.hidden_size ** 0.5
def __call__(
self,
input_ids: mx.array,
attention_mask: Optional[mx.array] = None,
output_hidden_states: bool = True,
) -> Tuple[mx.array, List[mx.array]]: ) -> Tuple[mx.array, List[mx.array]]:
"""Forward pass returning hidden states.
batch_size, seq_len = input_ids.shape
# Gemma scales embeddings by sqrt(hidden_size) Args:
hidden_states = self.embed_tokens(input_ids) * self.embed_scale input_ids: Input token IDs of shape (batch, seq_len)
attention_mask: Optional attention mask (1 for valid, 0 for padding)
output_hidden_states: Whether to return all hidden states
all_hidden_states = [hidden_states] if output_hidden_states else [] Returns:
Tuple of (final_hidden_states, list_of_all_hidden_states)
"""
batch_size, seq_len = inputs.shape
positions = mx.arange(seq_len)[None, :].astype(mx.int32) # Get embeddings
positions = mx.broadcast_to(positions, (batch_size, seq_len)) h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
for layer in self.layers: # Apply Gemma scaling
hidden_states = layer(hidden_states, positions, attention_mask) h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
if output_hidden_states: mx.eval(h)
all_hidden_states.append(hidden_states)
hidden_states = self.norm(hidden_states) all_hidden_states = [h] if output_hidden_states else []
return hidden_states, all_hidden_states # Set up cache (all None for non-cached inference)
if cache is None:
cache = [None] * len(self.model.layers)
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
sliding_mask = full_causal_mask
num_layers = len(self.model.layers)
for i, layer in enumerate(self.model.layers):
is_global = (
i % self.config.sliding_window_pattern
== self.config.sliding_window_pattern - 1
)
# Select appropriate mask for this layer
if is_global:
local_mask = full_causal_mask
else:
local_mask = sliding_mask
h = layer(h, local_mask, cache[i])
mx.eval(h)
if output_hidden_states and i < num_layers - 1:
all_hidden_states.append(h)
# Apply final norm
hidden_states = self.model.norm(h)
mx.eval(hidden_states)
# Append the final normalized output as the last hidden state
if output_hidden_states:
all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states
else:
# Return logits
return self.model.embed_tokens.as_linear(hidden_states)
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
prefix = "language_model."
sanitized = {}
for key, value in weights.items():
if key.startswith(prefix):
if hasattr(value, "dtype") and value.dtype == mx.float32:
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
else:
sanitized[key[len(prefix):]] = value
return sanitized
@property
def layers(self) -> List[nn.Module]:
return self.model.layers
def make_cache(self):
from mlx_vlm.models.cache import KVCache, RotatingKVCache
caches = []
for i in range(len(self.layers)):
if (
i % self.config.sliding_window_pattern
== self.config.sliding_window_pattern - 1
):
caches.append(KVCache())
else:
caches.append(RotatingKVCache(max_size=self.config.sliding_window))
return caches
@classmethod
def from_pretrained(cls, model_path: str):
import json
weight_files = sorted(Path(model_path).glob("*.safetensors"))
config_file = Path(model_path) / "config.json"
config_dict = {}
if config_file.exists():
with open(config_file, "r") as f:
config_dict = json.load(f)
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
else:
raise ValueError(f"Config file not found at {model_path}")
quantization = config_dict.get("quantization", None)
weights = {}
for i, wf in enumerate(weight_files):
weights.update(mx.load(str(wf)))
if hasattr(language_model, "sanitize"):
weights = language_model.sanitize(weights=weights)
apply_quantization(model=language_model, weights=weights, quantization=quantization)
language_model.load_weights(list(weights.items()), strict=False)
return language_model
@@ -278,7 +264,7 @@ class GEGLU(nn.Module):
self.proj = nn.Linear(in_dim, out_dim, bias=True) self.proj = nn.Linear(in_dim, out_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
return nn.gelu_approx(self.proj(x)) return nn.gelu(self.proj(x))
class ConnectorFeedForward(nn.Module): class ConnectorFeedForward(nn.Module):
@@ -359,28 +345,28 @@ class Embeddings1DConnector(nn.Module):
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]: def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (INTERLEAVED type). """Compute RoPE frequencies for connector (INTERLEAVED type).
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim). Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
""" """
import math
import numpy as np
dim = self.num_heads * self.head_dim # inner_dim = 3840 dim = self.num_heads * self.head_dim # inner_dim = 3840
theta = self.positional_embedding_theta theta = self.positional_embedding_theta
max_pos = [1] # Default for connector max_pos = [1] # Default for connector
n_elem = 2 * len(max_pos) # = 2 n_elem = 2 * len(max_pos) # = 2
# Generate frequency indices (matches generate_freq_grid_pytorch)
start = 1.0 start = 1.0
end = theta end = theta
num_indices = dim // n_elem # 1920 num_indices = dim // n_elem # 1920
log_start = math.log(start) / math.log(theta) # = 0 # Use numpy float64 for precision
log_end = math.log(end) / math.log(theta) # = 1 log_start = np.log(start) / np.log(theta) # = 0
lin_space = mx.linspace(log_start, log_end, num_indices) log_end = np.log(end) / np.log(theta) # = 1
indices = (theta ** lin_space) * (math.pi / 2) lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
# Generate positions and compute freqs (matches generate_freqs) # Generate positions and compute freqs (matches generate_freqs)
positions = mx.arange(seq_len).astype(mx.float32) positions = np.arange(seq_len, dtype=np.float64)
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1) # fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1 # scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
scaled_positions = positions * 2 - 1 # Shape: (seq_len,) scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
@@ -390,17 +376,17 @@ class Embeddings1DConnector(nn.Module):
freqs = scaled_positions[:, None] * indices[None, :] freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis) # Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
cos_freq = mx.cos(freqs) cos_freq = np.cos(freqs)
sin_freq = mx.sin(freqs) sin_freq = np.sin(freqs)
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim) # repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
# Pattern: [c0, c0, c1, c1, c2, c2, ...] # Pattern: [c0, c0, c1, c1, c2, c2, ...]
cos_full = mx.repeat(cos_freq, 2, axis=-1) cos_full = np.repeat(cos_freq, 2, axis=-1)
sin_full = mx.repeat(sin_freq, 2, axis=-1) sin_full = np.repeat(sin_freq, 2, axis=-1)
# Add batch dimension: (1, seq_len, dim) # Add batch dimension and convert to MLX: (1, seq_len, dim)
cos_full = cos_full[None, :, :] cos_full = mx.array(cos_full[None, :, :].astype(np.float32))
sin_full = sin_full[None, :, :] sin_full = mx.array(sin_full[None, :, :].astype(np.float32))
return cos_full.astype(dtype), sin_full.astype(dtype) return cos_full.astype(dtype), sin_full.astype(dtype)
@@ -547,45 +533,19 @@ class GemmaFeaturesExtractor(nn.Module):
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
sanitized = {}
for key, value in weights.items():
new_key = None
if key.startswith("base_text_encoder.language_model."):
new_key = key.replace("base_text_encoder.language_model.", "")
elif key.startswith("language_model.model."):
new_key = key.replace("language_model.model.", "")
elif key.startswith("language_model."):
new_key = key.replace("language_model.", "")
else:
continue
if new_key is None:
continue
sanitized[new_key] = value
return sanitized
class LTX2TextEncoder(nn.Module): class LTX2TextEncoder(nn.Module):
def __init__( def __init__(
self, self,
model_path: str = "Lightricks/LTX-2",
hidden_dim: int = 3840, hidden_dim: int = 3840,
num_layers: int = 49, # 48 transformer layers + 1 embedding num_layers: int = 49, # 48 transformer layers + 1 embedding
): ):
super().__init__() super().__init__()
self._model_path = model_path
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.num_layers = num_layers self.num_layers = num_layers
self.language_model = None
# Gemma 3 model
self.config = Gemma3Config()
self.model = Gemma3TextModel(self.config)
# Feature extractor: 3840*49 -> 3840 # Feature extractor: 3840*49 -> 3840
self.feature_extractor = GemmaFeaturesExtractor( self.feature_extractor = GemmaFeaturesExtractor(
@@ -604,37 +564,16 @@ class LTX2TextEncoder(nn.Module):
self.processor = None self.processor = None
def load(self, model_path: Optional[str] = None): def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
path = model_path or self._model_path
# Load Gemma weights from text_encoder subdirectory if Path(text_encoder_path / "text_encoder").is_dir():
if Path(path).is_dir(): text_encoder_path = str(text_encoder_path / "text_encoder")
text_encoder_path = Path(path) / "text_encoder"
if text_encoder_path.exists(): self.language_model = LanguageModel.from_pretrained(text_encoder_path)
gemma_path = str(text_encoder_path)
else:
gemma_path = path
else:
gemma_path = path
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
all_weights = {}
for i, wf in enumerate(weight_files):
print(f" Loading weight file {i+1}/{len(weight_files)}...")
weights = mx.load(str(wf))
all_weights.update(weights)
# Sanitize and load Gemma weights
sanitized = sanitize_gemma3_weights(all_weights)
print(f" Sanitized Gemma weights: {len(sanitized)}")
self.model.load_weights(list(sanitized.items()), strict=False)
# Load transformer weights for feature extractor and connector # Load transformer weights for feature extractor and connector
transformer_path = Path(model_path or self._model_path) transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
if transformer_files: if transformer_files:
print(f"Loading transformer weights for text pipeline...")
transformer_weights = mx.load(str(transformer_files[0])) transformer_weights = mx.load(str(transformer_files[0]))
# Load feature extractor (aggregate_embed) # Load feature extractor (aggregate_embed)
@@ -642,7 +581,7 @@ class LTX2TextEncoder(nn.Module):
self.feature_extractor.aggregate_embed.weight = transformer_weights[ self.feature_extractor.aggregate_embed.weight = transformer_weights[
"text_embedding_projection.aggregate_embed.weight" "text_embedding_projection.aggregate_embed.weight"
] ]
print(" Loaded aggregate_embed weights")
# Load video_embeddings_connector weights # Load video_embeddings_connector weights
connector_weights = {} connector_weights = {}
@@ -663,20 +602,18 @@ class LTX2TextEncoder(nn.Module):
self.video_embeddings_connector.load_weights( self.video_embeddings_connector.load_weights(
list(mapped_weights.items()), strict=False list(mapped_weights.items()), strict=False
) )
print(f" Loaded {len(connector_weights)} connector weights")
# Manually load learnable_registers (it's a plain mx.array, not a parameter) # Manually load learnable_registers (it's a plain mx.array, not a parameter)
if "learnable_registers" in connector_weights: if "learnable_registers" in connector_weights:
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"] self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
# Load tokenizer # Load tokenizer
from transformers import AutoTokenizer from transformers import AutoTokenizer
tokenizer_path = Path(model_path or self._model_path) / "tokenizer" tokenizer_path = model_path / "tokenizer"
if tokenizer_path.exists(): if tokenizer_path.exists():
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True) self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
else: else:
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True) self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
# Set left padding to match official LTX-2 text encoder # Set left padding to match official LTX-2 text encoder
self.processor.padding_side = "left" self.processor.padding_side = "left"
@@ -691,7 +628,6 @@ class LTX2TextEncoder(nn.Module):
if self.processor is None: if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.") raise RuntimeError("Model not loaded. Call load() first.")
# Tokenize with left padding (as in PyTorch version)
inputs = self.processor( inputs = self.processor(
prompt, prompt,
return_tensors="np", return_tensors="np",
@@ -702,28 +638,19 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"]) input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"]) attention_mask = mx.array(inputs["attention_mask"])
# Get all hidden states from Gemma _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
# Normalize and concatenate all hidden states
concat_hidden = norm_and_concat_hidden_states( concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left" all_hidden_states, attention_mask, padding_side="left"
) )
# Project through feature extractor
features = self.feature_extractor(concat_hidden) features = self.feature_extractor(concat_hidden)
# Convert attention mask to additive format for connector
additive_mask = (attention_mask - 1).astype(features.dtype) additive_mask = (attention_mask - 1).astype(features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
# Process through connector
# Note: connector replaces padding with learnable registers and resets mask to zeros
# This means all positions now have valid embeddings (no need for final masking)
embeddings, _ = self.video_embeddings_connector(features, additive_mask) embeddings, _ = self.video_embeddings_connector(features, additive_mask)
# Return embeddings without zeroing - the connector's register replacement
# means all positions have meaningful values now
return embeddings, attention_mask return embeddings, attention_mask
def __call__( def __call__(
@@ -733,6 +660,177 @@ class LTX2TextEncoder(nn.Module):
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
return self.encode(prompt, max_length) return self.encode(prompt, max_length)
@functools.cached_property
def default_t2v_system_prompt(self) -> str:
"""Load the default T2V system prompt."""
return _load_system_prompt("gemma_t2v_system_prompt.txt")
@functools.cached_property
def default_i2v_system_prompt(self) -> str:
"""Load the default I2V system prompt."""
return _load_system_prompt("gemma_i2v_system_prompt.txt")
def _clean_response(self, response: str) -> str:
"""Clean up the generated response."""
# Remove leading/trailing whitespace
response = response.strip()
# Remove any leading punctuation
response = re.sub(r'^[^\w\s]+', '', response)
return response
def _apply_chat_template(
self,
messages: List[Dict[str, str]],
) -> str:
"""Apply Gemma chat template to messages."""
# Gemma 3 chat template format
formatted = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
elif role == "user":
if isinstance(content, str):
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
elif isinstance(content, list):
# Handle multimodal content (image + text)
text_parts = [c["text"] for c in content if c.get("type") == "text"]
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
elif role == "assistant":
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
# Add generation prompt
formatted += "<start_of_turn>model\n"
return formatted
def enhance_t2v(
self,
prompt: str,
max_tokens: int = 512,
system_prompt: Optional[str] = None,
seed: int = 42,
verbose: bool = True,
**kwargs,
) -> str:
"""Enhance a text prompt for T2V generation using mlx-lm.
Args:
prompt: The original user prompt
max_new_tokens: Maximum number of tokens to generate
system_prompt: Optional custom system prompt
seed: Random seed for generation
Returns:
Enhanced prompt string
"""
from tqdm import tqdm
try:
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_logits_processors, make_sampler
except ImportError:
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
return prompt
if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.")
system_prompt = system_prompt or self.default_t2v_system_prompt
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"user prompt: {prompt}"},
]
# Apply chat template
formatted = self._apply_chat_template(messages)
# Use mlx-lm generate with temperature sampling
mx.random.seed(seed)
# Tokenize
inputs = self.processor(
formatted,
return_tensors="np",
add_special_tokens=False,
)
input_ids = mx.array(inputs["input_ids"])
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 0.95), top_k=kwargs.get("top_k", -1))
logits_processors = make_logits_processors(
kwargs.get("logit_bias", None),
kwargs.get("repetition_penalty", 1.3),
kwargs.get("repetition_context_size", 20),
)
generated_token_count = 0
generated_tokens = []
for i, response in enumerate(
tqdm(
stream_generate(
self.language_model,
tokenizer=self.processor,
prompt=input_ids.squeeze(0),
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
),
total=max_tokens,
disable=not verbose,
)
):
next_token = mx.array([response.token])
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
generated_tokens.append(next_token.squeeze())
generated_token_count += 1
if i % 50 == 0:
mx.clear_cache()
# Check for EOS
if response.token == 1 or response.token == 107: # EOS tokens
break
# Decode only the new tokens
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
enhanced_prompt = self._clean_response(enhanced_prompt)
logging.info(f"Enhanced prompt: {enhanced_prompt}")
return enhanced_prompt
def enhance_i2v(
self,
prompt: str,
image: Optional[mx.array] = None,
max_new_tokens: int = 512,
system_prompt: Optional[str] = None,
seed: int = 42,
) -> str:
"""Enhance a text prompt for I2V generation.
Args:
prompt: The original user prompt
image: Optional image tensor (not currently used)
max_new_tokens: Maximum number of tokens to generate
system_prompt: Optional custom system prompt
seed: Random seed for generation
Returns:
Enhanced prompt string
"""
# Use T2V enhancement with I2V system prompt
return self.enhance_t2v(
prompt,
max_new_tokens=max_new_tokens,
system_prompt=system_prompt or self.default_i2v_system_prompt,
seed=seed,
)
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder: def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
encoder = LTX2TextEncoder(model_path=model_path) encoder = LTX2TextEncoder(model_path=model_path)

View File

@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -85,6 +85,10 @@ class GroupNorm3d(nn.Module):
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:
# x: (N, D, H, W, C) # x: (N, D, H, W, C)
n, d, h, w, c = x.shape n, d, h, w, c = x.shape
input_dtype = x.dtype
x = x.astype(mx.float32)
# Reshape to (N, D*H*W, num_groups, C//num_groups) # Reshape to (N, D*H*W, num_groups, C//num_groups)
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups)) x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
@@ -100,7 +104,12 @@ class GroupNorm3d(nn.Module):
x = mx.reshape(x, (n, d, h, w, c)) x = mx.reshape(x, (n, d, h, w, c))
# Apply weight and bias # Apply weight and bias
x = x * self.weight + self.bias weight = self.weight.astype(mx.float32)
bias = self.bias.astype(mx.float32)
x = x * weight + bias
# Convert back to input dtype
x = x.astype(input_dtype)
return x return x

View File

@@ -201,14 +201,15 @@ class ResBlockGroup(nn.Module):
embedding_dim=channels * 4 embedding_dim=channels * 4
) )
self.res_blocks = [ # Use dict with int keys for MLX to track parameters properly
ResnetBlock3DSimple( self.res_blocks = {
i: ResnetBlock3DSimple(
channels, channels,
spatial_padding_mode, spatial_padding_mode,
timestep_conditioning=timestep_conditioning timestep_conditioning=timestep_conditioning
) )
for _ in range(num_layers) for i in range(num_layers)
] }
def __call__( def __call__(
self, self,
@@ -227,7 +228,7 @@ class ResBlockGroup(nn.Module):
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting # Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1) timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
for res_block in self.res_blocks: for res_block in self.res_blocks.values():
x = res_block(x, causal=causal, timestep_embed=timestep_embed) x = res_block(x, causal=causal, timestep_embed=timestep_embed)
return x return x
@@ -287,37 +288,37 @@ class LTX2VideoDecoder(nn.Module):
self.conv_in = ConvInWrapper() self.conv_in = ConvInWrapper()
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample # Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
# Use dict with int keys for MLX to track parameters properly
self.up_blocks = [ self.up_blocks = {
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning), 0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample( 1: DepthToSpaceUpsample(
dims=3, dims=3,
in_channels=1024, in_channels=1024,
stride=(2, 2, 2), stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config! residual=True,
out_channels_reduction_factor=2, out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
), ),
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning), 2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample( 3: DepthToSpaceUpsample(
dims=3, dims=3,
in_channels=512, in_channels=512,
stride=(2, 2, 2), stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config! residual=True,
out_channels_reduction_factor=2, out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
), ),
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning), 4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample( 5: DepthToSpaceUpsample(
dims=3, dims=3,
in_channels=256, in_channels=256,
stride=(2, 2, 2), stride=(2, 2, 2),
residual=True, # CRITICAL: Must match PyTorch config! residual=True,
out_channels_reduction_factor=2, out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
), ),
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning), 6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
] }
final_out_channels = out_channels * patch_size * patch_size final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module): class ConvOutWrapper(nn.Module):
@@ -396,10 +397,10 @@ class LTX2VideoDecoder(nn.Module):
if debug: if debug:
debug_stats("After conv_in", x) debug_stats("After conv_in", x)
for i, block in enumerate(self.up_blocks): for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup): if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep) x = block(x, causal=causal, timestep=scaled_timestep)
else: else:
x = block(x, causal=causal) x = block(x, causal=causal)
if debug: if debug:
block_type = type(block).__name__ block_type = type(block).__name__
@@ -443,10 +444,10 @@ class LTX2VideoDecoder(nn.Module):
return x return x
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder: def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path from pathlib import Path
import json
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning) from safetensors import safe_open
model_path = Path(model_path) model_path = Path(model_path)
@@ -461,6 +462,25 @@ def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX
raise FileNotFoundError(f"VAE weights not found at {model_path}") raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...") print(f"Loading VAE decoder from {weights_path}...")
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
weights = mx.load(str(weights_path)) weights = mx.load(str(weights_path))
# Determine prefix based on weight keys # Determine prefix based on weight keys

View File

@@ -1,7 +1,4 @@
"""Utility functions for MLX Video."""
import math import math
from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -22,6 +19,28 @@ def get_model_path(model_repo: str):
allow_patterns=["*.safetensors", "*.json"], allow_patterns=["*.safetensors", "*.json"],
)) ))
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
if quantization is not None:
def get_class_predicate(p, m):
# Handle custom per layer quantizations
if p in quantization:
return quantization[p]
if not hasattr(m, "to_quantized"):
return False
# Skip layers not divisible by 64
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
mode=quantization.get("mode", "affine"),
class_predicate=get_class_predicate,
)
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:

View File

@@ -18,7 +18,8 @@ dependencies = [
"transformers[tokenizers]", "transformers[tokenizers]",
"tqdm", "tqdm",
"opencv-python>=4.12.0.88", "opencv-python>=4.12.0.88",
"Pillow>=10.3.0" "Pillow>=10.3.0",
"mlx-vlm"
] ]
license = {text="MIT"} license = {text="MIT"}
authors = [ authors = [

1424
uv.lock generated

File diff suppressed because it is too large Load Diff