Add frame number validation in video generation and update Gemma3 text encoder to use validated mlx-vlm implementation
This commit is contained in:
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# ANSI color codes
|
||||
class Colors:
|
||||
@@ -113,9 +114,10 @@ def denoise(
|
||||
text_embeddings: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigmas: list,
|
||||
verbose: bool = True,
|
||||
) -> mx.array:
|
||||
"""Run denoising loop."""
|
||||
for i in range(len(sigmas) - 1):
|
||||
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
|
||||
b, c, f, h, w = latents.shape
|
||||
@@ -156,6 +158,7 @@ def generate_video(
|
||||
fps: int = 24,
|
||||
output_path: str = "output.mp4",
|
||||
save_frames: bool = False,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Generate video from text prompt.
|
||||
|
||||
@@ -174,6 +177,12 @@ def generate_video(
|
||||
# Validate dimensions
|
||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||
|
||||
if num_frames % 8 != 1:
|
||||
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
|
||||
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
|
||||
num_frames = adjusted_num_frames
|
||||
|
||||
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||
@@ -236,7 +245,7 @@ def generate_video(
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
|
||||
|
||||
# Upsample latents
|
||||
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
||||
@@ -265,7 +274,7 @@ def generate_video(
|
||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||
mx.eval(latents)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
@@ -295,7 +304,7 @@ def generate_video(
|
||||
for frame in video_np:
|
||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
out.release()
|
||||
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}")
|
||||
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||
|
||||
@@ -308,6 +317,7 @@ def generate_video(
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
||||
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
|
||||
|
||||
return video_np
|
||||
|
||||
@@ -377,6 +387,11 @@ Examples:
|
||||
default="Lightricks/LTX-2",
|
||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
@@ -389,6 +404,7 @@ Examples:
|
||||
fps=args.fps,
|
||||
output_path=args.output,
|
||||
save_frames=args.save_frames,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
|
||||
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline.
|
||||
|
||||
Uses mlx-vlm's Gemma3 implementation which has been validated to match PyTorch
|
||||
with 0.999 correlation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
@@ -11,179 +15,61 @@ import mlx.nn as nn
|
||||
from mlx_video.utils import rms_norm
|
||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||
|
||||
@dataclass
|
||||
class Gemma3Config:
|
||||
"""Configuration for Gemma 3 text model."""
|
||||
hidden_size: int = 3840
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
head_dim: int = 256
|
||||
intermediate_size: int = 15360
|
||||
num_hidden_layers: int = 48
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
vocab_size: int = 262208
|
||||
max_position_embeddings: int = 131072
|
||||
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||
from mlx_vlm.models.gemma3.config import TextConfig
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""RMS Normalization (Gemma style with 1+weight scaling)."""
|
||||
class LanguageModel(nn.Module):
|
||||
|
||||
def __init__(self, dims: int, eps: float = 1e-6):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
# Gemma initializes to ones, but uses (1+weight) scaling
|
||||
# After loading weights, weight will have the actual learned values
|
||||
self.weight = mx.ones((dims,))
|
||||
# Create config matching LTX-2 text encoder requirements
|
||||
self.config = TextConfig(
|
||||
model_type="gemma3_text",
|
||||
hidden_size=3840,
|
||||
num_hidden_layers=48,
|
||||
intermediate_size=15360,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=8,
|
||||
head_dim=256,
|
||||
rms_norm_eps=1e-6,
|
||||
vocab_size=262208,
|
||||
query_pre_attn_scalar=256,
|
||||
rope_global_base_freq=1000000.0,
|
||||
rope_local_base_freq=10000.0,
|
||||
rope_traditional=False,
|
||||
sliding_window=1024,
|
||||
sliding_window_pattern=6,
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# Gemma-style RMSNorm uses (1 + weight) as the scale factor
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
# Create the Gemma3Model from mlx-vlm
|
||||
self.model = Gemma3Model(self.config)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
q: mx.array,
|
||||
k: mx.array,
|
||||
positions: mx.array,
|
||||
head_dim: int,
|
||||
rope_theta: float = 1000000.0,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Apply rotary position embeddings to Q and K."""
|
||||
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
|
||||
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
|
||||
cos = mx.cos(freqs)
|
||||
sin = mx.sin(freqs)
|
||||
cos = cos[:, :, None, :]
|
||||
sin = sin[:, :, None, :]
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return mx.concatenate([-x2, x1], axis=-1)
|
||||
|
||||
cos_full = mx.concatenate([cos, cos], axis=-1)
|
||||
sin_full = mx.concatenate([sin, sin], axis=-1)
|
||||
q_embed = q * cos_full + rotate_half(q) * sin_full
|
||||
k_embed = k * cos_full + rotate_half(k) * sin_full
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
|
||||
class Gemma3MLP(nn.Module):
|
||||
"""Gemma 3 MLP with gated activation."""
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
gate = nn.gelu_approx(self.gate_proj(x))
|
||||
up = self.up_proj(x)
|
||||
return self.down_proj(gate * up)
|
||||
|
||||
|
||||
class Gemma3Attention(nn.Module):
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.scale = 1.0 / math.sqrt(config.head_dim)
|
||||
|
||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
def _create_causal_mask_with_padding(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
positions: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
seq_len: int,
|
||||
attention_mask: Optional[mx.array],
|
||||
dtype: mx.Dtype,
|
||||
) -> mx.array:
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
|
||||
|
||||
q = mx.transpose(q, (0, 2, 1, 3))
|
||||
k = mx.transpose(k, (0, 2, 1, 3))
|
||||
v = mx.transpose(v, (0, 2, 1, 3))
|
||||
|
||||
# Create causal mask (lower triangular)
|
||||
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
|
||||
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
|
||||
|
||||
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
|
||||
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask)
|
||||
out = mx.transpose(out, (0, 2, 1, 3))
|
||||
out = mx.reshape(out, (batch_size, seq_len, -1))
|
||||
|
||||
return self.o_proj(out)
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
self.self_attn = Gemma3Attention(config)
|
||||
self.mlp = Gemma3MLP(config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
positions: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Gemma3TextModel(nn.Module):
|
||||
|
||||
def __init__(self, config: Gemma3Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
# Gemma scales embeddings by sqrt(hidden_size)
|
||||
self.embed_scale = config.hidden_size ** 0.5
|
||||
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
|
||||
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
|
||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
|
||||
mx.full(combined.shape, min_val, dtype=dtype))
|
||||
return mask[:, None, :, :]
|
||||
else:
|
||||
# No padding mask, just causal
|
||||
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||
mx.full((seq_len, seq_len), min_val, dtype=dtype))
|
||||
return mask[None, None, :, :] # (1, 1, seq, seq)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -191,26 +77,69 @@ class Gemma3TextModel(nn.Module):
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
output_hidden_states: bool = True,
|
||||
) -> Tuple[mx.array, List[mx.array]]:
|
||||
|
||||
"""Forward pass returning hidden states.
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs of shape (batch, seq_len)
|
||||
attention_mask: Optional attention mask (1 for valid, 0 for padding)
|
||||
output_hidden_states: Whether to return all hidden states
|
||||
|
||||
Returns:
|
||||
Tuple of (final_hidden_states, list_of_all_hidden_states)
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Gemma scales embeddings by sqrt(hidden_size)
|
||||
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
||||
# Get embeddings
|
||||
h = self.model.embed_tokens(input_ids)
|
||||
|
||||
all_hidden_states = [hidden_states] if output_hidden_states else []
|
||||
# Apply Gemma scaling
|
||||
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||
mx.eval(h)
|
||||
|
||||
positions = mx.arange(seq_len)[None, :].astype(mx.int32)
|
||||
positions = mx.broadcast_to(positions, (batch_size, seq_len))
|
||||
all_hidden_states = [h] if output_hidden_states else []
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, positions, attention_mask)
|
||||
if output_hidden_states:
|
||||
all_hidden_states.append(hidden_states)
|
||||
# Set up cache (all None for non-cached inference)
|
||||
cache = [None] * len(self.model.layers)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
||||
|
||||
|
||||
sliding_mask = full_causal_mask
|
||||
|
||||
|
||||
num_layers = len(self.model.layers)
|
||||
for i, layer in enumerate(self.model.layers):
|
||||
is_global = (
|
||||
i % self.config.sliding_window_pattern
|
||||
== self.config.sliding_window_pattern - 1
|
||||
)
|
||||
|
||||
# Select appropriate mask for this layer
|
||||
if is_global:
|
||||
local_mask = full_causal_mask
|
||||
else:
|
||||
local_mask = sliding_mask
|
||||
|
||||
h = layer(h, local_mask, None)
|
||||
mx.eval(h)
|
||||
|
||||
if output_hidden_states and i < num_layers - 1:
|
||||
all_hidden_states.append(h)
|
||||
|
||||
# Apply final norm
|
||||
hidden_states = self.model.norm(h)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
# Append the final normalized output as the last hidden state
|
||||
if output_hidden_states:
|
||||
all_hidden_states.append(hidden_states)
|
||||
|
||||
return hidden_states, all_hidden_states
|
||||
|
||||
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True):
|
||||
"""Load weights into the model."""
|
||||
self.model.load_weights(weights, strict=strict)
|
||||
|
||||
|
||||
|
||||
class ConnectorAttention(nn.Module):
|
||||
@@ -582,10 +511,7 @@ class LTX2TextEncoder(nn.Module):
|
||||
self._model_path = model_path
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layers = num_layers
|
||||
|
||||
# Gemma 3 model
|
||||
self.config = Gemma3Config()
|
||||
self.model = Gemma3TextModel(self.config)
|
||||
self.model = LanguageModel()
|
||||
|
||||
# Feature extractor: 3840*49 -> 3840
|
||||
self.feature_extractor = GemmaFeaturesExtractor(
|
||||
@@ -691,7 +617,6 @@ class LTX2TextEncoder(nn.Module):
|
||||
if self.processor is None:
|
||||
raise RuntimeError("Model not loaded. Call load() first.")
|
||||
|
||||
# Tokenize with left padding (as in PyTorch version)
|
||||
inputs = self.processor(
|
||||
prompt,
|
||||
return_tensors="np",
|
||||
@@ -702,28 +627,19 @@ class LTX2TextEncoder(nn.Module):
|
||||
input_ids = mx.array(inputs["input_ids"])
|
||||
attention_mask = mx.array(inputs["attention_mask"])
|
||||
|
||||
# Get all hidden states from Gemma
|
||||
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
|
||||
|
||||
# Normalize and concatenate all hidden states
|
||||
concat_hidden = norm_and_concat_hidden_states(
|
||||
all_hidden_states, attention_mask, padding_side="left"
|
||||
)
|
||||
|
||||
# Project through feature extractor
|
||||
features = self.feature_extractor(concat_hidden)
|
||||
|
||||
# Convert attention mask to additive format for connector
|
||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||
|
||||
# Process through connector
|
||||
# Note: connector replaces padding with learnable registers and resets mask to zeros
|
||||
# This means all positions now have valid embeddings (no need for final masking)
|
||||
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||
|
||||
# Return embeddings without zeroing - the connector's register replacement
|
||||
# means all positions have meaningful values now
|
||||
return embeddings, attention_mask
|
||||
|
||||
def __call__(
|
||||
|
||||
Reference in New Issue
Block a user