- Refactor video generation script
- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions. - Updated VAE weight sanitization for compatibility and improved activation function handling in text projection. - Added support for saving individual frames and refined output video generation process.
This commit is contained in:
@@ -161,7 +161,7 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for VAE
|
||||
Dictionary with MLX-compatible naming for VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
@@ -172,17 +172,40 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Only process VAE decoder weights (skip audio_vae, etc.)
|
||||
if not key.startswith("vae."):
|
||||
continue
|
||||
|
||||
# Handle per-channel statistics key mapping
|
||||
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
|
||||
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
|
||||
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
|
||||
if "vae.per_channel_statistics" in key:
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
|
||||
continue
|
||||
elif key.startswith("vae.decoder."):
|
||||
# Strip the vae.decoder. prefix for decoder weights
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
else:
|
||||
# Skip other vae.* keys that are not decoder weights
|
||||
continue
|
||||
|
||||
# Handle Conv3d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
@@ -381,8 +381,7 @@ def precompute_freqs_cis(
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
# For double precision, compute in numpy (float64) then convert back to MLX
|
||||
# MLX GPU doesn't support float64, so we use numpy for high precision computation
|
||||
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
@@ -418,10 +417,7 @@ def _precompute_freqs_cis_double_precision(
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies in double precision using numpy.
|
||||
|
||||
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
|
||||
"""
|
||||
# Convert to numpy float64
|
||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.utils import rms_norm
|
||||
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
|
||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||
|
||||
@dataclass
|
||||
class Gemma3Config:
|
||||
@@ -240,7 +240,7 @@ class ConnectorAttention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
pe: Optional[mx.array] = None,
|
||||
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
@@ -255,25 +255,16 @@ class ConnectorAttention(nn.Module):
|
||||
|
||||
|
||||
if pe is not None:
|
||||
# pe: (1, seq_len, num_heads, head_dim, 2)
|
||||
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
q, k = apply_rotary_emb_1d(q, k, pe)
|
||||
# Reshape back for attention computation
|
||||
q = mx.reshape(q, (batch_size, seq_len, -1))
|
||||
k = mx.reshape(k, (batch_size, seq_len, -1))
|
||||
|
||||
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
|
||||
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
|
||||
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
|
||||
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
|
||||
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
|
||||
if attention_mask is not None:
|
||||
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
|
||||
# No mask needed for connector - after register replacement, all positions are valid
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||
|
||||
return self.to_out[0](out)
|
||||
@@ -365,29 +356,53 @@ class Embeddings1DConnector(nn.Module):
|
||||
if num_learnable_registers > 0:
|
||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
||||
|
||||
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||
"""
|
||||
import math
|
||||
|
||||
dim = self.num_heads * self.head_dim
|
||||
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
||||
theta = self.positional_embedding_theta
|
||||
n_elem = 2
|
||||
max_pos = [1] # Default for connector
|
||||
n_elem = 2 * len(max_pos) # = 2
|
||||
|
||||
|
||||
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
|
||||
indices = (theta ** linspace_vals) * (math.pi / 2)
|
||||
# Generate frequency indices (matches generate_freq_grid_pytorch)
|
||||
start = 1.0
|
||||
end = theta
|
||||
num_indices = dim // n_elem # 1920
|
||||
|
||||
log_start = math.log(start) / math.log(theta) # = 0
|
||||
log_end = math.log(end) / math.log(theta) # = 1
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
indices = (theta ** lin_space) * (math.pi / 2)
|
||||
|
||||
# Generate positions and compute freqs (matches generate_freqs)
|
||||
positions = mx.arange(seq_len).astype(mx.float32)
|
||||
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
|
||||
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
||||
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
||||
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
||||
|
||||
cos = mx.cos(freqs) # (seq_len, dim//2)
|
||||
sin = mx.sin(freqs)
|
||||
# freqs = indices * scaled_positions (outer product)
|
||||
# Shape: (seq_len, num_indices)
|
||||
freqs = scaled_positions[:, None] * indices[None, :]
|
||||
|
||||
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
||||
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
||||
cos_full = mx.repeat(cos_freq, 2, axis=-1)
|
||||
sin_full = mx.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
|
||||
return freqs_cis.astype(dtype)
|
||||
# Add batch dimension: (1, seq_len, dim)
|
||||
cos_full = cos_full[None, :, :]
|
||||
sin_full = sin_full[None, :, :]
|
||||
|
||||
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||
|
||||
def _replace_padded_with_registers(
|
||||
self,
|
||||
@@ -502,9 +517,8 @@ def norm_and_concat_hidden_states(
|
||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||
|
||||
# Compute masked min/max per layer
|
||||
large_val = 1e9
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
|
||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||
range_val = x_max - x_min
|
||||
|
||||
@@ -16,7 +16,7 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
self.act = nn.GELU(approx="tanh") # Must match PyTorch's approximate="tanh"
|
||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
|
||||
@@ -10,13 +10,19 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
hidden_size: int,
|
||||
out_features: int | None = None,
|
||||
bias: bool = True,
|
||||
act_fn: str = "gelu_tanh",
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act = nn.GELU(approx="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user