- Refactor video generation script

- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions.
- Updated VAE weight sanitization for compatibility and improved activation function handling in text projection.
- Added support for saving individual frames and refined output video generation process.
This commit is contained in:
Prince Canuma
2026-01-12 14:04:53 +01:00
parent d1ca36a315
commit 7114b023bd
6 changed files with 270 additions and 304 deletions

View File

@@ -161,7 +161,7 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE
Dictionary with MLX-compatible naming for VAE decoder
"""
sanitized = {}
@@ -172,17 +172,40 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if "position_ids" in key:
continue
# Only process VAE decoder weights (skip audio_vae, etc.)
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
continue
elif key.startswith("vae.decoder."):
# Strip the vae.decoder. prefix for decoder weights
new_key = key.replace("vae.decoder.", "")
else:
# Skip other vae.* keys that are not decoder weights
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value

View File

@@ -381,8 +381,7 @@ def precompute_freqs_cis(
if max_pos is None:
max_pos = [20, 2048, 2048]
# For double precision, compute in numpy (float64) then convert back to MLX
# MLX GPU doesn't support float64, so we use numpy for high precision computation
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
@@ -418,10 +417,7 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies in double precision using numpy.
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
"""
# Convert to numpy float64
indices_grid_np = np.array(indices_grid).astype(np.float64)

View File

@@ -9,7 +9,7 @@ import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import rms_norm
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@dataclass
class Gemma3Config:
@@ -240,7 +240,7 @@ class ConnectorAttention(nn.Module):
self,
x: mx.array,
attention_mask: Optional[mx.array] = None,
pe: Optional[mx.array] = None,
pe: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
batch_size, seq_len, _ = x.shape
@@ -255,25 +255,16 @@ class ConnectorAttention(nn.Module):
if pe is not None:
# pe: (1, seq_len, num_heads, head_dim, 2)
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
q, k = apply_rotary_emb_1d(q, k, pe)
# Reshape back for attention computation
q = mx.reshape(q, (batch_size, seq_len, -1))
k = mx.reshape(k, (batch_size, seq_len, -1))
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
if attention_mask is not None:
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
# No mask needed for connector - after register replacement, all positions are valid
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
return self.to_out[0](out)
@@ -365,29 +356,53 @@ class Embeddings1DConnector(nn.Module):
if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (INTERLEAVED type).
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
"""
import math
dim = self.num_heads * self.head_dim
dim = self.num_heads * self.head_dim # inner_dim = 3840
theta = self.positional_embedding_theta
n_elem = 2
max_pos = [1] # Default for connector
n_elem = 2 * len(max_pos) # = 2
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
indices = (theta ** linspace_vals) * (math.pi / 2)
# Generate frequency indices (matches generate_freq_grid_pytorch)
start = 1.0
end = theta
num_indices = dim // n_elem # 1920
log_start = math.log(start) / math.log(theta) # = 0
log_end = math.log(end) / math.log(theta) # = 1
lin_space = mx.linspace(log_start, log_end, num_indices)
indices = (theta ** lin_space) * (math.pi / 2)
# Generate positions and compute freqs (matches generate_freqs)
positions = mx.arange(seq_len).astype(mx.float32)
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
cos = mx.cos(freqs) # (seq_len, dim//2)
sin = mx.sin(freqs)
# freqs = indices * scaled_positions (outer product)
# Shape: (seq_len, num_indices)
freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
cos_full = mx.repeat(cos_freq, 2, axis=-1)
sin_full = mx.repeat(sin_freq, 2, axis=-1)
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
return freqs_cis.astype(dtype)
# Add batch dimension: (1, seq_len, dim)
cos_full = cos_full[None, :, :]
sin_full = sin_full[None, :, :]
return cos_full.astype(dtype), sin_full.astype(dtype)
def _replace_padded_with_registers(
self,
@@ -502,9 +517,8 @@ def norm_and_concat_hidden_states(
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer
large_val = 1e9
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min

View File

@@ -16,7 +16,7 @@ class PixArtAlphaTextProjection(nn.Module):
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise")
self.act = nn.GELU(approx="tanh") # Must match PyTorch's approximate="tanh"
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:

View File

@@ -10,13 +10,19 @@ class PixArtAlphaTextProjection(nn.Module):
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
act_fn: str = "gelu_tanh",
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise")
if act_fn == "gelu_tanh":
self.act = nn.GELU(approx="tanh")
elif act_fn == "silu":
self.act = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array: