Add frame number validation in video generation and update Gemma3 text encoder to use validated mlx-vlm implementation

This commit is contained in:
Prince Canuma
2026-01-13 17:12:11 +01:00
parent 61b003ff2c
commit 01d895bc77
4 changed files with 1546 additions and 197 deletions

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm
# ANSI color codes # ANSI color codes
class Colors: class Colors:
@@ -113,9 +114,10 @@ def denoise(
text_embeddings: mx.array, text_embeddings: mx.array,
transformer: LTXModel, transformer: LTXModel,
sigmas: list, sigmas: list,
verbose: bool = True,
) -> mx.array: ) -> mx.array:
"""Run denoising loop.""" """Run denoising loop."""
for i in range(len(sigmas) - 1): for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1] sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape b, c, f, h, w = latents.shape
@@ -156,6 +158,7 @@ def generate_video(
fps: int = 24, fps: int = 24,
output_path: str = "output.mp4", output_path: str = "output.mp4",
save_frames: bool = False, save_frames: bool = False,
verbose: bool = True,
): ):
"""Generate video from text prompt. """Generate video from text prompt.
@@ -175,6 +178,12 @@ def generate_video(
assert height % 64 == 0, f"Height must be divisible by 64, got {height}" assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}" assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
@@ -236,7 +245,7 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions) mx.eval(positions)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
# Upsample latents # Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
@@ -265,7 +274,7 @@ def generate_video(
latents = noise * noise_scale + latents * (1 - noise_scale) latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents) mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
del transformer del transformer
mx.clear_cache() mx.clear_cache()
@@ -295,7 +304,7 @@ def generate_video(
for frame in video_np: for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release() out.release()
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}") print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
except Exception as e: except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}") print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
@@ -308,6 +317,7 @@ def generate_video(
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}") print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np return video_np
@@ -377,6 +387,11 @@ Examples:
default="Lightricks/LTX-2", default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)" help="Model repository to use (default: Lightricks/LTX-2)"
) )
parser.add_argument(
"--verbose",
action="store_true",
help="Verbose output"
)
args = parser.parse_args() args = parser.parse_args()
generate_video( generate_video(
@@ -389,6 +404,7 @@ Examples:
fps=args.fps, fps=args.fps,
output_path=args.output, output_path=args.output,
save_frames=args.save_frames, save_frames=args.save_frames,
verbose=args.verbose,
) )

View File

