Merge pull request #9 from Blaizzy/pc/fix-text-encoder
Fix text encoder
This commit is contained in:
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
# ANSI color codes
|
# ANSI color codes
|
||||||
class Colors:
|
class Colors:
|
||||||
@@ -113,9 +114,10 @@ def denoise(
|
|||||||
text_embeddings: mx.array,
|
text_embeddings: mx.array,
|
||||||
transformer: LTXModel,
|
transformer: LTXModel,
|
||||||
sigmas: list,
|
sigmas: list,
|
||||||
|
verbose: bool = True,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Run denoising loop."""
|
"""Run denoising loop."""
|
||||||
for i in range(len(sigmas) - 1):
|
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
|
||||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||||
|
|
||||||
b, c, f, h, w = latents.shape
|
b, c, f, h, w = latents.shape
|
||||||
@@ -148,6 +150,7 @@ def denoise(
|
|||||||
|
|
||||||
def generate_video(
|
def generate_video(
|
||||||
model_repo: str,
|
model_repo: str,
|
||||||
|
text_encoder_repo: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
height: int = 512,
|
height: int = 512,
|
||||||
width: int = 512,
|
width: int = 512,
|
||||||
@@ -156,6 +159,10 @@ def generate_video(
|
|||||||
fps: int = 24,
|
fps: int = 24,
|
||||||
output_path: str = "output.mp4",
|
output_path: str = "output.mp4",
|
||||||
save_frames: bool = False,
|
save_frames: bool = False,
|
||||||
|
verbose: bool = True,
|
||||||
|
enhance_prompt: bool = False,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
temperature: float = 0.7,
|
||||||
):
|
):
|
||||||
"""Generate video from text prompt.
|
"""Generate video from text prompt.
|
||||||
|
|
||||||
@@ -175,11 +182,18 @@ def generate_video(
|
|||||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||||
|
|
||||||
|
if num_frames % 8 != 1:
|
||||||
|
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
|
||||||
|
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
|
||||||
|
num_frames = adjusted_num_frames
|
||||||
|
|
||||||
|
|
||||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
# Get model path
|
# Get model path
|
||||||
model_path = get_model_path(model_repo)
|
model_path = get_model_path(model_repo)
|
||||||
|
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||||
|
|
||||||
# Calculate latent dimensions
|
# Calculate latent dimensions
|
||||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||||
@@ -191,10 +205,16 @@ def generate_video(
|
|||||||
# Load text encoder
|
# Load text encoder
|
||||||
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
|
print(f"{Colors.BLUE}📝 Loading text encoder...{Colors.RESET}")
|
||||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||||
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
text_encoder = LTX2TextEncoder()
|
||||||
text_encoder.load(str(model_path))
|
text_encoder.load(model_path=model_path, text_encoder_path=text_encoder_path)
|
||||||
mx.eval(text_encoder.parameters())
|
mx.eval(text_encoder.parameters())
|
||||||
|
|
||||||
|
# Optionally enhance the prompt
|
||||||
|
if enhance_prompt:
|
||||||
|
print(f"{Colors.MAGENTA}✨ Enhancing prompt...{Colors.RESET}")
|
||||||
|
prompt = text_encoder.enhance_t2v(prompt, max_tokens=max_tokens, temperature=temperature, seed=seed, verbose=verbose)
|
||||||
|
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
text_embeddings, _ = text_encoder(prompt)
|
text_embeddings, _ = text_encoder(prompt)
|
||||||
mx.eval(text_embeddings)
|
mx.eval(text_embeddings)
|
||||||
|
|
||||||
@@ -236,7 +256,7 @@ def generate_video(
|
|||||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
|
||||||
|
|
||||||
# Upsample latents
|
# Upsample latents
|
||||||
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
||||||
@@ -265,7 +285,7 @@ def generate_video(
|
|||||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
|
||||||
|
|
||||||
del transformer
|
del transformer
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
@@ -295,7 +315,7 @@ def generate_video(
|
|||||||
for frame in video_np:
|
for frame in video_np:
|
||||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
out.release()
|
out.release()
|
||||||
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}")
|
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||||
|
|
||||||
@@ -308,6 +328,7 @@ def generate_video(
|
|||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
||||||
|
print(f"{Colors.BOLD}{Colors.GREEN}✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
|
||||||
|
|
||||||
return video_np
|
return video_np
|
||||||
|
|
||||||
@@ -361,7 +382,7 @@ Examples:
|
|||||||
help="Frames per second for output video (default: 24)"
|
help="Frames per second for output video (default: 24)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output", "-o",
|
"--output-path",
|
||||||
type=str,
|
type=str,
|
||||||
default="output.mp4",
|
default="output.mp4",
|
||||||
help="Output video path (default: output.mp4)"
|
help="Output video path (default: output.mp4)"
|
||||||
@@ -377,18 +398,38 @@ Examples:
|
|||||||
default="Lightricks/LTX-2",
|
default="Lightricks/LTX-2",
|
||||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-repo",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Text encoder repository to use (default: None)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Verbose output"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enhance-prompt",
|
||||||
|
action="store_true",
|
||||||
|
help="Enhance the prompt using Gemma before generation"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum number of tokens to generate (default: 512)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help="Temperature for prompt enhancement (default: 0.7)"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
generate_video(
|
generate_video(
|
||||||
model_repo=args.model_repo,
|
**vars(args)
|
||||||
prompt=args.prompt,
|
|
||||||
height=args.height,
|
|
||||||
width=args.width,
|
|
||||||
num_frames=args.num_frames,
|
|
||||||
seed=args.seed,
|
|
||||||
fps=args.fps,
|
|
||||||
output_path=args.output,
|
|
||||||
save_frames=args.save_frames,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -384,8 +384,9 @@ class LTXModel(nn.Module):
|
|||||||
video_config = config.get_video_config()
|
video_config = config.get_video_config()
|
||||||
audio_config = config.get_audio_config()
|
audio_config = config.get_audio_config()
|
||||||
|
|
||||||
self.transformer_blocks = [
|
|
||||||
BasicAVTransformerBlock(
|
self.transformer_blocks = {
|
||||||
|
idx: BasicAVTransformerBlock(
|
||||||
idx=idx,
|
idx=idx,
|
||||||
video=video_config,
|
video=video_config,
|
||||||
audio=audio_config,
|
audio=audio_config,
|
||||||
@@ -393,7 +394,7 @@ class LTXModel(nn.Module):
|
|||||||
norm_eps=config.norm_eps,
|
norm_eps=config.norm_eps,
|
||||||
)
|
)
|
||||||
for idx in range(config.num_layers)
|
for idx in range(config.num_layers)
|
||||||
]
|
}
|
||||||
|
|
||||||
def _process_transformer_blocks(
|
def _process_transformer_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -401,7 +402,7 @@ class LTXModel(nn.Module):
|
|||||||
audio: Optional[TransformerArgs],
|
audio: Optional[TransformerArgs],
|
||||||
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
) -> Tuple[Optional[TransformerArgs], Optional[TransformerArgs]]:
|
||||||
"""Process through all transformer blocks."""
|
"""Process through all transformer blocks."""
|
||||||
for block in self.transformer_blocks:
|
for block in self.transformer_blocks.values():
|
||||||
video, audio = block(video=video, audio=audio)
|
video, audio = block(video=video, audio=audio)
|
||||||
return video, audio
|
return video, audio
|
||||||
|
|
||||||
|
|||||||
30
mlx_video/models/ltx/prompts/gemma_i2v_system_prompt.txt
Normal file
30
mlx_video/models/ltx/prompts/gemma_i2v_system_prompt.txt
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
|
||||||
|
|
||||||
|
#### Guidelines:
|
||||||
|
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
|
||||||
|
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
|
||||||
|
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
|
||||||
|
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||||
|
- Chronological flow: Use temporal connectors ("as," "then," "while").
|
||||||
|
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
|
||||||
|
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
|
||||||
|
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
|
||||||
|
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
|
||||||
|
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
|
||||||
|
|
||||||
|
#### Important notes:
|
||||||
|
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
|
||||||
|
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
|
||||||
|
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||||
|
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
|
||||||
|
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
|
||||||
|
- Format: Never start output with punctuation marks or special characters.
|
||||||
|
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||||
|
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
|
||||||
|
|
||||||
|
#### Output Format (Strict):
|
||||||
|
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
|
||||||
|
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||||
|
|
||||||
|
#### Example output:
|
||||||
|
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
|
||||||
40
mlx_video/models/ltx/prompts/gemma_t2v_system_prompt.txt
Normal file
40
mlx_video/models/ltx/prompts/gemma_t2v_system_prompt.txt
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||||
|
|
||||||
|
#### Guidelines
|
||||||
|
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
|
||||||
|
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
|
||||||
|
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
|
||||||
|
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||||
|
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
|
||||||
|
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
|
||||||
|
- Speech (only when requested):
|
||||||
|
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
|
||||||
|
- Specify language if not English and accent if relevant.
|
||||||
|
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
|
||||||
|
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
|
||||||
|
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
|
||||||
|
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
|
||||||
|
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
|
||||||
|
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
|
||||||
|
|
||||||
|
#### Important notes:
|
||||||
|
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
|
||||||
|
- Camera motion: DO NOT invent camera motion unless requested by the user.
|
||||||
|
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
|
||||||
|
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||||
|
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
|
||||||
|
- Format: DO NOT start your response with special characters.
|
||||||
|
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||||
|
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
|
||||||
|
|
||||||
|
#### Output Format (Strict):
|
||||||
|
- Single continuous paragraph in natural language (English).
|
||||||
|
- NO titles, headings, prefaces, code fences, or Markdown.
|
||||||
|
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||||
|
|
||||||
|
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
|
||||||
|
|
||||||
|
#### Example
|
||||||
|
Input: "A woman at a coffee shop talking on the phone"
|
||||||
|
Output:
|
||||||
|
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.
|
||||||
@@ -49,6 +49,12 @@ def apply_interleaved_rotary_emb(
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor with interleaved rotary embeddings applied
|
Tensor with interleaved rotary embeddings applied
|
||||||
"""
|
"""
|
||||||
|
# Compute in float32 for better precision
|
||||||
|
input_dtype = input_tensor.dtype
|
||||||
|
input_tensor = input_tensor.astype(mx.float32)
|
||||||
|
cos_freqs = cos_freqs.astype(mx.float32)
|
||||||
|
sin_freqs = sin_freqs.astype(mx.float32)
|
||||||
|
|
||||||
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
# Reshape to pair adjacent dimensions: (..., dim) -> (..., dim/2, 2)
|
||||||
shape = input_tensor.shape
|
shape = input_tensor.shape
|
||||||
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
input_tensor = mx.reshape(input_tensor, shape[:-1] + (shape[-1] // 2, 2))
|
||||||
@@ -67,7 +73,7 @@ def apply_interleaved_rotary_emb(
|
|||||||
# Apply rotary embeddings
|
# Apply rotary embeddings
|
||||||
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
out = input_tensor * cos_freqs + t_rot * sin_freqs
|
||||||
|
|
||||||
return out
|
return out.astype(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
def rotate_half_interleaved(x: mx.array) -> mx.array:
|
||||||
|
|||||||
@@ -1,216 +1,202 @@
|
|||||||
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
|
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from mlx_video.utils import rms_norm
|
from mlx_video.utils import rms_norm, apply_quantization
|
||||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||||
|
|
||||||
@dataclass
|
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||||
class Gemma3Config:
|
from mlx_vlm.models.gemma3.config import TextConfig
|
||||||
"""Configuration for Gemma 3 text model."""
|
|
||||||
hidden_size: int = 3840
|
|
||||||
num_attention_heads: int = 16
|
|
||||||
num_key_value_heads: int = 8
|
|
||||||
head_dim: int = 256
|
|
||||||
intermediate_size: int = 15360
|
|
||||||
num_hidden_layers: int = 48
|
|
||||||
rms_norm_eps: float = 1e-6
|
|
||||||
rope_theta: float = 1000000.0
|
|
||||||
vocab_size: int = 262208
|
|
||||||
max_position_embeddings: int = 131072
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
# Path to system prompts
|
||||||
"""RMS Normalization (Gemma style with 1+weight scaling)."""
|
PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||||
|
|
||||||
def __init__(self, dims: int, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
def _load_system_prompt(prompt_name: str) -> str:
|
||||||
self.eps = eps
|
"""Load a system prompt from the prompts directory."""
|
||||||
# Gemma initializes to ones, but uses (1+weight) scaling
|
prompt_path = PROMPTS_DIR / prompt_name
|
||||||
# After loading weights, weight will have the actual learned values
|
if prompt_path.exists():
|
||||||
self.weight = mx.ones((dims,))
|
with open(prompt_path, "r") as f:
|
||||||
|
return f.read()
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
raise FileNotFoundError(f"System prompt not found: {prompt_path}")
|
||||||
# Gemma-style RMSNorm uses (1 + weight) as the scale factor
|
|
||||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
|
||||||
|
class LanguageModel(nn.Module):
|
||||||
|
|
||||||
def apply_rotary_emb(
|
|
||||||
q: mx.array,
|
def __init__(self, config: TextConfig):
|
||||||
k: mx.array,
|
|
||||||
positions: mx.array,
|
|
||||||
head_dim: int,
|
|
||||||
rope_theta: float = 1000000.0,
|
|
||||||
) -> Tuple[mx.array, mx.array]:
|
|
||||||
"""Apply rotary position embeddings to Q and K."""
|
|
||||||
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
|
|
||||||
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
|
|
||||||
cos = mx.cos(freqs)
|
|
||||||
sin = mx.sin(freqs)
|
|
||||||
cos = cos[:, :, None, :]
|
|
||||||
sin = sin[:, :, None, :]
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return mx.concatenate([-x2, x1], axis=-1)
|
|
||||||
|
|
||||||
cos_full = mx.concatenate([cos, cos], axis=-1)
|
|
||||||
sin_full = mx.concatenate([sin, sin], axis=-1)
|
|
||||||
q_embed = q * cos_full + rotate_half(q) * sin_full
|
|
||||||
k_embed = k * cos_full + rotate_half(k) * sin_full
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3MLP(nn.Module):
|
|
||||||
"""Gemma 3 MLP with gated activation."""
|
|
||||||
|
|
||||||
def __init__(self, config: Gemma3Config):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
|
||||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
|
||||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
gate = nn.gelu_approx(self.gate_proj(x))
|
|
||||||
up = self.up_proj(x)
|
|
||||||
return self.down_proj(gate * up)
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3Attention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: Gemma3Config):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# Create config matching LTX-2 text encoder requirements
|
||||||
self.config = config
|
self.config = config
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.num_kv_heads = config.num_key_value_heads
|
|
||||||
self.head_dim = config.head_dim
|
|
||||||
self.scale = 1.0 / math.sqrt(config.head_dim)
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
# Create the Gemma3Model from mlx-vlm
|
||||||
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
self.model = Gemma3Model(self.config)
|
||||||
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
|
||||||
|
|
||||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
def _create_causal_mask_with_padding(
|
||||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
self,
|
||||||
hidden_states: mx.array,
|
seq_len: int,
|
||||||
positions: mx.array,
|
attention_mask: Optional[mx.array],
|
||||||
attention_mask: Optional[mx.array] = None,
|
dtype: mx.Dtype,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
batch_size, seq_len, _ = hidden_states.shape
|
|
||||||
|
|
||||||
q = self.q_proj(hidden_states)
|
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
|
||||||
k = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
||||||
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
|
||||||
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
|
||||||
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
|
|
||||||
|
|
||||||
q = mx.transpose(q, (0, 2, 1, 3))
|
|
||||||
k = mx.transpose(k, (0, 2, 1, 3))
|
|
||||||
v = mx.transpose(v, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
# Create causal mask (lower triangular)
|
|
||||||
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9
|
batch_size = attention_mask.shape[0]
|
||||||
|
|
||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask)
|
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
|
||||||
out = mx.transpose(out, (0, 2, 1, 3))
|
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
|
||||||
out = mx.reshape(out, (batch_size, seq_len, -1))
|
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||||
|
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
|
||||||
return self.o_proj(out)
|
mx.full(combined.shape, min_val, dtype=dtype))
|
||||||
|
return mask[:, None, :, :]
|
||||||
|
else:
|
||||||
class Gemma3DecoderLayer(nn.Module):
|
# No padding mask, just causal
|
||||||
|
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||||
def __init__(self, config: Gemma3Config):
|
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||||
super().__init__()
|
mx.full((seq_len, seq_len), min_val, dtype=dtype))
|
||||||
self.self_attn = Gemma3Attention(config)
|
return mask[None, None, :, :] # (1, 1, seq, seq)
|
||||||
self.mlp = Gemma3MLP(config)
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
hidden_states: mx.array,
|
inputs: mx.array,
|
||||||
positions: mx.array,
|
input_embeddings: Optional[mx.array] = None,
|
||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
output_hidden_states: bool = False,
|
||||||
residual = hidden_states
|
cache: Optional[List[mx.array]] = None,
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3TextModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: Gemma3Config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
||||||
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
# Gemma scales embeddings by sqrt(hidden_size)
|
|
||||||
self.embed_scale = config.hidden_size ** 0.5
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids: mx.array,
|
|
||||||
attention_mask: Optional[mx.array] = None,
|
|
||||||
output_hidden_states: bool = True,
|
|
||||||
) -> Tuple[mx.array, List[mx.array]]:
|
) -> Tuple[mx.array, List[mx.array]]:
|
||||||
|
"""Forward pass returning hidden states.
|
||||||
|
|
||||||
batch_size, seq_len = input_ids.shape
|
Args:
|
||||||
|
input_ids: Input token IDs of shape (batch, seq_len)
|
||||||
|
attention_mask: Optional attention mask (1 for valid, 0 for padding)
|
||||||
|
output_hidden_states: Whether to return all hidden states
|
||||||
|
|
||||||
# Gemma scales embeddings by sqrt(hidden_size)
|
Returns:
|
||||||
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
Tuple of (final_hidden_states, list_of_all_hidden_states)
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = inputs.shape
|
||||||
|
|
||||||
all_hidden_states = [hidden_states] if output_hidden_states else []
|
# Get embeddings
|
||||||
|
h = input_embeddings if input_embeddings is not None else self.model.embed_tokens(inputs)
|
||||||
|
|
||||||
positions = mx.arange(seq_len)[None, :].astype(mx.int32)
|
# Apply Gemma scaling
|
||||||
positions = mx.broadcast_to(positions, (batch_size, seq_len))
|
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||||
|
mx.eval(h)
|
||||||
|
|
||||||
for layer in self.layers:
|
all_hidden_states = [h] if output_hidden_states else []
|
||||||
hidden_states = layer(hidden_states, positions, attention_mask)
|
|
||||||
|
# Set up cache (all None for non-cached inference)
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.model.layers)
|
||||||
|
|
||||||
|
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
||||||
|
|
||||||
|
sliding_mask = full_causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
num_layers = len(self.model.layers)
|
||||||
|
for i, layer in enumerate(self.model.layers):
|
||||||
|
is_global = (
|
||||||
|
i % self.config.sliding_window_pattern
|
||||||
|
== self.config.sliding_window_pattern - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select appropriate mask for this layer
|
||||||
|
if is_global:
|
||||||
|
local_mask = full_causal_mask
|
||||||
|
else:
|
||||||
|
local_mask = sliding_mask
|
||||||
|
|
||||||
|
h = layer(h, local_mask, cache[i])
|
||||||
|
mx.eval(h)
|
||||||
|
|
||||||
|
if output_hidden_states and i < num_layers - 1:
|
||||||
|
all_hidden_states.append(h)
|
||||||
|
|
||||||
|
# Apply final norm
|
||||||
|
hidden_states = self.model.norm(h)
|
||||||
|
mx.eval(hidden_states)
|
||||||
|
|
||||||
|
# Append the final normalized output as the last hidden state
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states.append(hidden_states)
|
all_hidden_states.append(hidden_states)
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
return hidden_states, all_hidden_states
|
return hidden_states, all_hidden_states
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Return logits
|
||||||
|
return self.model.embed_tokens.as_linear(hidden_states)
|
||||||
|
|
||||||
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
prefix = "language_model."
|
||||||
|
sanitized = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
if key.startswith(prefix):
|
||||||
|
if hasattr(value, "dtype") and value.dtype == mx.float32:
|
||||||
|
sanitized[key[len(prefix):]] = value.astype(mx.bfloat16)
|
||||||
|
else:
|
||||||
|
sanitized[key[len(prefix):]] = value
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self) -> List[nn.Module]:
|
||||||
|
return self.model.layers
|
||||||
|
|
||||||
|
def make_cache(self):
|
||||||
|
from mlx_vlm.models.cache import KVCache, RotatingKVCache
|
||||||
|
caches = []
|
||||||
|
for i in range(len(self.layers)):
|
||||||
|
if (
|
||||||
|
i % self.config.sliding_window_pattern
|
||||||
|
== self.config.sliding_window_pattern - 1
|
||||||
|
):
|
||||||
|
caches.append(KVCache())
|
||||||
|
else:
|
||||||
|
caches.append(RotatingKVCache(max_size=self.config.sliding_window))
|
||||||
|
return caches
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_path: str):
|
||||||
|
import json
|
||||||
|
weight_files = sorted(Path(model_path).glob("*.safetensors"))
|
||||||
|
config_file = Path(model_path) / "config.json"
|
||||||
|
config_dict = {}
|
||||||
|
if config_file.exists():
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
|
||||||
|
language_model = cls(config=TextConfig.from_dict(config_dict["text_config"]))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Config file not found at {model_path}")
|
||||||
|
|
||||||
|
quantization = config_dict.get("quantization", None)
|
||||||
|
weights = {}
|
||||||
|
for i, wf in enumerate(weight_files):
|
||||||
|
weights.update(mx.load(str(wf)))
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(language_model, "sanitize"):
|
||||||
|
weights = language_model.sanitize(weights=weights)
|
||||||
|
|
||||||
|
|
||||||
|
apply_quantization(model=language_model, weights=weights, quantization=quantization)
|
||||||
|
|
||||||
|
language_model.load_weights(list(weights.items()), strict=False)
|
||||||
|
|
||||||
|
return language_model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectorAttention(nn.Module):
|
class ConnectorAttention(nn.Module):
|
||||||
@@ -278,7 +264,7 @@ class GEGLU(nn.Module):
|
|||||||
self.proj = nn.Linear(in_dim, out_dim, bias=True)
|
self.proj = nn.Linear(in_dim, out_dim, bias=True)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
return nn.gelu_approx(self.proj(x))
|
return nn.gelu(self.proj(x))
|
||||||
|
|
||||||
|
|
||||||
class ConnectorFeedForward(nn.Module):
|
class ConnectorFeedForward(nn.Module):
|
||||||
@@ -359,28 +345,28 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||||
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
||||||
|
|
||||||
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
|
||||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||||
"""
|
"""
|
||||||
import math
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
||||||
theta = self.positional_embedding_theta
|
theta = self.positional_embedding_theta
|
||||||
max_pos = [1] # Default for connector
|
max_pos = [1] # Default for connector
|
||||||
n_elem = 2 * len(max_pos) # = 2
|
n_elem = 2 * len(max_pos) # = 2
|
||||||
|
|
||||||
# Generate frequency indices (matches generate_freq_grid_pytorch)
|
|
||||||
start = 1.0
|
start = 1.0
|
||||||
end = theta
|
end = theta
|
||||||
num_indices = dim // n_elem # 1920
|
num_indices = dim // n_elem # 1920
|
||||||
|
|
||||||
log_start = math.log(start) / math.log(theta) # = 0
|
# Use numpy float64 for precision
|
||||||
log_end = math.log(end) / math.log(theta) # = 1
|
log_start = np.log(start) / np.log(theta) # = 0
|
||||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
log_end = np.log(end) / np.log(theta) # = 1
|
||||||
indices = (theta ** lin_space) * (math.pi / 2)
|
lin_space = np.linspace(log_start, log_end, num_indices, dtype=np.float64)
|
||||||
|
indices = (np.power(theta, lin_space) * (np.pi / 2)).astype(np.float32)
|
||||||
|
|
||||||
# Generate positions and compute freqs (matches generate_freqs)
|
# Generate positions and compute freqs (matches generate_freqs)
|
||||||
positions = mx.arange(seq_len).astype(mx.float32)
|
positions = np.arange(seq_len, dtype=np.float64)
|
||||||
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
||||||
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
||||||
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
||||||
@@ -390,17 +376,17 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
freqs = scaled_positions[:, None] * indices[None, :]
|
freqs = scaled_positions[:, None] * indices[None, :]
|
||||||
|
|
||||||
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
||||||
cos_freq = mx.cos(freqs)
|
cos_freq = np.cos(freqs)
|
||||||
sin_freq = mx.sin(freqs)
|
sin_freq = np.sin(freqs)
|
||||||
|
|
||||||
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
||||||
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
||||||
cos_full = mx.repeat(cos_freq, 2, axis=-1)
|
cos_full = np.repeat(cos_freq, 2, axis=-1)
|
||||||
sin_full = mx.repeat(sin_freq, 2, axis=-1)
|
sin_full = np.repeat(sin_freq, 2, axis=-1)
|
||||||
|
|
||||||
# Add batch dimension: (1, seq_len, dim)
|
# Add batch dimension and convert to MLX: (1, seq_len, dim)
|
||||||
cos_full = cos_full[None, :, :]
|
cos_full = mx.array(cos_full[None, :, :].astype(np.float32))
|
||||||
sin_full = sin_full[None, :, :]
|
sin_full = mx.array(sin_full[None, :, :].astype(np.float32))
|
||||||
|
|
||||||
return cos_full.astype(dtype), sin_full.astype(dtype)
|
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||||
|
|
||||||
@@ -547,45 +533,19 @@ class GemmaFeaturesExtractor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_gemma3_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|
||||||
sanitized = {}
|
|
||||||
|
|
||||||
for key, value in weights.items():
|
|
||||||
new_key = None
|
|
||||||
|
|
||||||
if key.startswith("base_text_encoder.language_model."):
|
|
||||||
new_key = key.replace("base_text_encoder.language_model.", "")
|
|
||||||
elif key.startswith("language_model.model."):
|
|
||||||
new_key = key.replace("language_model.model.", "")
|
|
||||||
elif key.startswith("language_model."):
|
|
||||||
new_key = key.replace("language_model.", "")
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if new_key is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
sanitized[new_key] = value
|
|
||||||
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
|
|
||||||
class LTX2TextEncoder(nn.Module):
|
class LTX2TextEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: str = "Lightricks/LTX-2",
|
|
||||||
hidden_dim: int = 3840,
|
hidden_dim: int = 3840,
|
||||||
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
num_layers: int = 49, # 48 transformer layers + 1 embedding
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._model_path = model_path
|
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.language_model = None
|
||||||
# Gemma 3 model
|
|
||||||
self.config = Gemma3Config()
|
|
||||||
self.model = Gemma3TextModel(self.config)
|
|
||||||
|
|
||||||
# Feature extractor: 3840*49 -> 3840
|
# Feature extractor: 3840*49 -> 3840
|
||||||
self.feature_extractor = GemmaFeaturesExtractor(
|
self.feature_extractor = GemmaFeaturesExtractor(
|
||||||
@@ -604,37 +564,16 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
|
|
||||||
self.processor = None
|
self.processor = None
|
||||||
|
|
||||||
def load(self, model_path: Optional[str] = None):
|
def load(self, model_path: Optional[str] = None, text_encoder_path: Optional[str] = "google/gemma-3-12b-it"):
|
||||||
path = model_path or self._model_path
|
|
||||||
|
|
||||||
# Load Gemma weights from text_encoder subdirectory
|
if Path(text_encoder_path / "text_encoder").is_dir():
|
||||||
if Path(path).is_dir():
|
text_encoder_path = str(text_encoder_path / "text_encoder")
|
||||||
text_encoder_path = Path(path) / "text_encoder"
|
|
||||||
if text_encoder_path.exists():
|
|
||||||
gemma_path = str(text_encoder_path)
|
|
||||||
else:
|
|
||||||
gemma_path = path
|
|
||||||
else:
|
|
||||||
gemma_path = path
|
|
||||||
|
|
||||||
print(f"Loading Gemma 3 text encoder from {gemma_path}...")
|
self.language_model = LanguageModel.from_pretrained(text_encoder_path)
|
||||||
weight_files = sorted(Path(gemma_path).glob("*.safetensors"))
|
|
||||||
all_weights = {}
|
|
||||||
for i, wf in enumerate(weight_files):
|
|
||||||
print(f" Loading weight file {i+1}/{len(weight_files)}...")
|
|
||||||
weights = mx.load(str(wf))
|
|
||||||
all_weights.update(weights)
|
|
||||||
|
|
||||||
# Sanitize and load Gemma weights
|
|
||||||
sanitized = sanitize_gemma3_weights(all_weights)
|
|
||||||
print(f" Sanitized Gemma weights: {len(sanitized)}")
|
|
||||||
self.model.load_weights(list(sanitized.items()), strict=False)
|
|
||||||
|
|
||||||
# Load transformer weights for feature extractor and connector
|
# Load transformer weights for feature extractor and connector
|
||||||
transformer_path = Path(model_path or self._model_path)
|
transformer_files = list(model_path.glob("ltx-2-19*.safetensors"))
|
||||||
transformer_files = list(transformer_path.glob("ltx-2*.safetensors"))
|
|
||||||
if transformer_files:
|
if transformer_files:
|
||||||
print(f"Loading transformer weights for text pipeline...")
|
|
||||||
transformer_weights = mx.load(str(transformer_files[0]))
|
transformer_weights = mx.load(str(transformer_files[0]))
|
||||||
|
|
||||||
# Load feature extractor (aggregate_embed)
|
# Load feature extractor (aggregate_embed)
|
||||||
@@ -642,7 +581,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
self.feature_extractor.aggregate_embed.weight = transformer_weights[
|
||||||
"text_embedding_projection.aggregate_embed.weight"
|
"text_embedding_projection.aggregate_embed.weight"
|
||||||
]
|
]
|
||||||
print(" Loaded aggregate_embed weights")
|
|
||||||
|
|
||||||
# Load video_embeddings_connector weights
|
# Load video_embeddings_connector weights
|
||||||
connector_weights = {}
|
connector_weights = {}
|
||||||
@@ -663,20 +602,18 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
self.video_embeddings_connector.load_weights(
|
self.video_embeddings_connector.load_weights(
|
||||||
list(mapped_weights.items()), strict=False
|
list(mapped_weights.items()), strict=False
|
||||||
)
|
)
|
||||||
print(f" Loaded {len(connector_weights)} connector weights")
|
|
||||||
|
|
||||||
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
# Manually load learnable_registers (it's a plain mx.array, not a parameter)
|
||||||
if "learnable_registers" in connector_weights:
|
if "learnable_registers" in connector_weights:
|
||||||
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
self.video_embeddings_connector.learnable_registers = connector_weights["learnable_registers"]
|
||||||
print(f" Loaded learnable_registers: {connector_weights['learnable_registers'].shape}")
|
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
tokenizer_path = Path(model_path or self._model_path) / "tokenizer"
|
tokenizer_path = model_path / "tokenizer"
|
||||||
if tokenizer_path.exists():
|
if tokenizer_path.exists():
|
||||||
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
self.processor = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
|
||||||
else:
|
else:
|
||||||
self.processor = AutoTokenizer.from_pretrained(gemma_path, trust_remote_code=True)
|
self.processor = AutoTokenizer.from_pretrained(text_encoder_path, trust_remote_code=True)
|
||||||
# Set left padding to match official LTX-2 text encoder
|
# Set left padding to match official LTX-2 text encoder
|
||||||
self.processor.padding_side = "left"
|
self.processor.padding_side = "left"
|
||||||
|
|
||||||
@@ -691,7 +628,6 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
if self.processor is None:
|
if self.processor is None:
|
||||||
raise RuntimeError("Model not loaded. Call load() first.")
|
raise RuntimeError("Model not loaded. Call load() first.")
|
||||||
|
|
||||||
# Tokenize with left padding (as in PyTorch version)
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
prompt,
|
prompt,
|
||||||
return_tensors="np",
|
return_tensors="np",
|
||||||
@@ -702,28 +638,19 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
input_ids = mx.array(inputs["input_ids"])
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
attention_mask = mx.array(inputs["attention_mask"])
|
attention_mask = mx.array(inputs["attention_mask"])
|
||||||
|
|
||||||
# Get all hidden states from Gemma
|
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||||
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
|
|
||||||
|
|
||||||
# Normalize and concatenate all hidden states
|
|
||||||
concat_hidden = norm_and_concat_hidden_states(
|
concat_hidden = norm_and_concat_hidden_states(
|
||||||
all_hidden_states, attention_mask, padding_side="left"
|
all_hidden_states, attention_mask, padding_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Project through feature extractor
|
|
||||||
features = self.feature_extractor(concat_hidden)
|
features = self.feature_extractor(concat_hidden)
|
||||||
|
|
||||||
# Convert attention mask to additive format for connector
|
|
||||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||||
|
|
||||||
# Process through connector
|
|
||||||
# Note: connector replaces padding with learnable registers and resets mask to zeros
|
|
||||||
# This means all positions now have valid embeddings (no need for final masking)
|
|
||||||
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||||
|
|
||||||
# Return embeddings without zeroing - the connector's register replacement
|
|
||||||
# means all positions have meaningful values now
|
|
||||||
return embeddings, attention_mask
|
return embeddings, attention_mask
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@@ -733,6 +660,177 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
return self.encode(prompt, max_length)
|
return self.encode(prompt, max_length)
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def default_t2v_system_prompt(self) -> str:
|
||||||
|
"""Load the default T2V system prompt."""
|
||||||
|
return _load_system_prompt("gemma_t2v_system_prompt.txt")
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def default_i2v_system_prompt(self) -> str:
|
||||||
|
"""Load the default I2V system prompt."""
|
||||||
|
return _load_system_prompt("gemma_i2v_system_prompt.txt")
|
||||||
|
|
||||||
|
def _clean_response(self, response: str) -> str:
|
||||||
|
"""Clean up the generated response."""
|
||||||
|
# Remove leading/trailing whitespace
|
||||||
|
response = response.strip()
|
||||||
|
# Remove any leading punctuation
|
||||||
|
response = re.sub(r'^[^\w\s]+', '', response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _apply_chat_template(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
) -> str:
|
||||||
|
"""Apply Gemma chat template to messages."""
|
||||||
|
# Gemma 3 chat template format
|
||||||
|
formatted = ""
|
||||||
|
for msg in messages:
|
||||||
|
role = msg["role"]
|
||||||
|
content = msg["content"]
|
||||||
|
if role == "system":
|
||||||
|
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
||||||
|
elif role == "user":
|
||||||
|
if isinstance(content, str):
|
||||||
|
formatted += f"<start_of_turn>user\n{content}<end_of_turn>\n"
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Handle multimodal content (image + text)
|
||||||
|
text_parts = [c["text"] for c in content if c.get("type") == "text"]
|
||||||
|
formatted += f"<start_of_turn>user\n{' '.join(text_parts)}<end_of_turn>\n"
|
||||||
|
elif role == "assistant":
|
||||||
|
formatted += f"<start_of_turn>model\n{content}<end_of_turn>\n"
|
||||||
|
# Add generation prompt
|
||||||
|
formatted += "<start_of_turn>model\n"
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
def enhance_t2v(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
seed: int = 42,
|
||||||
|
verbose: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
"""Enhance a text prompt for T2V generation using mlx-lm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The original user prompt
|
||||||
|
max_new_tokens: Maximum number of tokens to generate
|
||||||
|
system_prompt: Optional custom system prompt
|
||||||
|
seed: Random seed for generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enhanced prompt string
|
||||||
|
"""
|
||||||
|
from tqdm import tqdm
|
||||||
|
try:
|
||||||
|
from mlx_lm import stream_generate
|
||||||
|
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("mlx-lm not available for prompt enhancement. Using original prompt.")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
if self.processor is None:
|
||||||
|
raise RuntimeError("Model not loaded. Call load() first.")
|
||||||
|
|
||||||
|
system_prompt = system_prompt or self.default_t2v_system_prompt
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": f"user prompt: {prompt}"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply chat template
|
||||||
|
formatted = self._apply_chat_template(messages)
|
||||||
|
|
||||||
|
# Use mlx-lm generate with temperature sampling
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
# Tokenize
|
||||||
|
inputs = self.processor(
|
||||||
|
formatted,
|
||||||
|
return_tensors="np",
|
||||||
|
add_special_tokens=False,
|
||||||
|
)
|
||||||
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
|
|
||||||
|
sampler = make_sampler(kwargs.get("temperature", 0.7), kwargs.get("top_p", 0.95), top_k=kwargs.get("top_k", -1))
|
||||||
|
logits_processors = make_logits_processors(
|
||||||
|
kwargs.get("logit_bias", None),
|
||||||
|
kwargs.get("repetition_penalty", 1.3),
|
||||||
|
kwargs.get("repetition_context_size", 20),
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_token_count = 0
|
||||||
|
generated_tokens = []
|
||||||
|
for i, response in enumerate(
|
||||||
|
tqdm(
|
||||||
|
stream_generate(
|
||||||
|
self.language_model,
|
||||||
|
tokenizer=self.processor,
|
||||||
|
prompt=input_ids.squeeze(0),
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
sampler=sampler,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
),
|
||||||
|
total=max_tokens,
|
||||||
|
disable=not verbose,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
next_token = mx.array([response.token])
|
||||||
|
input_ids = mx.concatenate([input_ids, next_token[None, :]], axis=1)
|
||||||
|
generated_tokens.append(next_token.squeeze())
|
||||||
|
generated_token_count += 1
|
||||||
|
|
||||||
|
if i % 50 == 0:
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
|
# Check for EOS
|
||||||
|
if response.token == 1 or response.token == 107: # EOS tokens
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Decode only the new tokens
|
||||||
|
|
||||||
|
enhanced_prompt = self.processor.decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
enhanced_prompt = self._clean_response(enhanced_prompt)
|
||||||
|
logging.info(f"Enhanced prompt: {enhanced_prompt}")
|
||||||
|
|
||||||
|
return enhanced_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def enhance_i2v(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image: Optional[mx.array] = None,
|
||||||
|
max_new_tokens: int = 512,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> str:
|
||||||
|
"""Enhance a text prompt for I2V generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The original user prompt
|
||||||
|
image: Optional image tensor (not currently used)
|
||||||
|
max_new_tokens: Maximum number of tokens to generate
|
||||||
|
system_prompt: Optional custom system prompt
|
||||||
|
seed: Random seed for generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enhanced prompt string
|
||||||
|
"""
|
||||||
|
# Use T2V enhancement with I2V system prompt
|
||||||
|
return self.enhance_t2v(
|
||||||
|
prompt,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
system_prompt=system_prompt or self.default_i2v_system_prompt,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
def load_text_encoder(model_path: str = "/tmp/ltx2") -> LTX2TextEncoder:
|
||||||
encoder = LTX2TextEncoder(model_path=model_path)
|
encoder = LTX2TextEncoder(model_path=model_path)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Tuple, Union
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
@@ -85,6 +85,10 @@ class GroupNorm3d(nn.Module):
|
|||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
# x: (N, D, H, W, C)
|
# x: (N, D, H, W, C)
|
||||||
n, d, h, w, c = x.shape
|
n, d, h, w, c = x.shape
|
||||||
|
input_dtype = x.dtype
|
||||||
|
|
||||||
|
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
|
||||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||||
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
|
x = mx.reshape(x, (n, d * h * w, self.num_groups, c // self.num_groups))
|
||||||
@@ -100,7 +104,12 @@ class GroupNorm3d(nn.Module):
|
|||||||
x = mx.reshape(x, (n, d, h, w, c))
|
x = mx.reshape(x, (n, d, h, w, c))
|
||||||
|
|
||||||
# Apply weight and bias
|
# Apply weight and bias
|
||||||
x = x * self.weight + self.bias
|
weight = self.weight.astype(mx.float32)
|
||||||
|
bias = self.bias.astype(mx.float32)
|
||||||
|
x = x * weight + bias
|
||||||
|
|
||||||
|
# Convert back to input dtype
|
||||||
|
x = x.astype(input_dtype)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
@@ -201,14 +201,15 @@ class ResBlockGroup(nn.Module):
|
|||||||
embedding_dim=channels * 4
|
embedding_dim=channels * 4
|
||||||
)
|
)
|
||||||
|
|
||||||
self.res_blocks = [
|
# Use dict with int keys for MLX to track parameters properly
|
||||||
ResnetBlock3DSimple(
|
self.res_blocks = {
|
||||||
|
i: ResnetBlock3DSimple(
|
||||||
channels,
|
channels,
|
||||||
spatial_padding_mode,
|
spatial_padding_mode,
|
||||||
timestep_conditioning=timestep_conditioning
|
timestep_conditioning=timestep_conditioning
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -227,7 +228,7 @@ class ResBlockGroup(nn.Module):
|
|||||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||||
|
|
||||||
for res_block in self.res_blocks:
|
for res_block in self.res_blocks.values():
|
||||||
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -287,37 +288,37 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
self.conv_in = ConvInWrapper()
|
self.conv_in = ConvInWrapper()
|
||||||
|
|
||||||
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
||||||
|
# Use dict with int keys for MLX to track parameters properly
|
||||||
self.up_blocks = [
|
self.up_blocks = {
|
||||||
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
DepthToSpaceUpsample(
|
1: DepthToSpaceUpsample(
|
||||||
dims=3,
|
dims=3,
|
||||||
in_channels=1024,
|
in_channels=1024,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
residual=True, # CRITICAL: Must match PyTorch config!
|
residual=True,
|
||||||
out_channels_reduction_factor=2,
|
out_channels_reduction_factor=2,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
),
|
),
|
||||||
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
DepthToSpaceUpsample(
|
3: DepthToSpaceUpsample(
|
||||||
dims=3,
|
dims=3,
|
||||||
in_channels=512,
|
in_channels=512,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
residual=True, # CRITICAL: Must match PyTorch config!
|
residual=True,
|
||||||
out_channels_reduction_factor=2,
|
out_channels_reduction_factor=2,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
),
|
),
|
||||||
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
DepthToSpaceUpsample(
|
5: DepthToSpaceUpsample(
|
||||||
dims=3,
|
dims=3,
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
residual=True, # CRITICAL: Must match PyTorch config!
|
residual=True,
|
||||||
out_channels_reduction_factor=2,
|
out_channels_reduction_factor=2,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
),
|
),
|
||||||
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
]
|
}
|
||||||
|
|
||||||
final_out_channels = out_channels * patch_size * patch_size
|
final_out_channels = out_channels * patch_size * patch_size
|
||||||
class ConvOutWrapper(nn.Module):
|
class ConvOutWrapper(nn.Module):
|
||||||
@@ -396,7 +397,7 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
if debug:
|
if debug:
|
||||||
debug_stats("After conv_in", x)
|
debug_stats("After conv_in", x)
|
||||||
|
|
||||||
for i, block in enumerate(self.up_blocks):
|
for i, block in self.up_blocks.items():
|
||||||
if isinstance(block, ResBlockGroup):
|
if isinstance(block, ResBlockGroup):
|
||||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||||
else:
|
else:
|
||||||
@@ -443,10 +444,10 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
|
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import json
|
||||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
from safetensors import safe_open
|
||||||
|
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
|
|
||||||
@@ -461,6 +462,25 @@ def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX
|
|||||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||||
|
|
||||||
print(f"Loading VAE decoder from {weights_path}...")
|
print(f"Loading VAE decoder from {weights_path}...")
|
||||||
|
|
||||||
|
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||||
|
if timestep_conditioning is None:
|
||||||
|
try:
|
||||||
|
with safe_open(str(weights_path), framework="numpy") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata and "config" in metadata:
|
||||||
|
configs = json.loads(metadata["config"])
|
||||||
|
vae_config = configs.get("vae", {})
|
||||||
|
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||||
|
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||||
|
else:
|
||||||
|
timestep_conditioning = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||||
|
timestep_conditioning = False
|
||||||
|
|
||||||
|
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||||
|
|
||||||
weights = mx.load(str(weights_path))
|
weights = mx.load(str(weights_path))
|
||||||
|
|
||||||
# Determine prefix based on weight keys
|
# Determine prefix based on weight keys
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
"""Utility functions for MLX Video."""
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@@ -22,6 +19,28 @@ def get_model_path(model_repo: str):
|
|||||||
allow_patterns=["*.safetensors", "*.json"],
|
allow_patterns=["*.safetensors", "*.json"],
|
||||||
))
|
))
|
||||||
|
|
||||||
|
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||||
|
if quantization is not None:
|
||||||
|
def get_class_predicate(p, m):
|
||||||
|
# Handle custom per layer quantizations
|
||||||
|
if p in quantization:
|
||||||
|
return quantization[p]
|
||||||
|
if not hasattr(m, "to_quantized"):
|
||||||
|
return False
|
||||||
|
# Skip layers not divisible by 64
|
||||||
|
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
|
||||||
|
return False
|
||||||
|
# Handle legacy models which may not have everything quantized
|
||||||
|
return f"{p}.scales" in weights
|
||||||
|
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=quantization["group_size"],
|
||||||
|
bits=quantization["bits"],
|
||||||
|
mode=quantization.get("mode", "affine"),
|
||||||
|
class_predicate=get_class_predicate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ dependencies = [
|
|||||||
"transformers[tokenizers]",
|
"transformers[tokenizers]",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"opencv-python>=4.12.0.88",
|
"opencv-python>=4.12.0.88",
|
||||||
"Pillow>=10.3.0"
|
"Pillow>=10.3.0",
|
||||||
|
"mlx-vlm"
|
||||||
]
|
]
|
||||||
license = {text="MIT"}
|
license = {text="MIT"}
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
Reference in New Issue
Block a user