- Refactor video generation script

- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions.
- Updated VAE weight sanitization for compatibility and improved activation function handling in text projection.
- Added support for saving individual frames and refined output video generation process.
This commit is contained in:
Prince Canuma
2026-01-12 14:04:53 +01:00
parent d1ca36a315
commit 7114b023bd
6 changed files with 270 additions and 304 deletions

451
main.py
View File

@@ -1,6 +1,9 @@
import argparse
import time
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from pathlib import Path
from PIL import Image from PIL import Image
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
@@ -10,205 +13,127 @@ from mlx_video.convert import sanitize_transformer_weights
from mlx_video.generate import create_position_grid from mlx_video.generate import create_position_grid
from mlx_video.utils import to_denoised from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
# Paths
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from pathlib import Path
import os
LTX2_REPO = "Lightricks/LTX-2"
def get_ltx2_cache_dir(): # Distilled sigma schedules
# Try to get local cache (local_only), will not download files STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try: try:
ref_path = snapshot_download( return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
repo_id=LTX2_REPO,
local_files_only=True,
allow_patterns=["*"],
ignore_patterns=[],
# leave as default revision and cache_dir, only local
)
return ref_path
except Exception: except Exception:
# If not present locally, download from hub print("Downloading LTX-2 model weights...")
return snapshot_download( return Path(snapshot_download(
repo_id=LTX2_REPO, repo_id=model_repo,
local_files_only=False, local_files_only=False,
resume_download=True, resume_download=True,
allow_patterns=["*.safetensors", "*.json"], allow_patterns=["*.safetensors", "*.json"],
ignore_patterns=[] ))
)
LTX2_PATH = Path(get_ltx2_cache_dir())
MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors')
UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')
TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder')
TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer')
# Distilled sigma schedules (from PyTorch)
STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps
def denoise_loop( def denoise(
latents: mx.array, latents: mx.array,
positions: mx.array, positions: mx.array,
text_embeddings: mx.array, text_embeddings: mx.array,
transformer: LTXModel, transformer: LTXModel,
sigma_schedule: list, sigmas: list,
stage_name: str = "Stage",
negative_embeddings: mx.array = None,
cfg_scale: float = 1.0,
) -> mx.array: ) -> mx.array:
"""Run denoising loop for given sigma schedule. """Run denoising loop."""
for i in range(len(sigmas) - 1):
Args: sigma, sigma_next = sigmas[i], sigmas[i + 1]
latents: Noisy latent tensor
positions: Position embeddings
text_embeddings: Positive prompt embeddings
transformer: The transformer model
sigma_schedule: List of sigma values for each step
stage_name: Name for logging
negative_embeddings: Negative prompt embeddings for CFG (optional)
cfg_scale: Classifier-free guidance scale (1.0 = no guidance)
"""
use_cfg = negative_embeddings is not None and cfg_scale > 1.0
for i in range(len(sigma_schedule) - 1):
sigma = sigma_schedule[i]
sigma_next = sigma_schedule[i + 1]
print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}")
b, c, f, h, w = latents.shape b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1)) latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
timesteps = mx.full((1,), sigma)
# Positive (conditioned) prediction
video_modality = Modality( video_modality = Modality(
latent=latents_flat, latent=latents_flat,
timesteps=timesteps, timesteps=mx.full((1,), sigma),
positions=positions, positions=positions,
context=text_embeddings, context=text_embeddings,
context_mask=None, context_mask=None,
enabled=True, enabled=True,
) )
vx_cond, _ = transformer(video=video_modality, audio=None) velocity, _ = transformer(video=video_modality, audio=None)
mx.eval(vx_cond) mx.eval(velocity)
if use_cfg: velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
# Negative (unconditioned) prediction denoised = to_denoised(latents, velocity, sigma)
video_modality_neg = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=negative_embeddings,
context_mask=None,
enabled=True,
)
vx_uncond, _ = transformer(video=video_modality_neg, audio=None)
mx.eval(vx_uncond)
# CFG: output = uncond + cfg_scale * (cond - uncond)
vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond)
else:
vx = vx_cond
vx_reshaped = mx.transpose(vx, (0, 2, 1))
vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w))
# Debug: Print velocity stats
vx_np = np.array(vx_reshaped)
print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}")
# Get denoised prediction: x_0 = x_t - sigma * velocity
denoised = to_denoised(latents, vx_reshaped, sigma)
mx.eval(denoised) mx.eval(denoised)
# Debug: Print denoised stats
denoised_np = np.array(denoised)
print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}")
# Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma
if sigma_next > 0: if sigma_next > 0:
velocity = (latents - denoised) / sigma latents = denoised + sigma_next * (latents - denoised) / sigma
latents = denoised + sigma_next * velocity
else: else:
latents = denoised latents = denoised
mx.eval(latents) mx.eval(latents)
# Debug: Print latents after step
latents_np = np.array(latents)
print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
return latents return latents
def main(): def generate_video(
print("="*60) model_repo: str,
print("MLX LTX-2 Video Generation (Two-Stage)") prompt: str,
print("="*60) height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
):
"""Generate video from text prompt.
# Config - same as PyTorch reference Args:
prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic" prompt: Text description of the video to generate
negative_prompt = "" # PyTorch script doesn't use negative prompt height: Output video height (must be divisible by 64)
cfg_scale = 1.0 # No CFG in the distilled pipeline width: Output video width (must be divisible by 64)
height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
seed = 123 seed: Random seed for reproducibility
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
"""
start_time = time.time()
# Stage 1: Half resolution # Validate dimensions
stage1_height = height // 2 assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
stage1_width = width // 2 assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
stage1_latent_height = stage1_height // 32
stage1_latent_width = stage1_width // 32 print(f"Generating {width}x{height} video with {num_frames} frames")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
# Get model path
model_path = get_model_path(model_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8 latent_frames = 1 + (num_frames - 1) // 8
# Stage 2: Full resolution
latent_height = height // 32
latent_width = width // 32
print(f"\nConfig:")
print(f" Prompt: {prompt}")
print(f" Negative prompt: '{negative_prompt}'")
print(f" CFG scale: {cfg_scale}")
print(f" Final resolution: {width}x{height}, {num_frames} frames")
print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}")
print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}")
print(f" Seed: {seed}")
mx.random.seed(seed) mx.random.seed(seed)
# Load text encoder # Load text encoder
print("\nLoading text encoder...") print("Loading text encoder...")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH)) text_encoder.load(str(model_path))
text_encoder.load(str(LTX2_PATH))
mx.eval(text_encoder.parameters()) mx.eval(text_encoder.parameters())
# Encode positive prompt text_embeddings, _ = text_encoder(prompt)
print("Encoding text...")
text_embeddings, attention_mask = text_encoder(prompt)
mx.eval(text_embeddings) mx.eval(text_embeddings)
print(f" Positive embeddings: {text_embeddings.shape}")
# Encode negative prompt for CFG
negative_embeddings, _ = text_encoder(negative_prompt)
mx.eval(negative_embeddings)
print(f" Negative embeddings: {negative_embeddings.shape}")
# Free text encoder memory
del text_encoder del text_encoder
mx.clear_cache() mx.clear_cache()
# Load transformer # Load transformer
print("\nLoading transformer...") print("Loading transformer...")
raw_weights = mx.load(MODEL_PATH) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig( config = LTXModelConfig(
@@ -231,164 +156,166 @@ def main():
transformer = LTXModel(config) transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False) transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters()) mx.eval(transformer.parameters())
print(" Transformer loaded!")
# ========================================
# Stage 1: Generate at half resolution # Stage 1: Generate at half resolution
# ======================================== print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
print("\n" + "="*60)
print("Stage 1: Generating at half resolution")
print("="*60)
mx.random.seed(seed) mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width)) latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents) mx.eval(latents)
print(f" Initial latents: {latents.shape}")
positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions) mx.eval(positions)
latents = denoise_loop( latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
latents=latents,
positions=positions,
text_embeddings=text_embeddings,
transformer=transformer,
sigma_schedule=STAGE_1_SIGMA_SCHEDULE,
stage_name="Stage 1",
negative_embeddings=negative_embeddings,
cfg_scale=cfg_scale,
)
print(f"\nStage 1 latents: {latents.shape}") # Upsample latents
latents_np = np.array(latents) print("Upsampling latents 2x...")
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}") upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
# ========================================
# Upsample latents 2x
# ========================================
print("\n" + "="*60)
print("Upsampling latents 2x")
print("="*60)
# Load upsampler
print(" Loading spatial upsampler...")
upsampler = load_upsampler(UPSAMPLER_PATH)
mx.eval(upsampler.parameters()) mx.eval(upsampler.parameters())
# Load latent statistics for normalization vae_decoder = load_vae_decoder(
vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True) str(model_path / 'ltx-2-19b-distilled.safetensors'),
# EXPERIMENT: Disable VAE decode noise for sharper output timestep_conditioning=True
# vae_decoder.decode_noise_scale = 0.0 )
# print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}")
latent_mean = vae_decoder.latents_mean
latent_std = vae_decoder.latents_std
# Upsample latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
print(" Upsampling...")
latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False)
mx.eval(latents) mx.eval(latents)
print(f" Upsampled latents: {latents.shape}")
# Free upsampler memory
del upsampler del upsampler
mx.clear_cache() mx.clear_cache()
# ========================================
# Stage 2: Refine at full resolution # Stage 2: Refine at full resolution
# ======================================== print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
print("\n" + "="*60) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
print("Stage 2: Refining at full resolution")
print("="*60)
# Debug: Print upsampled latent stats before adding noise
latents_np = np.array(latents)
print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# Create new position grid for full resolution
positions = create_position_grid(1, latent_frames, latent_height, latent_width)
mx.eval(positions) mx.eval(positions)
# Add noise at initial sigma for stage 2 # Add noise for refinement
# PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale) noise_scale = STAGE_2_SIGMAS[0]
# NOT addition: noisy = clean + scale * noise
noise_scale = STAGE_2_SIGMA_SCHEDULE[0]
noise = mx.random.normal(latents.shape) noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale) latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents) mx.eval(latents)
# Debug: Print latents after adding noise latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
latents_np = np.array(latents)
print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
latents = denoise_loop(
latents=latents,
positions=positions,
text_embeddings=text_embeddings,
transformer=transformer,
sigma_schedule=STAGE_2_SIGMA_SCHEDULE,
stage_name="Stage 2",
)
print(f"\nFinal latents: {latents.shape}")
latents_np = np.array(latents)
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# Save latents for PyTorch comparison
np.save("mlx_final_latents.npy", latents_np)
print(" Saved latents to mlx_final_latents.npy")
# Free transformer memory
del transformer del transformer
mx.clear_cache() mx.clear_cache()
# ========================================
# Decode to video # Decode to video
# ======================================== print("Decoding video...")
print("\n" + "="*60) video = vae_decoder(latents)
print("Decoding with VAE")
print("="*60)
# Decode latents to video
video = vae_decoder(latents, debug=True)
mx.eval(video) mx.eval(video)
print(f" Video shape: {video.shape}") mx.clear_cache()
# Convert to frames # Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W) video = mx.squeeze(video, axis=0) # (C, F, H, W)
# Debug: check raw RGB values before conversion
video_raw = np.array(video)
print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}")
print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]")
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1] video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = mx.clip(video, 0.0, 1.0)
video = (video * 255).astype(mx.uint8) video = (video * 255).astype(mx.uint8)
video_np = np.array(video) video_np = np.array(video)
print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}") # Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save first frame
output_path = Path("mlx_output_frame0_2.png")
Image.fromarray(video_np[0]).save(output_path)
print(f"\nSaved first frame to {output_path}")
# Save video
try: try:
import imageio import imageio
video_path = "mlx_output_video_2.mp4" imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
imageio.mimwrite(video_path, video_np, fps=24, codec='libx264') print(f"Saved video to {output_path}")
print(f"Saved video to {video_path}")
except Exception as e: except Exception as e:
print(f"Could not save video: {e}") print(f"Could not save video: {e}")
print("\nDone!") if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}")
elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
return video_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python main.py --prompt "A cat walking on grass"
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
"""
)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text description of the video to generate"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Output video height (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Output video width (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--num-frames", "-n",
type=int,
default=33,
help="Number of frames (default: 33, must be 1 + 8*k)"
)
parser.add_argument(
"--seed", "-s",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=24,
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output", "-o",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
)
parser.add_argument(
"--save-frames",
action="store_true",
help="Save individual frames as images"
)
parser.add_argument(
"--model-repo",
type=str,
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
args = parser.parse_args()
generate_video(
model_repo=args.model_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
)
if __name__ == "__main__": if __name__ == "__main__":
import time
start_time = time.time()
main() main()
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

View File

@@ -161,7 +161,7 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
weights: Dictionary of weights with PyTorch naming weights: Dictionary of weights with PyTorch naming
Returns: Returns:
Dictionary with MLX-compatible naming for VAE Dictionary with MLX-compatible naming for VAE decoder
""" """
sanitized = {} sanitized = {}
@@ -172,17 +172,40 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if "position_ids" in key: if "position_ids" in key:
continue continue
# Only process VAE decoder weights (skip audio_vae, etc.)
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
continue
elif key.startswith("vae.decoder."):
# Strip the vae.decoder. prefix for decoder weights
new_key = key.replace("vae.decoder.", "")
else:
# Skip other vae.* keys that are not decoder weights
continue
# Handle Conv3d weight shape conversion # Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W) # PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels) # MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5: if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
# Transpose from (O, I, D, H, W) to (O, D, H, W, I) # Transpose from (O, I, D, H, W) to (O, D, H, W, I)
value = mx.transpose(value, (0, 2, 3, 4, 1)) value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion # Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W) # PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels) # MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4: if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1)) value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value sanitized[new_key] = value

View File

@@ -381,8 +381,7 @@ def precompute_freqs_cis(
if max_pos is None: if max_pos is None:
max_pos = [20, 2048, 2048] max_pos = [20, 2048, 2048]
# For double precision, compute in numpy (float64) then convert back to MLX
# MLX GPU doesn't support float64, so we use numpy for high precision computation
if double_precision: if double_precision:
return _precompute_freqs_cis_double_precision( return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid, indices_grid, dim, theta, max_pos, use_middle_indices_grid,
@@ -418,10 +417,7 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int, num_attention_heads: int,
rope_type: LTXRopeType, rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]: ) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies in double precision using numpy.
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
"""
# Convert to numpy float64 # Convert to numpy float64
indices_grid_np = np.array(indices_grid).astype(np.float64) indices_grid_np = np.array(indices_grid).astype(np.float64)

