- Refactor video generation script

- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions.
- Updated VAE weight sanitization for compatibility and improved activation function handling in text projection.
- Added support for saving individual frames and refined output video generation process.
This commit is contained in:
Prince Canuma
2026-01-12 14:04:53 +01:00
parent d1ca36a315
commit 7114b023bd
6 changed files with 270 additions and 304 deletions

451
main.py
View File

@@ -1,6 +1,9 @@
import argparse
import time
from pathlib import Path
import mlx.core as mx
import numpy as np
from pathlib import Path
from PIL import Image
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
@@ -10,205 +13,127 @@ from mlx_video.convert import sanitize_transformer_weights
from mlx_video.generate import create_position_grid
from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
# Paths
from huggingface_hub import snapshot_download
from pathlib import Path
import os
LTX2_REPO = "Lightricks/LTX-2"
def get_ltx2_cache_dir():
# Try to get local cache (local_only), will not download files
# Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
ref_path = snapshot_download(
repo_id=LTX2_REPO,
local_files_only=True,
allow_patterns=["*"],
ignore_patterns=[],
# leave as default revision and cache_dir, only local
)
return ref_path
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
# If not present locally, download from hub
return snapshot_download(
repo_id=LTX2_REPO,
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
ignore_patterns=[]
)
LTX2_PATH = Path(get_ltx2_cache_dir())
MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors')
UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')
TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder')
TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer')
# Distilled sigma schedules (from PyTorch)
STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps
))
def denoise_loop(
def denoise(
latents: mx.array,
positions: mx.array,
text_embeddings: mx.array,
transformer: LTXModel,
sigma_schedule: list,
stage_name: str = "Stage",
negative_embeddings: mx.array = None,
cfg_scale: float = 1.0,
sigmas: list,
) -> mx.array:
"""Run denoising loop for given sigma schedule.
Args:
latents: Noisy latent tensor
positions: Position embeddings
text_embeddings: Positive prompt embeddings
transformer: The transformer model
sigma_schedule: List of sigma values for each step
stage_name: Name for logging
negative_embeddings: Negative prompt embeddings for CFG (optional)
cfg_scale: Classifier-free guidance scale (1.0 = no guidance)
"""
use_cfg = negative_embeddings is not None and cfg_scale > 1.0
for i in range(len(sigma_schedule) - 1):
sigma = sigma_schedule[i]
sigma_next = sigma_schedule[i + 1]
print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}")
"""Run denoising loop."""
for i in range(len(sigmas) - 1):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
timesteps = mx.full((1,), sigma)
# Positive (conditioned) prediction
video_modality = Modality(
latent=latents_flat,
timesteps=timesteps,
timesteps=mx.full((1,), sigma),
positions=positions,
context=text_embeddings,
context_mask=None,
enabled=True,
)
vx_cond, _ = transformer(video=video_modality, audio=None)
mx.eval(vx_cond)
velocity, _ = transformer(video=video_modality, audio=None)
mx.eval(velocity)
if use_cfg:
# Negative (unconditioned) prediction
video_modality_neg = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=negative_embeddings,
context_mask=None,
enabled=True,
)
vx_uncond, _ = transformer(video=video_modality_neg, audio=None)
mx.eval(vx_uncond)
# CFG: output = uncond + cfg_scale * (cond - uncond)
vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond)
else:
vx = vx_cond
vx_reshaped = mx.transpose(vx, (0, 2, 1))
vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w))
# Debug: Print velocity stats
vx_np = np.array(vx_reshaped)
print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}")
# Get denoised prediction: x_0 = x_t - sigma * velocity
denoised = to_denoised(latents, vx_reshaped, sigma)
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
mx.eval(denoised)
# Debug: Print denoised stats
denoised_np = np.array(denoised)
print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}")
# Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma
if sigma_next > 0:
velocity = (latents - denoised) / sigma
latents = denoised + sigma_next * velocity
latents = denoised + sigma_next * (latents - denoised) / sigma
else:
latents = denoised
mx.eval(latents)
# Debug: Print latents after step
latents_np = np.array(latents)
print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
return latents
def main():
print("="*60)
print("MLX LTX-2 Video Generation (Two-Stage)")
print("="*60)
def generate_video(
model_repo: str,
prompt: str,
height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
):
"""Generate video from text prompt.
# Config - same as PyTorch reference
prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic"
negative_prompt = "" # PyTorch script doesn't use negative prompt
cfg_scale = 1.0 # No CFG in the distilled pipeline
height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage
seed = 123
Args:
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
seed: Random seed for reproducibility
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
"""
start_time = time.time()
# Stage 1: Half resolution
stage1_height = height // 2
stage1_width = width // 2
stage1_latent_height = stage1_height // 32
stage1_latent_width = stage1_width // 32
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
print(f"Generating {width}x{height} video with {num_frames} frames")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
# Get model path
model_path = get_model_path(model_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
# Stage 2: Full resolution
latent_height = height // 32
latent_width = width // 32
print(f"\nConfig:")
print(f" Prompt: {prompt}")
print(f" Negative prompt: '{negative_prompt}'")
print(f" CFG scale: {cfg_scale}")
print(f" Final resolution: {width}x{height}, {num_frames} frames")
print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}")
print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}")
print(f" Seed: {seed}")
mx.random.seed(seed)
# Load text encoder
print("\nLoading text encoder...")
print("Loading text encoder...")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH))
text_encoder.load(str(LTX2_PATH))
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
mx.eval(text_encoder.parameters())
# Encode positive prompt
print("Encoding text...")
text_embeddings, attention_mask = text_encoder(prompt)
text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings)
print(f" Positive embeddings: {text_embeddings.shape}")
# Encode negative prompt for CFG
negative_embeddings, _ = text_encoder(negative_prompt)
mx.eval(negative_embeddings)
print(f" Negative embeddings: {negative_embeddings.shape}")
# Free text encoder memory
del text_encoder
mx.clear_cache()
# Load transformer
print("\nLoading transformer...")
raw_weights = mx.load(MODEL_PATH)
print("Loading transformer...")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
@@ -231,164 +156,166 @@ def main():
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
print(" Transformer loaded!")
# ========================================
# Stage 1: Generate at half resolution
# ========================================
print("\n" + "="*60)
print("Stage 1: Generating at half resolution")
print("="*60)
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width))
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
print(f" Initial latents: {latents.shape}")
positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
latents = denoise_loop(
latents=latents,
positions=positions,
text_embeddings=text_embeddings,
transformer=transformer,
sigma_schedule=STAGE_1_SIGMA_SCHEDULE,
stage_name="Stage 1",
negative_embeddings=negative_embeddings,
cfg_scale=cfg_scale,
)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
print(f"\nStage 1 latents: {latents.shape}")
latents_np = np.array(latents)
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# ========================================
# Upsample latents 2x
# ========================================
print("\n" + "="*60)
print("Upsampling latents 2x")
print("="*60)
# Load upsampler
print(" Loading spatial upsampler...")
upsampler = load_upsampler(UPSAMPLER_PATH)
# Upsample latents
print("Upsampling latents 2x...")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
# Load latent statistics for normalization
vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True)
# EXPERIMENT: Disable VAE decode noise for sharper output
# vae_decoder.decode_noise_scale = 0.0
# print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}")
latent_mean = vae_decoder.latents_mean
latent_std = vae_decoder.latents_std
vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=True
)
# Upsample
print(" Upsampling...")
latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False)
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(latents)
print(f" Upsampled latents: {latents.shape}")
# Free upsampler memory
del upsampler
mx.clear_cache()
# ========================================
# Stage 2: Refine at full resolution
# ========================================
print("\n" + "="*60)
print("Stage 2: Refining at full resolution")
print("="*60)
# Debug: Print upsampled latent stats before adding noise
latents_np = np.array(latents)
print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# Create new position grid for full resolution
positions = create_position_grid(1, latent_frames, latent_height, latent_width)
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
# Add noise at initial sigma for stage 2
# PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale)
# NOT addition: noisy = clean + scale * noise
noise_scale = STAGE_2_SIGMA_SCHEDULE[0]
# Add noise for refinement
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
# Debug: Print latents after adding noise
latents_np = np.array(latents)
print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
latents = denoise_loop(
latents=latents,
positions=positions,
text_embeddings=text_embeddings,
transformer=transformer,
sigma_schedule=STAGE_2_SIGMA_SCHEDULE,
stage_name="Stage 2",
)
print(f"\nFinal latents: {latents.shape}")
latents_np = np.array(latents)
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# Save latents for PyTorch comparison
np.save("mlx_final_latents.npy", latents_np)
print(" Saved latents to mlx_final_latents.npy")
# Free transformer memory
del transformer
mx.clear_cache()
# ========================================
# Decode to video
# ========================================
print("\n" + "="*60)
print("Decoding with VAE")
print("="*60)
# Decode latents to video
video = vae_decoder(latents, debug=True)
print("Decoding video...")
video = vae_decoder(latents)
mx.eval(video)
print(f" Video shape: {video.shape}")
mx.clear_cache()
# Convert to frames
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
# Debug: check raw RGB values before conversion
video_raw = np.array(video)
print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}")
print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]")
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1]
video = mx.clip(video, 0.0, 1.0)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}")
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save first frame
output_path = Path("mlx_output_frame0_2.png")
Image.fromarray(video_np[0]).save(output_path)
print(f"\nSaved first frame to {output_path}")
# Save video
try:
import imageio
video_path = "mlx_output_video_2.mp4"
imageio.mimwrite(video_path, video_np, fps=24, codec='libx264')
print(f"Saved video to {video_path}")
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
print(f"Saved video to {output_path}")
except Exception as e:
print(f"Could not save video: {e}")
print("\nDone!")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}")
elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
return video_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python main.py --prompt "A cat walking on grass"
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
"""
)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text description of the video to generate"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Output video height (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Output video width (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--num-frames", "-n",
type=int,
default=33,
help="Number of frames (default: 33, must be 1 + 8*k)"
)
parser.add_argument(
"--seed", "-s",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=24,
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output", "-o",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
)
parser.add_argument(
"--save-frames",
action="store_true",
help="Save individual frames as images"
)
parser.add_argument(
"--model-repo",
type=str,
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
args = parser.parse_args()
generate_video(
model_repo=args.model_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
)
if __name__ == "__main__":
import time
start_time = time.time()
main()
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

View File

@@ -161,7 +161,7 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE
Dictionary with MLX-compatible naming for VAE decoder
"""
sanitized = {}
@@ -172,17 +172,40 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if "position_ids" in key:
continue
# Only process VAE decoder weights (skip audio_vae, etc.)
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
continue
elif key.startswith("vae.decoder."):
# Strip the vae.decoder. prefix for decoder weights
new_key = key.replace("vae.decoder.", "")
else:
# Skip other vae.* keys that are not decoder weights
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value

