- Refactor video generation script

- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions.
- Updated VAE weight sanitization for compatibility and improved activation function handling in text projection.
- Added support for saving individual frames and refined output video generation process.
This commit is contained in:
Prince Canuma
2026-01-12 14:04:53 +01:00
parent d1ca36a315
commit 7114b023bd
6 changed files with 270 additions and 304 deletions

View File

@@ -9,7 +9,7 @@ import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import rms_norm
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@dataclass
class Gemma3Config:
@@ -240,7 +240,7 @@ class ConnectorAttention(nn.Module):
self,
x: mx.array,
attention_mask: Optional[mx.array] = None,
pe: Optional[mx.array] = None,
pe: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
batch_size, seq_len, _ = x.shape
@@ -255,25 +255,16 @@ class ConnectorAttention(nn.Module):
if pe is not None:
# pe: (1, seq_len, num_heads, head_dim, 2)
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
q, k = apply_rotary_emb_1d(q, k, pe)
# Reshape back for attention computation
q = mx.reshape(q, (batch_size, seq_len, -1))
k = mx.reshape(k, (batch_size, seq_len, -1))
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
if attention_mask is not None:
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
# No mask needed for connector - after register replacement, all positions are valid
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
return self.to_out[0](out)
@@ -365,29 +356,53 @@ class Embeddings1DConnector(nn.Module):
if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (INTERLEAVED type).
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
"""
import math
dim = self.num_heads * self.head_dim
dim = self.num_heads * self.head_dim # inner_dim = 3840
theta = self.positional_embedding_theta
n_elem = 2
max_pos = [1] # Default for connector
n_elem = 2 * len(max_pos) # = 2
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
indices = (theta ** linspace_vals) * (math.pi / 2)
# Generate frequency indices (matches generate_freq_grid_pytorch)
start = 1.0
end = theta
num_indices = dim // n_elem # 1920
log_start = math.log(start) / math.log(theta) # = 0
log_end = math.log(end) / math.log(theta) # = 1
lin_space = mx.linspace(log_start, log_end, num_indices)
indices = (theta ** lin_space) * (math.pi / 2)
# Generate positions and compute freqs (matches generate_freqs)
positions = mx.arange(seq_len).astype(mx.float32)
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
cos = mx.cos(freqs) # (seq_len, dim//2)
sin = mx.sin(freqs)
# freqs = indices * scaled_positions (outer product)
# Shape: (seq_len, num_indices)
freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
cos_full = mx.repeat(cos_freq, 2, axis=-1)
sin_full = mx.repeat(sin_freq, 2, axis=-1)
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
return freqs_cis.astype(dtype)
# Add batch dimension: (1, seq_len, dim)
cos_full = cos_full[None, :, :]
sin_full = sin_full[None, :, :]
return cos_full.astype(dtype), sin_full.astype(dtype)
def _replace_padded_with_registers(
self,
@@ -502,9 +517,8 @@ def norm_and_concat_hidden_states(
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer
large_val = 1e9
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min