- Refactor video generation script
- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions. - Updated VAE weight sanitization for compatibility and improved activation function handling in text projection. - Added support for saving individual frames and refined output video generation process.
This commit is contained in:
451
main.py
451
main.py
@@ -1,6 +1,9 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
@@ -10,205 +13,127 @@ from mlx_video.convert import sanitize_transformer_weights
|
||||
from mlx_video.generate import create_position_grid
|
||||
from mlx_video.utils import to_denoised
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
|
||||
# Paths
|
||||
from huggingface_hub import snapshot_download
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
LTX2_REPO = "Lightricks/LTX-2"
|
||||
|
||||
def get_ltx2_cache_dir():
|
||||
# Try to get local cache (local_only), will not download files
|
||||
# Distilled sigma schedules
|
||||
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
ref_path = snapshot_download(
|
||||
repo_id=LTX2_REPO,
|
||||
local_files_only=True,
|
||||
allow_patterns=["*"],
|
||||
ignore_patterns=[],
|
||||
# leave as default revision and cache_dir, only local
|
||||
)
|
||||
return ref_path
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
# If not present locally, download from hub
|
||||
return snapshot_download(
|
||||
repo_id=LTX2_REPO,
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
ignore_patterns=[]
|
||||
)
|
||||
|
||||
LTX2_PATH = Path(get_ltx2_cache_dir())
|
||||
MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors')
|
||||
UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')
|
||||
TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder')
|
||||
TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer')
|
||||
|
||||
# Distilled sigma schedules (from PyTorch)
|
||||
STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps
|
||||
))
|
||||
|
||||
|
||||
def denoise_loop(
|
||||
def denoise(
|
||||
latents: mx.array,
|
||||
positions: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigma_schedule: list,
|
||||
stage_name: str = "Stage",
|
||||
negative_embeddings: mx.array = None,
|
||||
cfg_scale: float = 1.0,
|
||||
sigmas: list,
|
||||
) -> mx.array:
|
||||
"""Run denoising loop for given sigma schedule.
|
||||
|
||||
Args:
|
||||
latents: Noisy latent tensor
|
||||
positions: Position embeddings
|
||||
text_embeddings: Positive prompt embeddings
|
||||
transformer: The transformer model
|
||||
sigma_schedule: List of sigma values for each step
|
||||
stage_name: Name for logging
|
||||
negative_embeddings: Negative prompt embeddings for CFG (optional)
|
||||
cfg_scale: Classifier-free guidance scale (1.0 = no guidance)
|
||||
"""
|
||||
use_cfg = negative_embeddings is not None and cfg_scale > 1.0
|
||||
|
||||
for i in range(len(sigma_schedule) - 1):
|
||||
sigma = sigma_schedule[i]
|
||||
sigma_next = sigma_schedule[i + 1]
|
||||
|
||||
print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}")
|
||||
"""Run denoising loop."""
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
|
||||
b, c, f, h, w = latents.shape
|
||||
latents_flat = mx.reshape(latents, (b, c, -1))
|
||||
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
|
||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||
|
||||
timesteps = mx.full((1,), sigma)
|
||||
|
||||
# Positive (conditioned) prediction
|
||||
video_modality = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
timesteps=mx.full((1,), sigma),
|
||||
positions=positions,
|
||||
context=text_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
vx_cond, _ = transformer(video=video_modality, audio=None)
|
||||
mx.eval(vx_cond)
|
||||
velocity, _ = transformer(video=video_modality, audio=None)
|
||||
mx.eval(velocity)
|
||||
|
||||
if use_cfg:
|
||||
# Negative (unconditioned) prediction
|
||||
video_modality_neg = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
positions=positions,
|
||||
context=negative_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
vx_uncond, _ = transformer(video=video_modality_neg, audio=None)
|
||||
mx.eval(vx_uncond)
|
||||
|
||||
# CFG: output = uncond + cfg_scale * (cond - uncond)
|
||||
vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond)
|
||||
else:
|
||||
vx = vx_cond
|
||||
|
||||
vx_reshaped = mx.transpose(vx, (0, 2, 1))
|
||||
vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w))
|
||||
|
||||
# Debug: Print velocity stats
|
||||
vx_np = np.array(vx_reshaped)
|
||||
print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}")
|
||||
|
||||
# Get denoised prediction: x_0 = x_t - sigma * velocity
|
||||
denoised = to_denoised(latents, vx_reshaped, sigma)
|
||||
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||
denoised = to_denoised(latents, velocity, sigma)
|
||||
mx.eval(denoised)
|
||||
|
||||
# Debug: Print denoised stats
|
||||
denoised_np = np.array(denoised)
|
||||
print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}")
|
||||
|
||||
# Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma
|
||||
if sigma_next > 0:
|
||||
velocity = (latents - denoised) / sigma
|
||||
latents = denoised + sigma_next * velocity
|
||||
latents = denoised + sigma_next * (latents - denoised) / sigma
|
||||
else:
|
||||
latents = denoised
|
||||
mx.eval(latents)
|
||||
|
||||
# Debug: Print latents after step
|
||||
latents_np = np.array(latents)
|
||||
print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def main():
|
||||
print("="*60)
|
||||
print("MLX LTX-2 Video Generation (Two-Stage)")
|
||||
print("="*60)
|
||||
def generate_video(
|
||||
model_repo: str,
|
||||
prompt: str,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_frames: int = 33,
|
||||
seed: int = 42,
|
||||
fps: int = 24,
|
||||
output_path: str = "output.mp4",
|
||||
save_frames: bool = False,
|
||||
):
|
||||
"""Generate video from text prompt.
|
||||
|
||||
# Config - same as PyTorch reference
|
||||
prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic"
|
||||
negative_prompt = "" # PyTorch script doesn't use negative prompt
|
||||
cfg_scale = 1.0 # No CFG in the distilled pipeline
|
||||
height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage
|
||||
seed = 123
|
||||
Args:
|
||||
prompt: Text description of the video to generate
|
||||
height: Output video height (must be divisible by 64)
|
||||
width: Output video width (must be divisible by 64)
|
||||
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
|
||||
seed: Random seed for reproducibility
|
||||
fps: Frames per second for output video
|
||||
output_path: Path to save the output video
|
||||
save_frames: Whether to save individual frames as images
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Stage 1: Half resolution
|
||||
stage1_height = height // 2
|
||||
stage1_width = width // 2
|
||||
stage1_latent_height = stage1_height // 32
|
||||
stage1_latent_width = stage1_width // 32
|
||||
# Validate dimensions
|
||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||
|
||||
print(f"Generating {width}x{height} video with {num_frames} frames")
|
||||
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||||
|
||||
# Get model path
|
||||
model_path = get_model_path(model_repo)
|
||||
|
||||
# Calculate latent dimensions
|
||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||
stage2_h, stage2_w = height // 32, width // 32
|
||||
latent_frames = 1 + (num_frames - 1) // 8
|
||||
|
||||
# Stage 2: Full resolution
|
||||
latent_height = height // 32
|
||||
latent_width = width // 32
|
||||
|
||||
print(f"\nConfig:")
|
||||
print(f" Prompt: {prompt}")
|
||||
print(f" Negative prompt: '{negative_prompt}'")
|
||||
print(f" CFG scale: {cfg_scale}")
|
||||
print(f" Final resolution: {width}x{height}, {num_frames} frames")
|
||||
print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}")
|
||||
print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}")
|
||||
print(f" Seed: {seed}")
|
||||
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Load text encoder
|
||||
print("\nLoading text encoder...")
|
||||
print("Loading text encoder...")
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
|
||||
text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH))
|
||||
text_encoder.load(str(LTX2_PATH))
|
||||
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
||||
text_encoder.load(str(model_path))
|
||||
mx.eval(text_encoder.parameters())
|
||||
|
||||
# Encode positive prompt
|
||||
print("Encoding text...")
|
||||
text_embeddings, attention_mask = text_encoder(prompt)
|
||||
text_embeddings, _ = text_encoder(prompt)
|
||||
mx.eval(text_embeddings)
|
||||
print(f" Positive embeddings: {text_embeddings.shape}")
|
||||
|
||||
# Encode negative prompt for CFG
|
||||
negative_embeddings, _ = text_encoder(negative_prompt)
|
||||
mx.eval(negative_embeddings)
|
||||
print(f" Negative embeddings: {negative_embeddings.shape}")
|
||||
|
||||
# Free text encoder memory
|
||||
del text_encoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Load transformer
|
||||
print("\nLoading transformer...")
|
||||
raw_weights = mx.load(MODEL_PATH)
|
||||
print("Loading transformer...")
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
|
||||
config = LTXModelConfig(
|
||||
@@ -231,164 +156,166 @@ def main():
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
mx.eval(transformer.parameters())
|
||||
print(" Transformer loaded!")
|
||||
|
||||
# ========================================
|
||||
# Stage 1: Generate at half resolution
|
||||
# ========================================
|
||||
print("\n" + "="*60)
|
||||
print("Stage 1: Generating at half resolution")
|
||||
print("="*60)
|
||||
|
||||
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
|
||||
mx.random.seed(seed)
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width))
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||
mx.eval(latents)
|
||||
print(f" Initial latents: {latents.shape}")
|
||||
|
||||
positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width)
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
latents = denoise_loop(
|
||||
latents=latents,
|
||||
positions=positions,
|
||||
text_embeddings=text_embeddings,
|
||||
transformer=transformer,
|
||||
sigma_schedule=STAGE_1_SIGMA_SCHEDULE,
|
||||
stage_name="Stage 1",
|
||||
negative_embeddings=negative_embeddings,
|
||||
cfg_scale=cfg_scale,
|
||||
)
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
||||
|
||||
print(f"\nStage 1 latents: {latents.shape}")
|
||||
latents_np = np.array(latents)
|
||||
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
|
||||
# ========================================
|
||||
# Upsample latents 2x
|
||||
# ========================================
|
||||
print("\n" + "="*60)
|
||||
print("Upsampling latents 2x")
|
||||
print("="*60)
|
||||
|
||||
# Load upsampler
|
||||
print(" Loading spatial upsampler...")
|
||||
upsampler = load_upsampler(UPSAMPLER_PATH)
|
||||
# Upsample latents
|
||||
print("Upsampling latents 2x...")
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
# Load latent statistics for normalization
|
||||
vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True)
|
||||
# EXPERIMENT: Disable VAE decode noise for sharper output
|
||||
# vae_decoder.decode_noise_scale = 0.0
|
||||
# print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}")
|
||||
latent_mean = vae_decoder.latents_mean
|
||||
latent_std = vae_decoder.latents_std
|
||||
vae_decoder = load_vae_decoder(
|
||||
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||
timestep_conditioning=True
|
||||
)
|
||||
|
||||
# Upsample
|
||||
print(" Upsampling...")
|
||||
latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False)
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||
mx.eval(latents)
|
||||
print(f" Upsampled latents: {latents.shape}")
|
||||
|
||||
# Free upsampler memory
|
||||
del upsampler
|
||||
mx.clear_cache()
|
||||
|
||||
# ========================================
|
||||
# Stage 2: Refine at full resolution
|
||||
# ========================================
|
||||
print("\n" + "="*60)
|
||||
print("Stage 2: Refining at full resolution")
|
||||
print("="*60)
|
||||
|
||||
# Debug: Print upsampled latent stats before adding noise
|
||||
latents_np = np.array(latents)
|
||||
print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
|
||||
# Create new position grid for full resolution
|
||||
positions = create_position_grid(1, latent_frames, latent_height, latent_width)
|
||||
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
# Add noise at initial sigma for stage 2
|
||||
# PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale)
|
||||
# NOT addition: noisy = clean + scale * noise
|
||||
noise_scale = STAGE_2_SIGMA_SCHEDULE[0]
|
||||
# Add noise for refinement
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
noise = mx.random.normal(latents.shape)
|
||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||
mx.eval(latents)
|
||||
|
||||
# Debug: Print latents after adding noise
|
||||
latents_np = np.array(latents)
|
||||
print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
||||
|
||||
latents = denoise_loop(
|
||||
latents=latents,
|
||||
positions=positions,
|
||||
text_embeddings=text_embeddings,
|
||||
transformer=transformer,
|
||||
sigma_schedule=STAGE_2_SIGMA_SCHEDULE,
|
||||
stage_name="Stage 2",
|
||||
)
|
||||
|
||||
print(f"\nFinal latents: {latents.shape}")
|
||||
latents_np = np.array(latents)
|
||||
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
|
||||
# Save latents for PyTorch comparison
|
||||
np.save("mlx_final_latents.npy", latents_np)
|
||||
print(" Saved latents to mlx_final_latents.npy")
|
||||
|
||||
# Free transformer memory
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
# ========================================
|
||||
# Decode to video
|
||||
# ========================================
|
||||
print("\n" + "="*60)
|
||||
print("Decoding with VAE")
|
||||
print("="*60)
|
||||
|
||||
# Decode latents to video
|
||||
video = vae_decoder(latents, debug=True)
|
||||
print("Decoding video...")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
print(f" Video shape: {video.shape}")
|
||||
mx.clear_cache()
|
||||
|
||||
# Convert to frames
|
||||
# Convert to uint8 frames
|
||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||
|
||||
# Debug: check raw RGB values before conversion
|
||||
video_raw = np.array(video)
|
||||
print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}")
|
||||
print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]")
|
||||
|
||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||
video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
||||
video = mx.clip(video, 0.0, 1.0)
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
|
||||
print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}")
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save first frame
|
||||
output_path = Path("mlx_output_frame0_2.png")
|
||||
Image.fromarray(video_np[0]).save(output_path)
|
||||
print(f"\nSaved first frame to {output_path}")
|
||||
|
||||
# Save video
|
||||
try:
|
||||
import imageio
|
||||
video_path = "mlx_output_video_2.mp4"
|
||||
imageio.mimwrite(video_path, video_np, fps=24, codec='libx264')
|
||||
print(f"Saved video to {video_path}")
|
||||
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
|
||||
print(f"Saved video to {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Could not save video: {e}")
|
||||
|
||||
print("\nDone!")
|
||||
if save_frames:
|
||||
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
||||
frames_dir.mkdir(exist_ok=True)
|
||||
for i, frame in enumerate(video_np):
|
||||
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
|
||||
print(f"Saved {len(video_np)} frames to {frames_dir}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
|
||||
|
||||
return video_np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate videos with MLX LTX-2",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python main.py --prompt "A cat walking on grass"
|
||||
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
|
||||
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt", "-p",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Text description of the video to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height", "-H",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output video height (default: 512, must be divisible by 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width", "-W",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output video width (default: 512, must be divisible by 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames", "-n",
|
||||
type=int,
|
||||
default=33,
|
||||
help="Number of frames (default: 33, must be 1 + 8*k)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", "-s",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducibility (default: 42)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=24,
|
||||
help="Frames per second for output video (default: 24)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
type=str,
|
||||
default="output.mp4",
|
||||
help="Output video path (default: output.mp4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
action="store_true",
|
||||
help="Save individual frames as images"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-repo",
|
||||
type=str,
|
||||
default="Lightricks/LTX-2",
|
||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
model_repo=args.model_repo,
|
||||
prompt=args.prompt,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_frames=args.num_frames,
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
output_path=args.output,
|
||||
save_frames=args.save_frames,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
start_time = time.time()
|
||||
main()
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
|
||||
@@ -161,7 +161,7 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for VAE
|
||||
Dictionary with MLX-compatible naming for VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
@@ -172,17 +172,40 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Only process VAE decoder weights (skip audio_vae, etc.)
|
||||
if not key.startswith("vae."):
|
||||
continue
|
||||
|
||||
# Handle per-channel statistics key mapping
|
||||
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
|
||||
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
|
||||
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
|
||||
if "vae.per_channel_statistics" in key:
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
|
||||
continue
|
||||
elif key.startswith("vae.decoder."):
|
||||
# Strip the vae.decoder. prefix for decoder weights
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
else:
|
||||
# Skip other vae.* keys that are not decoder weights
|
||||
continue
|
||||
|
||||
# Handle Conv3d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
@@ -381,8 +381,7 @@ def precompute_freqs_cis(
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
# For double precision, compute in numpy (float64) then convert back to MLX
|
||||
# MLX GPU doesn't support float64, so we use numpy for high precision computation
|
||||
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
@@ -418,10 +417,7 @@ def _precompute_freqs_cis_double_precision(
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies in double precision using numpy.
|
||||
|
||||
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
|
||||
"""
|
||||
# Convert to numpy float64
|
||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.utils import rms_norm
|
||||
from mlx_video.models.ltx.rope import apply_rotary_emb_1d
|
||||
from mlx_video.models.ltx.rope import apply_interleaved_rotary_emb
|
||||
|
||||
@dataclass
|
||||
class Gemma3Config:
|
||||
@@ -240,7 +240,7 @@ class ConnectorAttention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
pe: Optional[mx.array] = None,
|
||||
pe: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
@@ -255,25 +255,16 @@ class ConnectorAttention(nn.Module):
|
||||
|
||||
|
||||
if pe is not None:
|
||||
# pe: (1, seq_len, num_heads, head_dim, 2)
|
||||
# q, k: (B, seq, inner_dim) - need to reshape for RoPE then reshape back
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
|
||||
q, k = apply_rotary_emb_1d(q, k, pe)
|
||||
# Reshape back for attention computation
|
||||
q = mx.reshape(q, (batch_size, seq_len, -1))
|
||||
k = mx.reshape(k, (batch_size, seq_len, -1))
|
||||
|
||||
# pe: tuple of (cos, sin) each with shape (1, seq_len, inner_dim)
|
||||
q = apply_interleaved_rotary_emb(q, pe[0], pe[1])
|
||||
k = apply_interleaved_rotary_emb(k, pe[0], pe[1])
|
||||
|
||||
q = mx.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
k = mx.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
v = mx.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim)).transpose(0, 2, 1, 3)
|
||||
|
||||
mask = mx.full((batch_size, seq_len, seq_len), -1e9, dtype=q.dtype)
|
||||
if attention_mask is not None:
|
||||
mask = mask + (1.0 - attention_mask[:, None, None, :].astype(q.dtype)) * -1e9
|
||||
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attention_mask)
|
||||
# No mask needed for connector - after register replacement, all positions are valid
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
|
||||
|
||||
return self.to_out[0](out)
|
||||
@@ -365,29 +356,53 @@ class Embeddings1DConnector(nn.Module):
|
||||
if num_learnable_registers > 0:
|
||||
self.learnable_registers = mx.zeros((num_learnable_registers, dim))
|
||||
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> mx.array:
|
||||
def _precompute_freqs_cis(self, seq_len: int, dtype: mx.Dtype) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies for connector (INTERLEAVED type).
|
||||
|
||||
Matches PyTorch: generate_freq_grid_pytorch + generate_freqs + interleaved_freqs_cis
|
||||
Returns tuple of (cos, sin) each with shape (1, seq_len, inner_dim).
|
||||
"""
|
||||
import math
|
||||
|
||||
dim = self.num_heads * self.head_dim
|
||||
dim = self.num_heads * self.head_dim # inner_dim = 3840
|
||||
theta = self.positional_embedding_theta
|
||||
n_elem = 2
|
||||
max_pos = [1] # Default for connector
|
||||
n_elem = 2 * len(max_pos) # = 2
|
||||
|
||||
|
||||
linspace_vals = mx.linspace(0.0, 1.0, dim // n_elem)
|
||||
indices = (theta ** linspace_vals) * (math.pi / 2)
|
||||
# Generate frequency indices (matches generate_freq_grid_pytorch)
|
||||
start = 1.0
|
||||
end = theta
|
||||
num_indices = dim // n_elem # 1920
|
||||
|
||||
log_start = math.log(start) / math.log(theta) # = 0
|
||||
log_end = math.log(end) / math.log(theta) # = 1
|
||||
lin_space = mx.linspace(log_start, log_end, num_indices)
|
||||
indices = (theta ** lin_space) * (math.pi / 2)
|
||||
|
||||
# Generate positions and compute freqs (matches generate_freqs)
|
||||
positions = mx.arange(seq_len).astype(mx.float32)
|
||||
freqs = positions[:, None] * indices[None, :] # (seq_len, dim//2)
|
||||
# fractional_positions = positions / max_pos[0] = positions (since max_pos[0]=1)
|
||||
# scaled_positions = fractional_positions * 2 - 1 = positions * 2 - 1
|
||||
scaled_positions = positions * 2 - 1 # Shape: (seq_len,)
|
||||
|
||||
cos = mx.cos(freqs) # (seq_len, dim//2)
|
||||
sin = mx.sin(freqs)
|
||||
# freqs = indices * scaled_positions (outer product)
|
||||
# Shape: (seq_len, num_indices)
|
||||
freqs = scaled_positions[:, None] * indices[None, :]
|
||||
|
||||
# Compute cos/sin with interleaved pattern (matches interleaved_freqs_cis)
|
||||
cos_freq = mx.cos(freqs)
|
||||
sin_freq = mx.sin(freqs)
|
||||
|
||||
cos_full = mx.repeat(cos, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||
sin_full = mx.repeat(sin, 2, axis=-1).reshape(1, seq_len, self.num_heads, self.head_dim)
|
||||
# repeat_interleave: (seq_len, num_indices) -> (seq_len, dim)
|
||||
# Pattern: [c0, c0, c1, c1, c2, c2, ...]
|
||||
cos_full = mx.repeat(cos_freq, 2, axis=-1)
|
||||
sin_full = mx.repeat(sin_freq, 2, axis=-1)
|
||||
|
||||
freqs_cis = mx.stack([cos_full, sin_full], axis=-1) # (1, seq_len, num_heads, head_dim, 2)
|
||||
return freqs_cis.astype(dtype)
|
||||
# Add batch dimension: (1, seq_len, dim)
|
||||
cos_full = cos_full[None, :, :]
|
||||
sin_full = sin_full[None, :, :]
|
||||
|
||||
return cos_full.astype(dtype), sin_full.astype(dtype)
|
||||
|
||||
def _replace_padded_with_registers(
|
||||
self,
|
||||
@@ -502,9 +517,8 @@ def norm_and_concat_hidden_states(
|
||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||
|
||||
# Compute masked min/max per layer
|
||||
large_val = 1e9
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, large_val, dtype=stacked.dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, -large_val, dtype=stacked.dtype))
|
||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
|
||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
|
||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||
range_val = x_max - x_min
|
||||
|
||||
@@ -16,7 +16,7 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
self.act = nn.GELU(approx="tanh") # Must match PyTorch's approximate="tanh"
|
||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
|
||||
@@ -10,13 +10,19 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
hidden_size: int,
|
||||
out_features: int | None = None,
|
||||
bias: bool = True,
|
||||
act_fn: str = "gelu_tanh",
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or hidden_size
|
||||
self.linear1 = nn.Linear(in_features, hidden_size, bias=bias)
|
||||
self.act = nn.GELU(approx="precise")
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act = nn.GELU(approx="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation function: {act_fn}")
|
||||
self.linear2 = nn.Linear(hidden_size, out_features, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user