Add frame number validation in video generation and update Gemma3 text encoder to use validated mlx-vlm implementation
This commit is contained in:
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
# ANSI color codes
|
# ANSI color codes
|
||||||
class Colors:
|
class Colors:
|
||||||
@@ -113,9 +114,10 @@ def denoise(
|
|||||||
text_embeddings: mx.array,
|
text_embeddings: mx.array,
|
||||||
transformer: LTXModel,
|
transformer: LTXModel,
|
||||||
sigmas: list,
|
sigmas: list,
|
||||||
|
verbose: bool = True,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Run denoising loop."""
|
"""Run denoising loop."""
|
||||||
for i in range(len(sigmas) - 1):
|
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
|
||||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||||
|
|
||||||
b, c, f, h, w = latents.shape
|
b, c, f, h, w = latents.shape
|
||||||
@@ -156,6 +158,7 @@ def generate_video(
|
|||||||
fps: int = 24,
|
fps: int = 24,
|
||||||
output_path: str = "output.mp4",
|
output_path: str = "output.mp4",
|
||||||
save_frames: bool = False,
|
save_frames: bool = False,
|
||||||
|
verbose: bool = True,
|
||||||
):
|
):
|
||||||
"""Generate video from text prompt.
|
"""Generate video from text prompt.
|
||||||
|
|
||||||
@@ -175,6 +178,12 @@ def generate_video(
|
|||||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||||
|
|
||||||
|
if num_frames % 8 != 1:
|
||||||
|
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
|
||||||
|
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
|
||||||
|
num_frames = adjusted_num_frames
|
||||||
|
|
||||||
|
|
||||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
@@ -236,7 +245,7 @@ def generate_video(
|
|||||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
|
||||||
|
|
||||||
# Upsample latents
|
# Upsample latents
|
||||||
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
||||||
@@ -265,7 +274,7 @@ def generate_video(
|
|||||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
|
||||||
|
|
||||||
del transformer
|
del transformer
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
@@ -295,7 +304,7 @@ def generate_video(
|
|||||||
for frame in video_np:
|
for frame in video_np:
|
||||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
out.release()
|
out.release()
|
||||||
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}")
|
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||||
|
|
||||||
@@ -308,6 +317,7 @@ def generate_video(
|
|||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
||||||
|
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
|
||||||
|
|
||||||
return video_np
|
return video_np
|
||||||
|
|
||||||
@@ -377,6 +387,11 @@ Examples:
|
|||||||
default="Lightricks/LTX-2",
|
default="Lightricks/LTX-2",
|
||||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Verbose output"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
generate_video(
|
generate_video(
|
||||||
@@ -389,6 +404,7 @@ Examples:
|
|||||||
fps=args.fps,
|
fps=args.fps,
|
||||||
output_path=args.output,
|
output_path=args.output,
|
||||||
save_frames=args.save_frames,
|
save_frames=args.save_frames,
|
||||||
|
verbose=args.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline."""
|
"""Gemma 3 Text Encoder for LTX-2 - Full Pipeline.
|
||||||
|
|
||||||
|
Uses mlx-vlm's Gemma3 implementation which has been validated to match PyTorch
|
||||||
|
with 0.999 correlation.
|
||||||
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -11,179 +15,61 @@ import mlx.nn as nn
|
|||||||
from mlx_video.utils import rms_norm
|
from mlx_video.utils import rms_norm
|
||||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||||
|
|
||||||
@dataclass
|
from mlx_vlm.models.gemma3.language import Gemma3Model
|
||||||
class Gemma3Config:
|
from mlx_vlm.models.gemma3.config import TextConfig
|
||||||
"""Configuration for Gemma 3 text model."""
|
|
||||||
hidden_size: int = 3840
|
|
||||||
num_attention_heads: int = 16
|
|
||||||
num_key_value_heads: int = 8
|
|
||||||
head_dim: int = 256
|
|
||||||
intermediate_size: int = 15360
|
|
||||||
num_hidden_layers: int = 48
|
|
||||||
rms_norm_eps: float = 1e-6
|
|
||||||
rope_theta: float = 1000000.0
|
|
||||||
vocab_size: int = 262208
|
|
||||||
max_position_embeddings: int = 131072
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class LanguageModel(nn.Module):
|
||||||
"""RMS Normalization (Gemma style with 1+weight scaling)."""
|
|
||||||
|
|
||||||
def __init__(self, dims: int, eps: float = 1e-6):
|
|
||||||
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.eps = eps
|
# Create config matching LTX-2 text encoder requirements
|
||||||
# Gemma initializes to ones, but uses (1+weight) scaling
|
self.config = TextConfig(
|
||||||
# After loading weights, weight will have the actual learned values
|
model_type="gemma3_text",
|
||||||
self.weight = mx.ones((dims,))
|
hidden_size=3840,
|
||||||
|
num_hidden_layers=48,
|
||||||
|
intermediate_size=15360,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
head_dim=256,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
vocab_size=262208,
|
||||||
|
query_pre_attn_scalar=256,
|
||||||
|
rope_global_base_freq=1000000.0,
|
||||||
|
rope_local_base_freq=10000.0,
|
||||||
|
rope_traditional=False,
|
||||||
|
sliding_window=1024,
|
||||||
|
sliding_window_pattern=6,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
# Create the Gemma3Model from mlx-vlm
|
||||||
# Gemma-style RMSNorm uses (1 + weight) as the scale factor
|
self.model = Gemma3Model(self.config)
|
||||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
|
||||||
|
|
||||||
|
def _create_causal_mask_with_padding(
|
||||||
def apply_rotary_emb(
|
|
||||||
q: mx.array,
|
|
||||||
k: mx.array,
|
|
||||||
positions: mx.array,
|
|
||||||
head_dim: int,
|
|
||||||
rope_theta: float = 1000000.0,
|
|
||||||
) -> Tuple[mx.array, mx.array]:
|
|
||||||
"""Apply rotary position embeddings to Q and K."""
|
|
||||||
inv_freq = 1.0 / (rope_theta ** (mx.arange(0, head_dim, 2).astype(mx.float32) / head_dim))
|
|
||||||
freqs = positions[:, :, None].astype(mx.float32) * inv_freq[None, None, :]
|
|
||||||
cos = mx.cos(freqs)
|
|
||||||
sin = mx.sin(freqs)
|
|
||||||
cos = cos[:, :, None, :]
|
|
||||||
sin = sin[:, :, None, :]
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return mx.concatenate([-x2, x1], axis=-1)
|
|
||||||
|
|
||||||
cos_full = mx.concatenate([cos, cos], axis=-1)
|
|
||||||
sin_full = mx.concatenate([sin, sin], axis=-1)
|
|
||||||
q_embed = q * cos_full + rotate_half(q) * sin_full
|
|
||||||
k_embed = k * cos_full + rotate_half(k) * sin_full
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3MLP(nn.Module):
|
|
||||||
"""Gemma 3 MLP with gated activation."""
|
|
||||||
|
|
||||||
def __init__(self, config: Gemma3Config):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
|
||||||
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
|
||||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
|
||||||
gate = nn.gelu_approx(self.gate_proj(x))
|
|
||||||
up = self.up_proj(x)
|
|
||||||
return self.down_proj(gate * up)
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3Attention(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: Gemma3Config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.num_heads = config.num_attention_heads
|
|
||||||
self.num_kv_heads = config.num_key_value_heads
|
|
||||||
self.head_dim = config.head_dim
|
|
||||||
self.scale = 1.0 / math.sqrt(config.head_dim)
|
|
||||||
|
|
||||||
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
||||||
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
||||||
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
|
||||||
|
|
||||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
||||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
self,
|
||||||
hidden_states: mx.array,
|
seq_len: int,
|
||||||
positions: mx.array,
|
attention_mask: Optional[mx.array],
|
||||||
attention_mask: Optional[mx.array] = None,
|
dtype: mx.Dtype,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
batch_size, seq_len, _ = hidden_states.shape
|
|
||||||
|
|
||||||
q = self.q_proj(hidden_states)
|
causal_mask = mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_))
|
||||||
k = self.k_proj(hidden_states)
|
|
||||||
v = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
||||||
k = mx.reshape(k, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
|
||||||
v = mx.reshape(v, (batch_size, seq_len, self.num_kv_heads, self.head_dim))
|
|
||||||
|
|
||||||
q = self.q_norm(q)
|
|
||||||
k = self.k_norm(k)
|
|
||||||
|
|
||||||
q, k = apply_rotary_emb(q, k, positions, self.head_dim, self.config.rope_theta)
|
|
||||||
|
|
||||||
q = mx.transpose(q, (0, 2, 1, 3))
|
|
||||||
k = mx.transpose(k, (0, 2, 1, 3))
|
|
||||||
v = mx.transpose(v, (0, 2, 1, 3))
|
|
||||||
|
|
||||||
# Create causal mask (lower triangular)
|
|
||||||
causal_mask = mx.triu(mx.full((seq_len, seq_len), -1e9, dtype=k.dtype), k=1)
|
|
||||||
causal_mask = causal_mask[None, None, :, :] # (1, 1, seq, seq
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask + (1.0 - attention_mask[:, None, None, :].astype(k.dtype)) * -1e9
|
batch_size = attention_mask.shape[0]
|
||||||
|
|
||||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=causal_mask)
|
padding_mask = attention_mask.astype(mx.bool_) # (batch, seq_len)
|
||||||
out = mx.transpose(out, (0, 2, 1, 3))
|
combined = causal_mask[None, :, :] & padding_mask[:, None, :]
|
||||||
out = mx.reshape(out, (batch_size, seq_len, -1))
|
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||||
|
mask = mx.where(combined, mx.zeros(combined.shape, dtype=dtype),
|
||||||
return self.o_proj(out)
|
mx.full(combined.shape, min_val, dtype=dtype))
|
||||||
|
return mask[:, None, :, :]
|
||||||
|
else:
|
||||||
class Gemma3DecoderLayer(nn.Module):
|
# No padding mask, just causal
|
||||||
|
min_val = mx.finfo(dtype).min if dtype in (mx.float16, mx.bfloat16) else -1e9
|
||||||
def __init__(self, config: Gemma3Config):
|
mask = mx.where(causal_mask, mx.zeros((seq_len, seq_len), dtype=dtype),
|
||||||
super().__init__()
|
mx.full((seq_len, seq_len), min_val, dtype=dtype))
|
||||||
self.self_attn = Gemma3Attention(config)
|
return mask[None, None, :, :] # (1, 1, seq, seq)
|
||||||
self.mlp = Gemma3MLP(config)
|
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
hidden_states: mx.array,
|
|
||||||
positions: mx.array,
|
|
||||||
attention_mask: Optional[mx.array] = None,
|
|
||||||
) -> mx.array:
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
hidden_states = self.self_attn(hidden_states, positions, attention_mask)
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3TextModel(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, config: Gemma3Config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
||||||
self.layers = [Gemma3DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
# Gemma scales embeddings by sqrt(hidden_size)
|
|
||||||
self.embed_scale = config.hidden_size ** 0.5
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -191,26 +77,69 @@ class Gemma3TextModel(nn.Module):
|
|||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
output_hidden_states: bool = True,
|
output_hidden_states: bool = True,
|
||||||
) -> Tuple[mx.array, List[mx.array]]:
|
) -> Tuple[mx.array, List[mx.array]]:
|
||||||
|
"""Forward pass returning hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token IDs of shape (batch, seq_len)
|
||||||
|
attention_mask: Optional attention mask (1 for valid, 0 for padding)
|
||||||
|
output_hidden_states: Whether to return all hidden states
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (final_hidden_states, list_of_all_hidden_states)
|
||||||
|
"""
|
||||||
batch_size, seq_len = input_ids.shape
|
batch_size, seq_len = input_ids.shape
|
||||||
|
|
||||||
# Gemma scales embeddings by sqrt(hidden_size)
|
# Get embeddings
|
||||||
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
h = self.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
all_hidden_states = [hidden_states] if output_hidden_states else []
|
# Apply Gemma scaling
|
||||||
|
h *= mx.array(self.config.hidden_size**0.5, mx.bfloat16).astype(h.dtype)
|
||||||
|
mx.eval(h)
|
||||||
|
|
||||||
positions = mx.arange(seq_len)[None, :].astype(mx.int32)
|
all_hidden_states = [h] if output_hidden_states else []
|
||||||
positions = mx.broadcast_to(positions, (batch_size, seq_len))
|
|
||||||
|
|
||||||
for layer in self.layers:
|
# Set up cache (all None for non-cached inference)
|
||||||
hidden_states = layer(hidden_states, positions, attention_mask)
|
cache = [None] * len(self.model.layers)
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states.append(hidden_states)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
full_causal_mask = self._create_causal_mask_with_padding(seq_len, attention_mask, h.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
sliding_mask = full_causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
num_layers = len(self.model.layers)
|
||||||
|
for i, layer in enumerate(self.model.layers):
|
||||||
|
is_global = (
|
||||||
|
i % self.config.sliding_window_pattern
|
||||||
|
== self.config.sliding_window_pattern - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select appropriate mask for this layer
|
||||||
|
if is_global:
|
||||||
|
local_mask = full_causal_mask
|
||||||
|
else:
|
||||||
|
local_mask = sliding_mask
|
||||||
|
|
||||||
|
h = layer(h, local_mask, None)
|
||||||
|
mx.eval(h)
|
||||||
|
|
||||||
|
if output_hidden_states and i < num_layers - 1:
|
||||||
|
all_hidden_states.append(h)
|
||||||
|
|
||||||
|
# Apply final norm
|
||||||
|
hidden_states = self.model.norm(h)
|
||||||
|
mx.eval(hidden_states)
|
||||||
|
|
||||||
|
# Append the final normalized output as the last hidden state
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states.append(hidden_states)
|
||||||
|
|
||||||
return hidden_states, all_hidden_states
|
return hidden_states, all_hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: List[Tuple[str, mx.array]], strict: bool = True):
|
||||||
|
"""Load weights into the model."""
|
||||||
|
self.model.load_weights(weights, strict=strict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectorAttention(nn.Module):
|
class ConnectorAttention(nn.Module):
|
||||||
@@ -582,10 +511,7 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
self._model_path = model_path
|
self._model_path = model_path
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
self.model = LanguageModel()
|
||||||
# Gemma 3 model
|
|
||||||
self.config = Gemma3Config()
|
|
||||||
self.model = Gemma3TextModel(self.config)
|
|
||||||
|
|
||||||
# Feature extractor: 3840*49 -> 3840
|
# Feature extractor: 3840*49 -> 3840
|
||||||
self.feature_extractor = GemmaFeaturesExtractor(
|
self.feature_extractor = GemmaFeaturesExtractor(
|
||||||
@@ -691,7 +617,6 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
if self.processor is None:
|
if self.processor is None:
|
||||||
raise RuntimeError("Model not loaded. Call load() first.")
|
raise RuntimeError("Model not loaded. Call load() first.")
|
||||||
|
|
||||||
# Tokenize with left padding (as in PyTorch version)
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
prompt,
|
prompt,
|
||||||
return_tensors="np",
|
return_tensors="np",
|
||||||
@@ -702,28 +627,19 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
input_ids = mx.array(inputs["input_ids"])
|
input_ids = mx.array(inputs["input_ids"])
|
||||||
attention_mask = mx.array(inputs["attention_mask"])
|
attention_mask = mx.array(inputs["attention_mask"])
|
||||||
|
|
||||||
# Get all hidden states from Gemma
|
|
||||||
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
|
_, all_hidden_states = self.model(input_ids, attention_mask, output_hidden_states=True)
|
||||||
|
|
||||||
# Normalize and concatenate all hidden states
|
|
||||||
concat_hidden = norm_and_concat_hidden_states(
|
concat_hidden = norm_and_concat_hidden_states(
|
||||||
all_hidden_states, attention_mask, padding_side="left"
|
all_hidden_states, attention_mask, padding_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Project through feature extractor
|
|
||||||
features = self.feature_extractor(concat_hidden)
|
features = self.feature_extractor(concat_hidden)
|
||||||
|
|
||||||
# Convert attention mask to additive format for connector
|
|
||||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||||
|
|
||||||
# Process through connector
|
|
||||||
# Note: connector replaces padding with learnable registers and resets mask to zeros
|
|
||||||
# This means all positions now have valid embeddings (no need for final masking)
|
|
||||||
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
embeddings, _ = self.video_embeddings_connector(features, additive_mask)
|
||||||
|
|
||||||
# Return embeddings without zeroing - the connector's register replacement
|
|
||||||
# means all positions have meaningful values now
|
|
||||||
return embeddings, attention_mask
|
return embeddings, attention_mask
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ dependencies = [
|
|||||||
"transformers[tokenizers]",
|
"transformers[tokenizers]",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"opencv-python>=4.12.0.88",
|
"opencv-python>=4.12.0.88",
|
||||||
"Pillow>=10.3.0"
|
"Pillow>=10.3.0",
|
||||||
|
"mlx-vlm"
|
||||||
]
|
]
|
||||||
license = {text="MIT"}
|
license = {text="MIT"}
|
||||||
authors = [
|
authors = [
|
||||||
|
|||||||
Reference in New Issue
Block a user