@@ -1,4 +1,8 @@
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline.""" """Gemma 3 Text Encoder for LTX-2 - Full Pipeline.
Uses mlx-vlm's Gemma3 implementation which has been validated to match PyTorch
with 0.999 correlation.
"""
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@@ -11,179 +15,61 @@ import mlx.nn as nn
from mlx_video.utils import rms_norm from mlx_video.utils import rms_norm
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@dataclass from mlx_vlm.models.gemma3.language import Gemma3Model
class Gemma3Config: from mlx_vlm.models.gemma3.config import TextConfig
"""Configuration for Gemma 3 text model."""
hidden_size: int = 3840
num_attention_heads: int = 16
num_key_value_heads: int = 8
head_dim: int = 256
intermediate_size: int = 15360
num_hidden_layers: int = 48
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
vocab_size: int = 262208
max_position_embeddings: int = 131072
class RMSNorm(nn.Module): class LanguageModel(nn.Module):
"""RMS Normalization (Gemma style with 1+weight scaling)."""
def __init__(self, dims: int, eps: float = 1e-6):
def __init__(self):
super().__init__() super().__init__()
self.eps = eps # Create config matching LTX-2 text encoder requirements
# Gemma initializes to ones, but uses (1+weight) scaling self.config = TextConfig(
# After loading weights, weight will have the actual learned values model_type="gemma3_text",
self.weight = mx.ones((dims,)) hidden_size=3840,
num_hidden_layers=48,
intermediate_size=15360,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=256,
rms_norm_eps=1e-6,
vocab_size=262208,
query_pre_attn_scalar=256,
rope_global_base_freq=1000000.0,
rope_local_base_freq=10000.0,
rope_traditional=False,
sliding_window=1024,
sliding_window_pattern=6,
)
def __call__(self, x: mx.array) -> mx.array: # Create the Gemma3Model from mlx-vlm
# Gemma-style RMSNorm uses (1 + weight) as the scale factor self.model = Gemma3Model(self.config)
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
def _create_causal_mask_with_padding(
def apply_rotary_emb(
q: mx.array,
k: mx.array,
positions: mx.array,
head_dim: int,
rope_theta: float = 1000000.0,
) -> Tuple[mx.array, mx.array]:
"""Apply rotary position embeddings to Q and K."""
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
cos = mx.cos(freqs)
sin = mx.sin(freqs)
cos = cos[:, :, None, :]
sin = sin[:, :, None, :]
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate([-x2, x1], axis=-1)
cos_full = mx.concatenate([cos, cos], axis=-1)
sin_full = mx.concatenate([sin, sin], axis=-1)
q_embed = q * cos_full + rotate_half(q) * sin_full
k_embed = k * cos_full + rotate_half(k) * sin_full
return q_embed, k_embed
class Gemma3MLP(nn.Module):
"""Gemma 3 MLP with gated activation."""
def __init__(self, config: Gemma3Config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
gate = nn.gelu_approx(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class Gemma3Attention(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.scale = 1.0 / math.sqrt(config.head_dim)
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def __call__(
self, self,
hidden_states: mx.array, seq_len: int,
positions: mx.array, attention_mask: Optional[mx.array],
attention_mask: Optional[mx.array] = None, dtype: mx.Dtype,
) -> mx.array: ) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states) causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# Create causal mask (lower triangular)
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
if attention_mask is not None: if attention_mask is not None:
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9 batch_size = attention_mask.shape[0]
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask) padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
out = mx.transpose(out, (0, 2, 1, 3)) combined = causal_mask[None, :, :] & padding_mask[:, None, :]
out = mx.reshape(out, (batch_size, seq_len, -1)) min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
return self.o_proj(out) mx.full(combined.shape, min_val, dtype=dtype))
return mask[:, None, :, :]
else:
class Gemma3DecoderLayer(nn.Module): # No padding mask, just causal
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
def __init__(self, config: Gemma3Config): mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
super().__init__() mx.full((seq_len, seq_len), min_val, dtype=dtype))
self.self_attn = Gemma3Attention(config) return mask[None, None, :, :] # (1, 1, seq, seq)
self.mlp = Gemma3MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
hidden_states: mx.array,
positions: mx.array,
attention_mask: Optional[mx.array] = None,
) -> mx.array:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma3TextModel(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Gemma scales embeddings by sqrt(hidden_size)
self.embed_scale = config.hidden_size ** 0.5
def __call__( def __call__(
self, self,
@@ -191,26 +77,69 @@ class Gemma3TextModel(nn.Module):
attention_mask: Optional[mx.array] = None, attention_mask: Optional[mx.array] = None,
output_hidden_states: bool = True, output_hidden_states: bool = True,
) -> Tuple[mx.array, List[mx.array]]: ) -> Tuple[mx.array, List[mx.array]]:
"""Forward pass returning hidden states.
Args:
input_ids: Input token IDs of shape (batch, seq_len)
attention_mask: Optional attention mask (1 for valid, 0 for padding)
output_hidden_states: Whether to return all hidden states
Returns:
Tuple of (final_hidden_states, list_of_all_hidden_states)
"""
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape
# Gemma scales embeddings by sqrt(hidden_size) # Get embeddings
hidden_states = self.embed_tokens(input_ids) * self.embed_scale h = self.model.embed_tokens(input_ids)
all_hidden_states = [hidden_states] if output_hidden_states else [] # Apply Gemma scaling
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
mx.eval(h)
positions = mx.arange(seq_len)[None, :].astype(mx.int32) all_hidden_states = [h] if output_hidden_states else []
positions = mx.broadcast_to(positions, (batch_size, seq_len))
for layer in self.layers: # Set up cache (all None for non-cached inference)
hidden_states = layer(hidden_states, positions, attention_mask) cache = [None] * len(self.model.layers)
if output_hidden_states:
all_hidden_states.append(hidden_states)
hidden_states = self.norm(hidden_states) full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
sliding_mask = full_causal_mask
num_layers = len(self.model.layers)
for i, layer in enumerate(self.model.layers):
is_global = (
i % self.config.sliding_window_pattern
== self.config.sliding_window_pattern - 1
)
# Select appropriate mask for this layer
if is_global:
local_mask = full_causal_mask
else:
local_mask = sliding_mask
h = layer(h, local_mask, None)
mx.eval(h)
if output_hidden_states and i < num_layers - 1:
all_hidden_states.append(h)
# Apply final norm
hidden_states = self.model.norm(h)
mx.eval(hidden_states)
# Append the final normalized output as the last hidden state
if output_hidden_states:
all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states return hidden_states, all_hidden_states
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True):
"""Load weights into the model."""
self.model.load_weights(weights, strict=strict)
class ConnectorAttention(nn.Module): class ConnectorAttention(nn.Module):
@@ -582,10 +511,7 @@ class LTX2TextEncoder(nn.Module):
self._model_path = model_path self._model_path = model_path
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.num_layers = num_layers self.num_layers = num_layers
self.model = LanguageModel()
# Gemma 3 model
self.config = Gemma3Config()
self.model = Gemma3TextModel(self.config)
# Feature extractor: 3840*49 -> 3840 # Feature extractor: 3840*49 -> 3840
self.feature_extractor = GemmaFeaturesExtractor( self.feature_extractor = GemmaFeaturesExtractor(
@@ -691,7 +617,6 @@ class LTX2TextEncoder(nn.Module):
if self.processor is None: if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.") raise RuntimeError("Model not loaded. Call load() first.")
# Tokenize with left padding (as in PyTorch version)
inputs = self.processor( inputs = self.processor(
prompt, prompt,
return_tensors="np", return_tensors="np",
@@ -702,28 +627,19 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"]) input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"]) attention_mask = mx.array(inputs["attention_mask"])
# Get all hidden states from Gemma
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True) _, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
# Normalize and concatenate all hidden states
concat_hidden = norm_and_concat_hidden_states( concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left" all_hidden_states, attention_mask, padding_side="left"
) )
# Project through feature extractor
features = self.feature_extractor(concat_hidden) features = self.feature_extractor(concat_hidden)
# Convert attention mask to additive format for connector
additive_mask = (attention_mask - 1).astype(features.dtype) additive_mask = (attention_mask - 1).astype(features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
# Process through connector
# Note: connector replaces padding with learnable registers and resets mask to zeros
# This means all positions now have valid embeddings (no need for final masking)
embeddings, _ = self.video_embeddings_connector(features, additive_mask) embeddings, _ = self.video_embeddings_connector(features, additive_mask)
# Return embeddings without zeroing - the connector's register replacement
# means all positions have meaningful values now
return embeddings, attention_mask return embeddings, attention_mask
def __call__( def __call__(

View File

@@ -18,7 +18,8 @@ dependencies = [
"transformers[tokenizers]", "transformers[tokenizers]",
"tqdm", "tqdm",
"opencv-python>=4.12.0.88", "opencv-python>=4.12.0.88",
"Pillow>=10.3.0" "Pillow>=10.3.0",
"mlx-vlm"
] ]
license = {text="MIT"} license = {text="MIT"}
authors = [ authors = [

1424
uv.lock generated

File diff suppressed because it is too large Load Diff