import argparse import time from pathlib import Path import mlx.core as mx import numpy as np from PIL import Image from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality from mlx_video.convert import sanitize_transformer_weights from mlx_video.generate import create_position_grid from mlx_video.utils import to_denoised from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from huggingface_hub import snapshot_download # Distilled sigma schedules STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0] def get_model_path(model_repo: str): """Get or download LTX-2 model path.""" try: return Path(snapshot_download(repo_id=model_repo, local_files_only=True)) except Exception: print("Downloading LTX-2 model weights...") return Path(snapshot_download( repo_id=model_repo, local_files_only=False, resume_download=True, allow_patterns=["*.safetensors", "*.json"], )) def denoise( latents: mx.array, positions: mx.array, text_embeddings: mx.array, transformer: LTXModel, sigmas: list, ) -> mx.array: """Run denoising loop.""" for i in range(len(sigmas) - 1): sigma, sigma_next = sigmas[i], sigmas[i + 1] b, c, f, h, w = latents.shape latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) video_modality = Modality( latent=latents_flat, timesteps=mx.full((1,), sigma), positions=positions, context=text_embeddings, context_mask=None, enabled=True, ) velocity, _ = transformer(video=video_modality, audio=None) mx.eval(velocity) velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) denoised = to_denoised(latents, velocity, sigma) mx.eval(denoised) if sigma_next > 0: latents = denoised + sigma_next * (latents - denoised) / sigma else: latents = denoised mx.eval(latents) return latents def generate_video( model_repo: str, prompt: str, height: int = 512, width: int = 512, num_frames: int = 33, seed: int = 42, fps: int = 24, output_path: str = "output.mp4", save_frames: bool = False, ): """Generate video from text prompt. Args: prompt: Text description of the video to generate height: Output video height (must be divisible by 64) width: Output video width (must be divisible by 64) num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97) seed: Random seed for reproducibility fps: Frames per second for output video output_path: Path to save the output video save_frames: Whether to save individual frames as images """ start_time = time.time() # Validate dimensions assert height % 64 == 0, f"Height must be divisible by 64, got {height}" assert width % 64 == 0, f"Width must be divisible by 64, got {width}" print(f"Generating {width}x{height} video with {num_frames} frames") print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}") # Get model path model_path = get_model_path(model_repo) # Calculate latent dimensions stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 stage2_h, stage2_w = height // 32, width // 32 latent_frames = 1 + (num_frames - 1) // 8 mx.random.seed(seed) # Load text encoder print("Loading text encoder...") from mlx_video.models.ltx.text_encoder import LTX2TextEncoder text_encoder = LTX2TextEncoder(model_path=str(model_path)) text_encoder.load(str(model_path)) mx.eval(text_encoder.parameters()) text_embeddings, _ = text_encoder(prompt) mx.eval(text_embeddings) del text_encoder mx.clear_cache() # Load transformer print("Loading transformer...") raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) config = LTXModelConfig( model_type=LTXModelType.VideoOnly, num_attention_heads=32, attention_head_dim=128, in_channels=128, out_channels=128, num_layers=48, cross_attention_dim=4096, caption_channels=3840, rope_type=LTXRopeType.SPLIT, double_precision_rope=True, positional_embedding_theta=10000.0, positional_embedding_max_pos=[20, 2048, 2048], use_middle_indices_grid=True, timestep_scale_multiplier=1000, ) transformer = LTXModel(config) transformer.load_weights(list(sanitized.items()), strict=False) mx.eval(transformer.parameters()) # Stage 1: Generate at half resolution print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...") mx.random.seed(seed) latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) mx.eval(latents) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS) # Upsample latents print("Upsampling latents 2x...") upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) mx.eval(upsampler.parameters()) vae_decoder = load_vae_decoder( str(model_path / 'ltx-2-19b-distilled.safetensors'), timestep_conditioning=True ) latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) mx.eval(latents) del upsampler mx.clear_cache() # Stage 2: Refine at full resolution print(f"Stage 2: Refining at {width}x{height} (3 steps)...") positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) # Add noise for refinement noise_scale = STAGE_2_SIGMAS[0] noise = mx.random.normal(latents.shape) latents = noise * noise_scale + latents * (1 - noise_scale) mx.eval(latents) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS) del transformer mx.clear_cache() # Decode to video print("Decoding video...") video = vae_decoder(latents) mx.eval(video) mx.clear_cache() # Convert to uint8 frames video = mx.squeeze(video, axis=0) # (C, F, H, W) video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) video = (video * 255).astype(mx.uint8) video_np = np.array(video) # Save outputs output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) try: import imageio imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264') print(f"Saved video to {output_path}") except Exception as e: print(f"Could not save video: {e}") if save_frames: frames_dir = output_path.parent / f"{output_path.stem}_frames" frames_dir.mkdir(exist_ok=True) for i, frame in enumerate(video_np): Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png") print(f"Saved {len(video_np)} frames to {frames_dir}") elapsed = time.time() - start_time print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)") return video_np def main(): parser = argparse.ArgumentParser( description="Generate videos with MLX LTX-2", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python main.py --prompt "A cat walking on grass" python main.py --prompt "Ocean waves at sunset" --height 768 --width 768 python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4 """ ) parser.add_argument( "--prompt", "-p", type=str, required=True, help="Text description of the video to generate" ) parser.add_argument( "--height", "-H", type=int, default=512, help="Output video height (default: 512, must be divisible by 32)" ) parser.add_argument( "--width", "-W", type=int, default=512, help="Output video width (default: 512, must be divisible by 32)" ) parser.add_argument( "--num-frames", "-n", type=int, default=33, help="Number of frames (default: 33, must be 1 + 8*k)" ) parser.add_argument( "--seed", "-s", type=int, default=42, help="Random seed for reproducibility (default: 42)" ) parser.add_argument( "--fps", type=int, default=24, help="Frames per second for output video (default: 24)" ) parser.add_argument( "--output", "-o", type=str, default="output.mp4", help="Output video path (default: output.mp4)" ) parser.add_argument( "--save-frames", action="store_true", help="Save individual frames as images" ) parser.add_argument( "--model-repo", type=str, default="Lightricks/LTX-2", help="Model repository to use (default: Lightricks/LTX-2)" ) args = parser.parse_args() generate_video( model_repo=args.model_repo, prompt=args.prompt, height=args.height, width=args.width, num_frames=args.num_frames, seed=args.seed, fps=args.fps, output_path=args.output, save_frames=args.save_frames, ) if __name__ == "__main__": main()