View File

@@ -381,8 +381,7 @@ def precompute_freqs_cis(
if max_pos is None:
max_pos = [20, 2048, 2048]
# For double precision, compute in numpy (float64) then convert back to MLX
# MLX GPU doesn't support float64, so we use numpy for high precision computation
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
@@ -418,10 +417,7 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies in double precision using numpy.
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
"""
# Convert to numpy float64
indices_grid_np = np.array(indices_grid).astype(np.float64)

View File

@@ -9,7 +9,7 @@ import mlx.core as mx
import mlx.nn as nn
from mlx_video.utils import rms_norm
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@dataclass
class Gemma3Config:
@@ -240,7 +240,7 @@ class ConnectorAttention(nn.Module):
self,
x: mx.array,
attention_mask: Optional[mx.array] = None,
pe: Optional[mx.array] = None,
pe: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
batch_size, seq_len, _ = x.shape
@@ -255,25 +255,16 @@ class ConnectorAttention(nn.Module):
if pe is not None:
# pe: (1, seq_len, num_heads, head_dim, 2)
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
q, k = apply_rotary_emb_1d(q, k, pe)
# Reshape back for attention computation
q = mx.reshape(q, (batch_size, seq_len, -1))
k = mx.reshape(k, (batch_size, seq_len, -1))
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
if attention_mask is not None:
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
# No mask needed for connector - after register replacement, all positions are valid
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
return self.to_out[0](out)
@@ -365,29 +356,53 @@ class Embeddings1DConnector(nn.Module):
if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (INTERLEAVED type).
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
"""
import math
dim = self.num_heads * self.head_dim
dim = self.num_heads * self.head_dim # inner_dim = 3840
theta = self.positional_embedding_theta
n_elem = 2
max_pos = [1] # Default for connector
n_elem = 2 * len(max_pos) # = 2
# Generate frequency indices (matches generate_freq_grid_pytorch)
start = 1.0
end = theta
num_indices = dim // n_elem # 1920
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
indices = (theta ** linspace_vals) * (math.pi / 2)
log_start = math.log(start) / math.log(theta) # = 0
log_end = math.log(end) / math.log(theta) # = 1
lin_space = mx.linspace(log_start, log_end, num_indices)
indices = (theta ** lin_space) * (math.pi / 2)
# Generate positions and compute freqs (matches generate_freqs)
positions = mx.arange(seq_len).astype(mx.float32)
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
cos = mx.cos(freqs) # (seq_len, dim//2)
sin = mx.sin(freqs)
# freqs = indices * scaled_positions (outer product)
# Shape: (seq_len, num_indices)
freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
cos_full = mx.repeat(cos_freq, 2, axis=-1)
sin_full = mx.repeat(sin_freq, 2, axis=-1)
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
return freqs_cis.astype(dtype)
# Add batch dimension: (1, seq_len, dim)
cos_full = cos_full[None, :, :]
sin_full = sin_full[None, :, :]
return cos_full.astype(dtype), sin_full.astype(dtype)
def _replace_padded_with_registers(
self,
@@ -502,9 +517,8 @@ def norm_and_concat_hidden_states(
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer
large_val = 1e9
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min

View File

@@ -16,7 +16,7 @@ class PixArtAlphaTextProjection(nn.Module):
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise")
self.act = nn.GELU(approx="tanh") # Must match PyTorch's approximate="tanh"
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array:

View File

@@ -10,13 +10,19 @@ class PixArtAlphaTextProjection(nn.Module):
hidden_size: int,
out_features: int | None = None,
bias: bool = True,
act_fn: str = "gelu_tanh",
):
super().__init__()
out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise")
if act_fn == "gelu_tanh":
self.act = nn.GELU(approx="tanh")
elif act_fn == "silu":
self.act = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array: