Add frame number validation in video generation and update Gemma3 text encoder to use validated mlx-vlm implementation

This commit is contained in:
Prince Canuma
2026-01-13 17:12:11 +01:00
parent 61b003ff2c
commit 01d895bc77
4 changed files with 1546 additions and 197 deletions

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import mlx.core as mx
import numpy as np
from PIL import Image
from tqdm import tqdm
# ANSI color codes
class Colors:
@@ -113,9 +114,10 @@ def denoise(
text_embeddings: mx.array,
transformer: LTXModel,
sigmas: list,
verbose: bool = True,
) -> mx.array:
"""Run denoising loop."""
for i in range(len(sigmas) - 1):
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
@@ -156,6 +158,7 @@ def generate_video(
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
verbose: bool = True,
):
"""Generate video from text prompt.
@@ -175,6 +178,12 @@ def generate_video(
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
@@ -236,7 +245,7 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
# Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
@@ -265,7 +274,7 @@ def generate_video(
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
del transformer
mx.clear_cache()
@@ -295,7 +304,7 @@ def generate_video(
for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}")
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
@@ -308,6 +317,7 @@ def generate_video(
elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np
@@ -377,6 +387,11 @@ Examples:
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Verbose output"
)
args = parser.parse_args()
generate_video(
@@ -389,6 +404,7 @@ Examples:
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
verbose=args.verbose,
)

View File

@@ -1,4 +1,8 @@
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline.
Uses mlx-vlm's Gemma3 implementation which has been validated to match PyTorch
with 0.999 correlation.
"""
import math
from dataclasses import dataclass
@@ -11,179 +15,61 @@ import mlx.nn as nn
from mlx_video.utils import rms_norm
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@dataclass
class Gemma3Config:
"""Configuration for Gemma 3 text model."""
hidden_size: int = 3840
num_attention_heads: int = 16
num_key_value_heads: int = 8
head_dim: int = 256
intermediate_size: int = 15360
num_hidden_layers: int = 48
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
vocab_size: int = 262208
max_position_embeddings: int = 131072
from mlx_vlm.models.gemma3.language import Gemma3Model
from mlx_vlm.models.gemma3.config import TextConfig
class RMSNorm(nn.Module):
"""RMS Normalization (Gemma style with 1+weight scaling)."""
class LanguageModel(nn.Module):
def __init__(self, dims: int, eps: float = 1e-6):
def __init__(self):
super().__init__()
self.eps = eps
# Gemma initializes to ones, but uses (1+weight) scaling
# After loading weights, weight will have the actual learned values
self.weight = mx.ones((dims,))
# Create config matching LTX-2 text encoder requirements
self.config = TextConfig(
model_type="gemma3_text",
hidden_size=3840,
num_hidden_layers=48,
intermediate_size=15360,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=256,
rms_norm_eps=1e-6,
vocab_size=262208,
query_pre_attn_scalar=256,
rope_global_base_freq=1000000.0,
rope_local_base_freq=10000.0,
rope_traditional=False,
sliding_window=1024,
sliding_window_pattern=6,
)
def __call__(self, x: mx.array) -> mx.array:
# Gemma-style RMSNorm uses (1 + weight) as the scale factor
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
# Create the Gemma3Model from mlx-vlm
self.model = Gemma3Model(self.config)
def apply_rotary_emb(
q: mx.array,
k: mx.array,
positions: mx.array,
head_dim: int,
rope_theta: float = 1000000.0,
) -> Tuple[mx.array, mx.array]:
"""Apply rotary position embeddings to Q and K."""
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
cos = mx.cos(freqs)
sin = mx.sin(freqs)
cos = cos[:, :, None, :]
sin = sin[:, :, None, :]
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return mx.concatenate([-x2, x1], axis=-1)
cos_full = mx.concatenate([cos, cos], axis=-1)
sin_full = mx.concatenate([sin, sin], axis=-1)
q_embed = q * cos_full + rotate_half(q) * sin_full
k_embed = k * cos_full + rotate_half(k) * sin_full
return q_embed, k_embed
class Gemma3MLP(nn.Module):
"""Gemma 3 MLP with gated activation."""
def __init__(self, config: Gemma3Config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
gate = nn.gelu_approx(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class Gemma3Attention(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.scale = 1.0 / math.sqrt(config.head_dim)
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def __call__(
def _create_causal_mask_with_padding(
self,
hidden_states: mx.array,
positions: mx.array,
attention_mask: Optional[mx.array] = None,
seq_len: int,
attention_mask: Optional[mx.array],
dtype: mx.Dtype,
) -> mx.array:
batch_size, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
q = mx.transpose(q, (0, 2, 1, 3))
k = mx.transpose(k, (0, 2, 1, 3))
v = mx.transpose(v, (0, 2, 1, 3))
# Create causal mask (lower triangular)
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
if attention_mask is not None:
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9
batch_size = attention_mask.shape[0]
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask)
out = mx.transpose(out, (0, 2, 1, 3))
out = mx.reshape(out, (batch_size, seq_len, -1))
return self.o_proj(out)
class Gemma3DecoderLayer(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.self_attn = Gemma3Attention(config)
self.mlp = Gemma3MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
hidden_states: mx.array,
positions: mx.array,
attention_mask: Optional[mx.array] = None,
) -> mx.array:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma3TextModel(nn.Module):
def __init__(self, config: Gemma3Config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Gemma scales embeddings by sqrt(hidden_size)
self.embed_scale = config.hidden_size ** 0.5
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
mx.full(combined.shape, min_val, dtype=dtype))
return mask[:, None, :, :]
else:
# No padding mask, just causal
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
mx.full((seq_len, seq_len), min_val, dtype=dtype))
return mask[None, None, :, :] # (1, 1, seq, seq)
def __call__(
self,
@@ -191,26 +77,69 @@ class Gemma3TextModel(nn.Module):
attention_mask: Optional[mx.array] = None,
output_hidden_states: bool = True,
) -> Tuple[mx.array, List[mx.array]]:
"""Forward pass returning hidden states.
Args:
input_ids: Input token IDs of shape (batch, seq_len)
attention_mask: Optional attention mask (1 for valid, 0 for padding)
output_hidden_states: Whether to return all hidden states
Returns:
Tuple of (final_hidden_states, list_of_all_hidden_states)
"""
batch_size, seq_len = input_ids.shape
# Gemma scales embeddings by sqrt(hidden_size)
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
# Get embeddings
h = self.model.embed_tokens(input_ids)
all_hidden_states = [hidden_states] if output_hidden_states else []
# Apply Gemma scaling
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
mx.eval(h)
positions = mx.arange(seq_len)[None, :].astype(mx.int32)
positions = mx.broadcast_to(positions, (batch_size, seq_len))
all_hidden_states = [h] if output_hidden_states else []
for layer in self.layers:
hidden_states = layer(hidden_states, positions, attention_mask)
if output_hidden_states:
all_hidden_states.append(hidden_states)
# Set up cache (all None for non-cached inference)
cache = [None] * len(self.model.layers)
hidden_states = self.norm(hidden_states)
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
sliding_mask = full_causal_mask
num_layers = len(self.model.layers)
for i, layer in enumerate(self.model.layers):
is_global = (
i % self.config.sliding_window_pattern
== self.config.sliding_window_pattern - 1
)
# Select appropriate mask for this layer
if is_global:
local_mask = full_causal_mask
else:
local_mask = sliding_mask
h = layer(h, local_mask, None)
mx.eval(h)
if output_hidden_states and i < num_layers - 1:
all_hidden_states.append(h)
# Apply final norm
hidden_states = self.model.norm(h)
mx.eval(hidden_states)
# Append the final normalized output as the last hidden state
if output_hidden_states:
all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True):
"""Load weights into the model."""
self.model.load_weights(weights, strict=strict)
class ConnectorAttention(nn.Module):
@@ -582,10 +511,7 @@ class LTX2TextEncoder(nn.Module):
self._model_path = model_path
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# Gemma 3 model
self.config = Gemma3Config()
self.model = Gemma3TextModel(self.config)
self.model = LanguageModel()
# Feature extractor: 3840*49 -> 3840
self.feature_extractor = GemmaFeaturesExtractor(
@@ -691,7 +617,6 @@ class LTX2TextEncoder(nn.Module):
if self.processor is None:
raise RuntimeError("Model not loaded. Call load() first.")
# Tokenize with left padding (as in PyTorch version)
inputs = self.processor(
prompt,
return_tensors="np",
@@ -702,28 +627,19 @@ class LTX2TextEncoder(nn.Module):
input_ids = mx.array(inputs["input_ids"])
attention_mask = mx.array(inputs["attention_mask"])
# Get all hidden states from Gemma
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
# Normalize and concatenate all hidden states
concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left"
)
# Project through feature extractor
features = self.feature_extractor(concat_hidden)
# Convert attention mask to additive format for connector
additive_mask = (attention_mask - 1).astype(features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
# Process through connector
# Note: connector replaces padding with learnable registers and resets mask to zeros
# This means all positions now have valid embeddings (no need for final masking)
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
# Return embeddings without zeroing - the connector's register replacement
# means all positions have meaningful values now
return embeddings, attention_mask
def __call__(

View File

@@ -18,7 +18,8 @@ dependencies = [
"transformers[tokenizers]",
"tqdm",
"opencv-python>=4.12.0.88",
"Pillow>=10.3.0"
"Pillow>=10.3.0",
"mlx-vlm"
]
license = {text="MIT"}
authors = [

1424
uv.lock generated

File diff suppressed because it is too large Load Diff