- Refactor video generation script

- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions.
- Updated VAE weight sanitization for compatibility and improved activation function handling in text projection.
- Added support for saving individual frames and refined output video generation process.
This commit is contained in:
Prince Canuma
2026-01-12 14:04:53 +01:00
parent d1ca36a315
commit 7114b023bd
6 changed files with 270 additions and 304 deletions

451
main.py
View File

@@ -1,6 +1,9 @@
import argparse
import time
from pathlib import Path
import mlx.core as mx
import numpy as np
from pathlib import Path
from PIL import Image
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
@@ -10,205 +13,127 @@ from mlx_video.convert import sanitize_transformer_weights
from mlx_video.generate import create_position_grid
from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
# Paths
from huggingface_hub import snapshot_download
from pathlib import Path
import os
LTX2_REPO = "Lightricks/LTX-2"
def get_ltx2_cache_dir():
# Try to get local cache (local_only), will not download files
# Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
ref_path = snapshot_download(
repo_id=LTX2_REPO,
local_files_only=True,
allow_patterns=["*"],
ignore_patterns=[],
# leave as default revision and cache_dir, only local
)
return ref_path
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
# If not present locally, download from hub
return snapshot_download(
repo_id=LTX2_REPO,
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
ignore_patterns=[]
)
LTX2_PATH = Path(get_ltx2_cache_dir())
MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors')
UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')
TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder')
TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer')
# Distilled sigma schedules (from PyTorch)
STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps
))
def denoise_loop(
def denoise(
latents: mx.array,
positions: mx.array,
text_embeddings: mx.array,
transformer: LTXModel,
sigma_schedule: list,
stage_name: str = "Stage",
negative_embeddings: mx.array = None,
cfg_scale: float = 1.0,
sigmas: list,
) -> mx.array:
"""Run denoising loop for given sigma schedule.
Args:
latents: Noisy latent tensor
positions: Position embeddings
text_embeddings: Positive prompt embeddings
transformer: The transformer model
sigma_schedule: List of sigma values for each step
stage_name: Name for logging
negative_embeddings: Negative prompt embeddings for CFG (optional)
cfg_scale: Classifier-free guidance scale (1.0 = no guidance)
"""
use_cfg = negative_embeddings is not None and cfg_scale > 1.0
for i in range(len(sigma_schedule) - 1):
sigma = sigma_schedule[i]
sigma_next = sigma_schedule[i + 1]
print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}")
"""Run denoising loop."""
for i in range(len(sigmas) - 1):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
timesteps = mx.full((1,), sigma)
# Positive (conditioned) prediction
video_modality = Modality(
latent=latents_flat,
timesteps=timesteps,
timesteps=mx.full((1,), sigma),
positions=positions,
context=text_embeddings,
context_mask=None,
enabled=True,
)
vx_cond, _ = transformer(video=video_modality, audio=None)
mx.eval(vx_cond)
velocity, _ = transformer(video=video_modality, audio=None)
mx.eval(velocity)
if use_cfg:
# Negative (unconditioned) prediction
video_modality_neg = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=negative_embeddings,
context_mask=None,
enabled=True,
)
vx_uncond, _ = transformer(video=video_modality_neg, audio=None)
mx.eval(vx_uncond)
# CFG: output = uncond + cfg_scale * (cond - uncond)
vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond)
else:
vx = vx_cond
vx_reshaped = mx.transpose(vx, (0, 2, 1))
vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w))
# Debug: Print velocity stats
vx_np = np.array(vx_reshaped)
print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}")
# Get denoised prediction: x_0 = x_t - sigma * velocity
denoised = to_denoised(latents, vx_reshaped, sigma)
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
mx.eval(denoised)
# Debug: Print denoised stats
denoised_np = np.array(denoised)
print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}")
# Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma
if sigma_next > 0:
velocity = (latents - denoised) / sigma
latents = denoised + sigma_next * velocity
latents = denoised + sigma_next * (latents - denoised) / sigma
else:
latents = denoised
mx.eval(latents)
# Debug: Print latents after step
latents_np = np.array(latents)
print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
return latents
def main():
print("="*60)
print("MLX LTX-2 Video Generation (Two-Stage)")
print("="*60)
def generate_video(
model_repo: str,
prompt: str,
height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
):
"""Generate video from text prompt.
# Config - same as PyTorch reference
prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic"
negative_prompt = "" # PyTorch script doesn't use negative prompt
cfg_scale = 1.0 # No CFG in the distilled pipeline
height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage
seed = 123
Args:
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
seed: Random seed for reproducibility
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
"""
start_time = time.time()
# Stage 1: Half resolution
stage1_height = height // 2
stage1_width = width // 2
stage1_latent_height = stage1_height // 32
stage1_latent_width = stage1_width // 32
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
print(f"Generating {width}x{height} video with {num_frames} frames")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
# Get model path
model_path = get_model_path(model_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
# Stage 2: Full resolution
latent_height = height // 32
latent_width = width // 32
print(f"\nConfig:")
print(f" Prompt: {prompt}")
print(f" Negative prompt: '{negative_prompt}'")
print(f" CFG scale: {cfg_scale}")
print(f" Final resolution: {width}x{height}, {num_frames} frames")
print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}")
print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}")
print(f" Seed: {seed}")
mx.random.seed(seed)
# Load text encoder
print("\nLoading text encoder...")
print("Loading text encoder...")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH))
text_encoder.load(str(LTX2_PATH))
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
mx.eval(text_encoder.parameters())
# Encode positive prompt
print("Encoding text...")
text_embeddings, attention_mask = text_encoder(prompt)
text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings)
print(f" Positive embeddings: {text_embeddings.shape}")
# Encode negative prompt for CFG
negative_embeddings, _ = text_encoder(negative_prompt)
mx.eval(negative_embeddings)
print(f" Negative embeddings: {negative_embeddings.shape}")
# Free text encoder memory
del text_encoder
mx.clear_cache()
# Load transformer
print("\nLoading transformer...")
raw_weights = mx.load(MODEL_PATH)
print("Loading transformer...")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
@@ -231,164 +156,166 @@ def main():
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
print(" Transformer loaded!")
# ========================================
# Stage 1: Generate at half resolution
# ========================================
print("\n" + "="*60)
print("Stage 1: Generating at half resolution")
print("="*60)
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width))
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
print(f" Initial latents: {latents.shape}")
positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
latents = denoise_loop(
latents=latents,
positions=positions,
text_embeddings=text_embeddings,
transformer=transformer,
sigma_schedule=STAGE_1_SIGMA_SCHEDULE,
stage_name="Stage 1",
negative_embeddings=negative_embeddings,
cfg_scale=cfg_scale,
)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
print(f"\nStage 1 latents: {latents.shape}")
latents_np = np.array(latents)
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# ========================================
# Upsample latents 2x
# ========================================
print("\n" + "="*60)
print("Upsampling latents 2x")
print("="*60)
# Load upsampler
print(" Loading spatial upsampler...")
upsampler = load_upsampler(UPSAMPLER_PATH)
# Upsample latents
print("Upsampling latents 2x...")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
# Load latent statistics for normalization
vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True)
# EXPERIMENT: Disable VAE decode noise for sharper output
# vae_decoder.decode_noise_scale = 0.0
# print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}")
latent_mean = vae_decoder.latents_mean
latent_std = vae_decoder.latents_std
vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=True
)
# Upsample
print(" Upsampling...")
latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False)
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(latents)
print(f" Upsampled latents: {latents.shape}")
# Free upsampler memory
del upsampler
mx.clear_cache()
# ========================================
# Stage 2: Refine at full resolution
# ========================================
print("\n" + "="*60)
print("Stage 2: Refining at full resolution")
print("="*60)
# Debug: Print upsampled latent stats before adding noise
latents_np = np.array(latents)
print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# Create new position grid for full resolution
positions = create_position_grid(1, latent_frames, latent_height, latent_width)
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
# Add noise at initial sigma for stage 2
# PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale)
# NOT addition: noisy = clean + scale * noise
noise_scale = STAGE_2_SIGMA_SCHEDULE[0]
# Add noise for refinement
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
# Debug: Print latents after adding noise
latents_np = np.array(latents)
print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
latents = denoise_loop(
latents=latents,
positions=positions,
text_embeddings=text_embeddings,
transformer=transformer,
sigma_schedule=STAGE_2_SIGMA_SCHEDULE,
stage_name="Stage 2",
)
print(f"\nFinal latents: {latents.shape}")
latents_np = np.array(latents)
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
# Save latents for PyTorch comparison
np.save("mlx_final_latents.npy", latents_np)
print(" Saved latents to mlx_final_latents.npy")
# Free transformer memory
del transformer
mx.clear_cache()
# ========================================
# Decode to video
# ========================================
print("\n" + "="*60)
print("Decoding with VAE")
print("="*60)
# Decode latents to video
video = vae_decoder(latents, debug=True)
print("Decoding video...")
video = vae_decoder(latents)
mx.eval(video)
print(f" Video shape: {video.shape}")
mx.clear_cache()
# Convert to frames
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
# Debug: check raw RGB values before conversion
video_raw = np.array(video)
print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}")
print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]")
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1]
video = mx.clip(video, 0.0, 1.0)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}")
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Save first frame
output_path = Path("mlx_output_frame0_2.png")
Image.fromarray(video_np[0]).save(output_path)
print(f"\nSaved first frame to {output_path}")
# Save video
try:
import imageio
video_path = "mlx_output_video_2.mp4"
imageio.mimwrite(video_path, video_np, fps=24, codec='libx264')
print(f"Saved video to {video_path}")
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
print(f"Saved video to {output_path}")
except Exception as e:
print(f"Could not save video: {e}")
print("\nDone!")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}")
elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
return video_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python main.py --prompt "A cat walking on grass"
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
"""
)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text description of the video to generate"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Output video height (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Output video width (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--num-frames", "-n",
type=int,
default=33,
help="Number of frames (default: 33, must be 1 + 8*k)"
)
parser.add_argument(
"--seed", "-s",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=24,
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output", "-o",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
)
parser.add_argument(
"--save-frames",
action="store_true",
help="Save individual frames as images"
)
parser.add_argument(
"--model-repo",
type=str,
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
args = parser.parse_args()
generate_video(
model_repo=args.model_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
)
if __name__ == "__main__":
import time
start_time = time.time()
main()
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")