- Refactor video generation script
- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions. - Updated VAE weight sanitization for compatibility and improved activation function handling in text projection. - Added support for saving individual frames and refined output video generation process.
This commit is contained in:
@@ -9,7 +9,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.utils import rms_norm
|
||||
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
|
||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||
|
||||
@dataclass
|
||||
class Gemma3Config:
|
||||
@@ -240,7 +240,7 @@ class ConnectorAttention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
pe: Optional[mx.array] = None,
|
||||
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
@@ -255,25 +255,16 @@ class ConnectorAttention(nn.Module):
|
||||
|
||||
|
||||
if pe is not None:
|
||||
# pe: (1, seq_len, num_heads, head_dim, 2)
|
||||
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
q, k = apply_rotary_emb_1d(q, k, pe)
|
||||
# Reshape back for attention computation
|
||||
q = mx.reshape(q, (batch_size, seq_len, -1))
|
||||
k = mx.reshape(k, (batch_size, seq_len, -1))
|
||||
|
||||
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
|
||||
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
|
||||
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
|
||||
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
|
||||
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
|
||||
if attention_mask is not None:
|
||||
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
|
||||
# No mask needed for connector - after register replacement, all positions are valid
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||
|
||||
return self.to_out[0](out)
|
||||
@@ -365,29 +356,53 @@ class Embeddings1DConnector(nn.Module):
|
||||
if num_learnable_registers > 0:
|
||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
||||
|
||||
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||
"""
|
||||
import math
|
||||
|
||||
dim = self.num_heads * self.head_dim
|
||||
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
||||
theta = self.positional_embedding_theta
|
||||
n_elem = 2
|
||||
max_pos = [1] # Default for connector
|
||||
n_elem = 2 * len(max_pos) # = 2
|
||||
|
||||
|
||||
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
|
||||
indices = (theta ** linspace_vals) * (math.pi / 2)
|
||||
# Generate frequency indices (matches generate_freq_grid_pytorch)
|
||||
start = 1.0
|
||||
end = theta
|
||||
num_indices = dim // n_elem # 1920
|
||||
|
||||
log_start = math.log(start) / math.log(theta) # = 0
|
||||
log_end = math.log(end) / math.log(theta) # = 1
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
indices = (theta ** lin_space) * (math.pi / 2)
|
||||
|
||||
# Generate positions and compute freqs (matches generate_freqs)
|
||||
positions = mx.arange(seq_len).astype(mx.float32)
|
||||
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
|
||||
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
||||
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
||||
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
||||
|
||||
cos = mx.cos(freqs) # (seq_len, dim//2)
|
||||
sin = mx.sin(freqs)
|
||||
# freqs = indices * scaled_positions (outer product)
|
||||
# Shape: (seq_len, num_indices)
|
||||
freqs = scaled_positions[:, None] * indices[None, :]
|
||||
|
||||
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
||||
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
||||
cos_full = mx.repeat(cos_freq, 2, axis=-1)
|
||||
sin_full = mx.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
|
||||
return freqs_cis.astype(dtype)
|
||||
# Add batch dimension: (1, seq_len, dim)
|
||||
cos_full = cos_full[None, :, :]
|
||||
sin_full = sin_full[None, :, :]
|
||||
|
||||
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||
|
||||
def _replace_padded_with_registers(
|
||||
self,
|
||||
@@ -502,9 +517,8 @@ def norm_and_concat_hidden_states(
|
||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||
|
||||
# Compute masked min/max per layer
|
||||
large_val = 1e9
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
|
||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||
range_val = x_max - x_min
|
||||
|
||||
Reference in New Issue
Block a user