View File

@@ -9,7 +9,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx_video.utils import rms_norm from mlx_video.utils import rms_norm
from mlx_video.models.ltx.rope import apply_rotary_emb_1d from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
@dataclass @dataclass
class Gemma3Config: class Gemma3Config:
@@ -240,7 +240,7 @@ class ConnectorAttention(nn.Module):
self, self,
x: mx.array, x: mx.array,
attention_mask: Optional[mx.array] = None, attention_mask: Optional[mx.array] = None,
pe: Optional[mx.array] = None, pe: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array: ) -> mx.array:
batch_size, seq_len, _ = x.shape batch_size, seq_len, _ = x.shape
@@ -255,25 +255,16 @@ class ConnectorAttention(nn.Module):
if pe is not None: if pe is not None:
# pe: (1, seq_len, num_heads, head_dim, 2) # pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)) k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
q, k = apply_rotary_emb_1d(q, k, pe)
# Reshape back for attention computation
q = mx.reshape(q, (batch_size, seq_len, -1))
k = mx.reshape(k, (batch_size, seq_len, -1))
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3) v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype) # No mask needed for connector - after register replacement, all positions are valid
if attention_mask is not None: out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1) out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
return self.to_out[0](out) return self.to_out[0](out)
@@ -365,29 +356,53 @@ class Embeddings1DConnector(nn.Module):
if num_learnable_registers > 0: if num_learnable_registers > 0:
self.learnable_registers = mx.zeros((num_learnable_registers, dim)) self.learnable_registers = mx.zeros((num_learnable_registers, dim))
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array: def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies for connector (INTERLEAVED type).
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
"""
import math import math
dim = self.num_heads * self.head_dim dim = self.num_heads * self.head_dim # inner_dim = 3840
theta = self.positional_embedding_theta theta = self.positional_embedding_theta
n_elem = 2 max_pos = [1] # Default for connector
n_elem = 2 * len(max_pos) # = 2
# Generate frequency indices (matches generate_freq_grid_pytorch)
start = 1.0
end = theta
num_indices = dim // n_elem # 1920
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem) log_start = math.log(start) / math.log(theta) # = 0
indices = (theta ** linspace_vals) * (math.pi / 2) log_end = math.log(end) / math.log(theta) # = 1
lin_space = mx.linspace(log_start, log_end, num_indices)
indices = (theta ** lin_space) * (math.pi / 2)
# Generate positions and compute freqs (matches generate_freqs)
positions = mx.arange(seq_len).astype(mx.float32) positions = mx.arange(seq_len).astype(mx.float32)
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2) # fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
cos = mx.cos(freqs) # (seq_len, dim//2) # freqs = indices * scaled_positions (outer product)
sin = mx.sin(freqs) # Shape: (seq_len, num_indices)
freqs = scaled_positions[:, None] * indices[None, :]
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
cos_freq = mx.cos(freqs)
sin_freq = mx.sin(freqs)
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim) # repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim) # Pattern: [c0, c0, c1, c1, c2, c2, ...]
cos_full = mx.repeat(cos_freq, 2, axis=-1)
sin_full = mx.repeat(sin_freq, 2, axis=-1)
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2) # Add batch dimension: (1, seq_len, dim)
return freqs_cis.astype(dtype) cos_full = cos_full[None, :, :]
sin_full = sin_full[None, :, :]
return cos_full.astype(dtype), sin_full.astype(dtype)
def _replace_padded_with_registers( def _replace_padded_with_registers(
self, self,
@@ -502,9 +517,8 @@ def norm_and_concat_hidden_states(
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps) mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer # Compute masked min/max per layer
large_val = 1e9 x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype)) x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True) x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True) x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min range_val = x_max - x_min

View File

@@ -16,7 +16,7 @@ class PixArtAlphaTextProjection(nn.Module):
out_features = out_features or hidden_size out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias) self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise") self.act = nn.GELU(approx="tanh") # Must match PyTorch's approximate="tanh"
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias) self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array:

View File

@@ -10,13 +10,19 @@ class PixArtAlphaTextProjection(nn.Module):
hidden_size: int, hidden_size: int,
out_features: int | None = None, out_features: int | None = None,
bias: bool = True, bias: bool = True,
act_fn: str = "gelu_tanh",
): ):
super().__init__() super().__init__()
out_features = out_features or hidden_size out_features = out_features or hidden_size
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias) self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
self.act = nn.GELU(approx="precise") if act_fn == "gelu_tanh":
self.act = nn.GELU(approx="tanh")
elif act_fn == "silu":
self.act = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias) self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
def __call__(self, x: mx.array) -> mx.array: def __call__(self, x: mx.array) -> mx.array: