import mlx.core as mx import numpy as np from pathlib import Path from PIL import Image from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality from mlx_video.convert import sanitize_transformer_weights from mlx_video.generate import create_position_grid from mlx_video.utils import to_denoised from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents # Paths from huggingface_hub import snapshot_download from pathlib import Path import os LTX2_REPO = "Lightricks/LTX-2" def get_ltx2_cache_dir(): # Try to get local cache (local_only), will not download files try: ref_path = snapshot_download( repo_id=LTX2_REPO, local_files_only=True, allow_patterns=["*"], ignore_patterns=[], # leave as default revision and cache_dir, only local ) return ref_path except Exception: # If not present locally, download from hub return snapshot_download( repo_id=LTX2_REPO, local_files_only=False, resume_download=True, allow_patterns=["*.safetensors", "*.json"], ignore_patterns=[] ) LTX2_PATH = Path(get_ltx2_cache_dir()) MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors') UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors') TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder') TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer') # Distilled sigma schedules (from PyTorch) STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps def denoise_loop( latents: mx.array, positions: mx.array, text_embeddings: mx.array, transformer: LTXModel, sigma_schedule: list, stage_name: str = "Stage", negative_embeddings: mx.array = None, cfg_scale: float = 1.0, ) -> mx.array: """Run denoising loop for given sigma schedule. Args: latents: Noisy latent tensor positions: Position embeddings text_embeddings: Positive prompt embeddings transformer: The transformer model sigma_schedule: List of sigma values for each step stage_name: Name for logging negative_embeddings: Negative prompt embeddings for CFG (optional) cfg_scale: Classifier-free guidance scale (1.0 = no guidance) """ use_cfg = negative_embeddings is not None and cfg_scale > 1.0 for i in range(len(sigma_schedule) - 1): sigma = sigma_schedule[i] sigma_next = sigma_schedule[i + 1] print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}") b, c, f, h, w = latents.shape latents_flat = mx.reshape(latents, (b, c, -1)) latents_flat = mx.transpose(latents_flat, (0, 2, 1)) timesteps = mx.full((1,), sigma) # Positive (conditioned) prediction video_modality = Modality( latent=latents_flat, timesteps=timesteps, positions=positions, context=text_embeddings, context_mask=None, enabled=True, ) vx_cond, _ = transformer(video=video_modality, audio=None) mx.eval(vx_cond) if use_cfg: # Negative (unconditioned) prediction video_modality_neg = Modality( latent=latents_flat, timesteps=timesteps, positions=positions, context=negative_embeddings, context_mask=None, enabled=True, ) vx_uncond, _ = transformer(video=video_modality_neg, audio=None) mx.eval(vx_uncond) # CFG: output = uncond + cfg_scale * (cond - uncond) vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond) else: vx = vx_cond vx_reshaped = mx.transpose(vx, (0, 2, 1)) vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w)) # Debug: Print velocity stats vx_np = np.array(vx_reshaped) print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}") # Get denoised prediction: x_0 = x_t - sigma * velocity denoised = to_denoised(latents, vx_reshaped, sigma) mx.eval(denoised) # Debug: Print denoised stats denoised_np = np.array(denoised) print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}") # Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma if sigma_next > 0: velocity = (latents - denoised) / sigma latents = denoised + sigma_next * velocity else: latents = denoised mx.eval(latents) # Debug: Print latents after step latents_np = np.array(latents) print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}") return latents def main(): print("="*60) print("MLX LTX-2 Video Generation (Two-Stage)") print("="*60) # Config - same as PyTorch reference prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic" negative_prompt = "" # PyTorch script doesn't use negative prompt cfg_scale = 1.0 # No CFG in the distilled pipeline height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage seed = 123 # Stage 1: Half resolution stage1_height = height // 2 stage1_width = width // 2 stage1_latent_height = stage1_height // 32 stage1_latent_width = stage1_width // 32 latent_frames = 1 + (num_frames - 1) // 8 # Stage 2: Full resolution latent_height = height // 32 latent_width = width // 32 print(f"\nConfig:") print(f" Prompt: {prompt}") print(f" Negative prompt: '{negative_prompt}'") print(f" CFG scale: {cfg_scale}") print(f" Final resolution: {width}x{height}, {num_frames} frames") print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}") print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}") print(f" Seed: {seed}") mx.random.seed(seed) # Load text encoder print("\nLoading text encoder...") from mlx_video.models.ltx.text_encoder import LTX2TextEncoder text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH)) text_encoder.load(str(LTX2_PATH)) mx.eval(text_encoder.parameters()) # Encode positive prompt print("Encoding text...") text_embeddings, attention_mask = text_encoder(prompt) mx.eval(text_embeddings) print(f" Positive embeddings: {text_embeddings.shape}") # Encode negative prompt for CFG negative_embeddings, _ = text_encoder(negative_prompt) mx.eval(negative_embeddings) print(f" Negative embeddings: {negative_embeddings.shape}") # Free text encoder memory del text_encoder mx.clear_cache() # Load transformer print("\nLoading transformer...") raw_weights = mx.load(MODEL_PATH) sanitized = sanitize_transformer_weights(raw_weights) config = LTXModelConfig( model_type=LTXModelType.VideoOnly, num_attention_heads=32, attention_head_dim=128, in_channels=128, out_channels=128, num_layers=48, cross_attention_dim=4096, caption_channels=3840, rope_type=LTXRopeType.SPLIT, double_precision_rope=True, positional_embedding_theta=10000.0, positional_embedding_max_pos=[20, 2048, 2048], use_middle_indices_grid=True, timestep_scale_multiplier=1000, ) transformer = LTXModel(config) transformer.load_weights(list(sanitized.items()), strict=False) mx.eval(transformer.parameters()) print(" Transformer loaded!") # ======================================== # Stage 1: Generate at half resolution # ======================================== print("\n" + "="*60) print("Stage 1: Generating at half resolution") print("="*60) mx.random.seed(seed) latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width)) mx.eval(latents) print(f" Initial latents: {latents.shape}") positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width) mx.eval(positions) latents = denoise_loop( latents=latents, positions=positions, text_embeddings=text_embeddings, transformer=transformer, sigma_schedule=STAGE_1_SIGMA_SCHEDULE, stage_name="Stage 1", negative_embeddings=negative_embeddings, cfg_scale=cfg_scale, ) print(f"\nStage 1 latents: {latents.shape}") latents_np = np.array(latents) print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}") # ======================================== # Upsample latents 2x # ======================================== print("\n" + "="*60) print("Upsampling latents 2x") print("="*60) # Load upsampler print(" Loading spatial upsampler...") upsampler = load_upsampler(UPSAMPLER_PATH) mx.eval(upsampler.parameters()) # Load latent statistics for normalization vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True) # EXPERIMENT: Disable VAE decode noise for sharper output # vae_decoder.decode_noise_scale = 0.0 # print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}") latent_mean = vae_decoder.latents_mean latent_std = vae_decoder.latents_std # Upsample print(" Upsampling...") latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False) mx.eval(latents) print(f" Upsampled latents: {latents.shape}") # Free upsampler memory del upsampler mx.clear_cache() # ======================================== # Stage 2: Refine at full resolution # ======================================== print("\n" + "="*60) print("Stage 2: Refining at full resolution") print("="*60) # Debug: Print upsampled latent stats before adding noise latents_np = np.array(latents) print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}") # Create new position grid for full resolution positions = create_position_grid(1, latent_frames, latent_height, latent_width) mx.eval(positions) # Add noise at initial sigma for stage 2 # PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale) # NOT addition: noisy = clean + scale * noise noise_scale = STAGE_2_SIGMA_SCHEDULE[0] noise = mx.random.normal(latents.shape) latents = noise * noise_scale + latents * (1 - noise_scale) mx.eval(latents) # Debug: Print latents after adding noise latents_np = np.array(latents) print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}") latents = denoise_loop( latents=latents, positions=positions, text_embeddings=text_embeddings, transformer=transformer, sigma_schedule=STAGE_2_SIGMA_SCHEDULE, stage_name="Stage 2", ) print(f"\nFinal latents: {latents.shape}") latents_np = np.array(latents) print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}") # Save latents for PyTorch comparison np.save("mlx_final_latents.npy", latents_np) print(" Saved latents to mlx_final_latents.npy") # Free transformer memory del transformer mx.clear_cache() # ======================================== # Decode to video # ======================================== print("\n" + "="*60) print("Decoding with VAE") print("="*60) # Decode latents to video video = vae_decoder(latents, debug=True) mx.eval(video) print(f" Video shape: {video.shape}") # Convert to frames video = mx.squeeze(video, axis=0) # (C, F, H, W) # Debug: check raw RGB values before conversion video_raw = np.array(video) print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}") print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]") video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1] video = mx.clip(video, 0.0, 1.0) video = (video * 255).astype(mx.uint8) video_np = np.array(video) print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}") # Save first frame output_path = Path("mlx_output_frame0_2.png") Image.fromarray(video_np[0]).save(output_path) print(f"\nSaved first frame to {output_path}") # Save video try: import imageio video_path = "mlx_output_video_2.mp4" imageio.mimwrite(video_path, video_np, fps=24, codec='libx264') print(f"Saved video to {video_path}") except Exception as e: print(f"Could not save video: {e}") print("\nDone!") if __name__ == "__main__": import time start_time = time.time() main() end_time = time.time() print(f"Time taken: {end_time - start_time} seconds")