diff --git a/main.py b/main.py deleted file mode 100644 index 2c4cabd..0000000 --- a/main.py +++ /dev/null @@ -1,321 +0,0 @@ -import argparse -import time -from pathlib import Path - -import mlx.core as mx -import numpy as np -from PIL import Image - -from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType -from mlx_video.models.ltx.ltx import LTXModel -from mlx_video.models.ltx.transformer import Modality -from mlx_video.convert import sanitize_transformer_weights -from mlx_video.generate import create_position_grid -from mlx_video.utils import to_denoised -from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder -from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents - -from huggingface_hub import snapshot_download - - -# Distilled sigma schedules -STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] -STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0] - - -def get_model_path(model_repo: str): - """Get or download LTX-2 model path.""" - try: - return Path(snapshot_download(repo_id=model_repo, local_files_only=True)) - except Exception: - print("Downloading LTX-2 model weights...") - return Path(snapshot_download( - repo_id=model_repo, - local_files_only=False, - resume_download=True, - allow_patterns=["*.safetensors", "*.json"], - )) - - -def denoise( - latents: mx.array, - positions: mx.array, - text_embeddings: mx.array, - transformer: LTXModel, - sigmas: list, -) -> mx.array: - """Run denoising loop.""" - for i in range(len(sigmas) - 1): - sigma, sigma_next = sigmas[i], sigmas[i + 1] - - b, c, f, h, w = latents.shape - latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) - - video_modality = Modality( - latent=latents_flat, - timesteps=mx.full((1,), sigma), - positions=positions, - context=text_embeddings, - context_mask=None, - enabled=True, - ) - - velocity, _ = transformer(video=video_modality, audio=None) - mx.eval(velocity) - - velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) - denoised = to_denoised(latents, velocity, sigma) - mx.eval(denoised) - - if sigma_next > 0: - latents = denoised + sigma_next * (latents - denoised) / sigma - else: - latents = denoised - mx.eval(latents) - - return latents - - -def generate_video( - model_repo: str, - prompt: str, - height: int = 512, - width: int = 512, - num_frames: int = 33, - seed: int = 42, - fps: int = 24, - output_path: str = "output.mp4", - save_frames: bool = False, -): - """Generate video from text prompt. - - Args: - prompt: Text description of the video to generate - height: Output video height (must be divisible by 64) - width: Output video width (must be divisible by 64) - num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97) - seed: Random seed for reproducibility - fps: Frames per second for output video - output_path: Path to save the output video - save_frames: Whether to save individual frames as images - """ - start_time = time.time() - - # Validate dimensions - assert height % 64 == 0, f"Height must be divisible by 64, got {height}" - assert width % 64 == 0, f"Width must be divisible by 64, got {width}" - - print(f"Generating {width}x{height} video with {num_frames} frames") - print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}") - - # Get model path - model_path = get_model_path(model_repo) - - # Calculate latent dimensions - stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 - stage2_h, stage2_w = height // 32, width // 32 - latent_frames = 1 + (num_frames - 1) // 8 - - mx.random.seed(seed) - - # Load text encoder - print("Loading text encoder...") - from mlx_video.models.ltx.text_encoder import LTX2TextEncoder - text_encoder = LTX2TextEncoder(model_path=str(model_path)) - text_encoder.load(str(model_path)) - mx.eval(text_encoder.parameters()) - - text_embeddings, _ = text_encoder(prompt) - mx.eval(text_embeddings) - - del text_encoder - mx.clear_cache() - - # Load transformer - print("Loading transformer...") - raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) - sanitized = sanitize_transformer_weights(raw_weights) - - config = LTXModelConfig( - model_type=LTXModelType.VideoOnly, - num_attention_heads=32, - attention_head_dim=128, - in_channels=128, - out_channels=128, - num_layers=48, - cross_attention_dim=4096, - caption_channels=3840, - rope_type=LTXRopeType.SPLIT, - double_precision_rope=True, - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - use_middle_indices_grid=True, - timestep_scale_multiplier=1000, - ) - - transformer = LTXModel(config) - transformer.load_weights(list(sanitized.items()), strict=False) - mx.eval(transformer.parameters()) - - # Stage 1: Generate at half resolution - print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...") - mx.random.seed(seed) - latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) - mx.eval(latents) - - positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) - mx.eval(positions) - - latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS) - - # Upsample latents - print("Upsampling latents 2x...") - upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) - mx.eval(upsampler.parameters()) - - vae_decoder = load_vae_decoder( - str(model_path / 'ltx-2-19b-distilled.safetensors'), - timestep_conditioning=True - ) - - latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) - mx.eval(latents) - - del upsampler - mx.clear_cache() - - # Stage 2: Refine at full resolution - print(f"Stage 2: Refining at {width}x{height} (3 steps)...") - positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) - mx.eval(positions) - - # Add noise for refinement - noise_scale = STAGE_2_SIGMAS[0] - noise = mx.random.normal(latents.shape) - latents = noise * noise_scale + latents * (1 - noise_scale) - mx.eval(latents) - - latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS) - - del transformer - mx.clear_cache() - - # Decode to video - print("Decoding video...") - video = vae_decoder(latents) - mx.eval(video) - mx.clear_cache() - - # Convert to uint8 frames - video = mx.squeeze(video, axis=0) # (C, F, H, W) - video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) - video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) - video = (video * 255).astype(mx.uint8) - video_np = np.array(video) - - # Save outputs - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - try: - import imageio - imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264') - print(f"Saved video to {output_path}") - except Exception as e: - print(f"Could not save video: {e}") - - if save_frames: - frames_dir = output_path.parent / f"{output_path.stem}_frames" - frames_dir.mkdir(exist_ok=True) - for i, frame in enumerate(video_np): - Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png") - print(f"Saved {len(video_np)} frames to {frames_dir}") - - elapsed = time.time() - start_time - print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)") - - return video_np - - -def main(): - parser = argparse.ArgumentParser( - description="Generate videos with MLX LTX-2", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - python main.py --prompt "A cat walking on grass" - python main.py --prompt "Ocean waves at sunset" --height 768 --width 768 - python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4 - """ - ) - - parser.add_argument( - "--prompt", "-p", - type=str, - required=True, - help="Text description of the video to generate" - ) - parser.add_argument( - "--height", "-H", - type=int, - default=512, - help="Output video height (default: 512, must be divisible by 32)" - ) - parser.add_argument( - "--width", "-W", - type=int, - default=512, - help="Output video width (default: 512, must be divisible by 32)" - ) - parser.add_argument( - "--num-frames", "-n", - type=int, - default=33, - help="Number of frames (default: 33, must be 1 + 8*k)" - ) - parser.add_argument( - "--seed", "-s", - type=int, - default=42, - help="Random seed for reproducibility (default: 42)" - ) - parser.add_argument( - "--fps", - type=int, - default=24, - help="Frames per second for output video (default: 24)" - ) - parser.add_argument( - "--output", "-o", - type=str, - default="output.mp4", - help="Output video path (default: output.mp4)" - ) - parser.add_argument( - "--save-frames", - action="store_true", - help="Save individual frames as images" - ) - parser.add_argument( - "--model-repo", - type=str, - default="Lightricks/LTX-2", - help="Model repository to use (default: Lightricks/LTX-2)" - ) - args = parser.parse_args() - - generate_video( - model_repo=args.model_repo, - prompt=args.prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - seed=args.seed, - fps=args.fps, - output_path=args.output, - save_frames=args.save_frames, - ) - - -if __name__ == "__main__": - main() diff --git a/mlx_video/__init__.py b/mlx_video/__init__.py index b6e61f1..3123d7f 100644 --- a/mlx_video/__init__.py +++ b/mlx_video/__init__.py @@ -1,13 +1,9 @@ - from mlx_video.models.ltx import LTXModel, LTXModelConfig -from mlx_video.generate import LTXVideoPipeline, GenerationConfig from mlx_video.convert import load_transformer_weights, load_vae_weights __all__ = [ "LTXModel", "LTXModelConfig", - "LTXVideoPipeline", - "GenerationConfig", "load_transformer_weights", "load_vae_weights", ] \ No newline at end of file diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 530582c..b6e4daa 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1,114 +1,25 @@ -from dataclasses import dataclass +import argparse +import time from pathlib import Path -from typing import List, Optional, Tuple, Iterator, Union import mlx.core as mx import numpy as np +from PIL import Image -from mlx_video.models.ltx.ltx import LTXModel, X0Model +from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType +from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality -from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder -from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder +from mlx_video.convert import sanitize_transformer_weights +from mlx_video.utils import to_denoised +from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder +from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents + +from huggingface_hub import snapshot_download -@dataclass -class GenerationConfig: - """Configuration for video generation.""" - # Video dimensions - height: int = 512 - width: int = 512 - num_frames: int = 33 # Must be 1 + 8*k - - # Diffusion parameters - num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True) - guidance_scale: float = 3.0 - use_distilled: bool = True # Use hardcoded sigma values for distilled model - - # Latent dimensions (computed from video dimensions) - @property - def latent_height(self) -> int: - return self.height // 32 - - @property - def latent_width(self) -> int: - return self.width // 32 - - @property - def latent_frames(self) -> int: - return 1 + (self.num_frames - 1) // 8 - - -# Hardcoded sigma values for distilled model (from LTX-2 pipeline) -# These were tuned to match the distillation process -DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] -STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0] - -# Scheduler constants for dynamic sigma computation (non-distilled models) -BASE_SHIFT_ANCHOR = 1024 -MAX_SHIFT_ANCHOR = 4096 - - -def get_sigmas( - num_steps: int, - num_tokens: int, - max_shift: float = 2.05, - base_shift: float = 0.95, - stretch: bool = True, - terminal: float = 0.1, - use_distilled: bool = True, -) -> mx.array: - """Get sigma schedule for diffusion. - - Args: - num_steps: Number of diffusion steps - num_tokens: Number of latent tokens (T * H * W) - max_shift: Maximum shift for sigma schedule - base_shift: Base shift for sigma schedule - stretch: Whether to stretch sigmas to terminal value - terminal: Terminal value for stretching - use_distilled: If True, use hardcoded distilled sigma values - - Returns: - Array of sigma values - """ - import math - - # For distilled model, use hardcoded sigma values - if use_distilled: - return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32) - - # For non-distilled models, compute dynamically using LTX2Scheduler logic - # Linear base schedule - sigmas = mx.linspace(1.0, 0.0, num_steps + 1) - - # Compute token-dependent sigma shift - x1 = BASE_SHIFT_ANCHOR - x2 = MAX_SHIFT_ANCHOR - mm = (max_shift - base_shift) / (x2 - x1) - b = base_shift - mm * x1 - sigma_shift = num_tokens * mm + b - - # Apply exponential transformation - # sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1) - power = 1 - exp_shift = math.exp(sigma_shift) - - # Convert to numpy for computation then back to mx - sigmas_np = np.array(sigmas) - result = np.zeros_like(sigmas_np) - non_zero = sigmas_np != 0 - result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power) - - # Stretch sigmas so final value matches terminal - if stretch: - non_zero_mask = result != 0 - non_zero_sigmas = result[non_zero_mask] - one_minus_z = 1.0 - non_zero_sigmas - scale_factor = one_minus_z[-1] / (1.0 - terminal) - stretched = 1.0 - (one_minus_z / scale_factor) - result[non_zero_mask] = stretched - - return mx.array(result, dtype=mx.float32) +# Distilled sigma schedules +STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0] +STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0] def create_position_grid( @@ -141,7 +52,6 @@ def create_position_grid( patch_size_t, patch_size_h, patch_size_w = 1, 1, 1 # Generate grid coordinates for each dimension (frame, height, width) - # These are the starting coordinates for each patch in latent space t_coords = np.arange(0, num_frames, patch_size_t) h_coords = np.arange(0, height, patch_size_h) w_coords = np.arange(0, width, patch_size_w) @@ -173,7 +83,6 @@ def create_position_grid( # Apply causal fix for first frame temporal axis if causal_fix: # VAE temporal stride for first frame is 1 instead of temporal_scale - # Shift and clamp to keep first-frame timestamps non-negative pixel_coords[:, 0, :, :] = np.clip( pixel_coords[:, 0, :, :] + 1 - temporal_scale, a_min=0, @@ -186,307 +95,117 @@ def create_position_grid( return mx.array(pixel_coords, dtype=mx.float32) -class LTXVideoPipeline: +def get_model_path(model_repo: str): + """Get or download LTX-2 model path.""" + try: + return Path(snapshot_download(repo_id=model_repo, local_files_only=True)) + except Exception: + print("Downloading LTX-2 model weights...") + return Path(snapshot_download( + repo_id=model_repo, + local_files_only=False, + resume_download=True, + allow_patterns=["*.safetensors", "*.json"], + )) - def __init__( - self, - transformer: LTXModel, - text_encoder: Optional[LTX2TextEncoder] = None, - tokenizer: Optional[any] = None, - vae_encoder: Optional[VideoEncoder] = None, - vae_decoder: Optional[VideoDecoder] = None, - ): - """Initialize pipeline. - Args: - transformer: LTX transformer model - text_encoder: Optional LTX text encoder - tokenizer: Optional tokenizer for text encoding - vae_encoder: Optional VAE encoder - vae_decoder: Optional VAE decoder - """ - self.transformer = transformer - self.text_encoder = text_encoder - self.tokenizer = tokenizer - self.vae_encoder = vae_encoder - self.vae_decoder = vae_decoder - self.x0_model = X0Model(transformer) +def denoise( + latents: mx.array, + positions: mx.array, + text_embeddings: mx.array, + transformer: LTXModel, + sigmas: list, +) -> mx.array: + """Run denoising loop.""" + for i in range(len(sigmas) - 1): + sigma, sigma_next = sigmas[i], sigmas[i + 1] - def prepare_latents( - self, - batch_size: int, - num_frames: int, - height: int, - width: int, - dtype: mx.Dtype = mx.float16, - ) -> mx.array: - """Prepare initial noise latents. - - Args: - batch_size: Batch size - num_frames: Number of latent frames - height: Latent height - width: Latent width - dtype: Data type - - Returns: - Random latent noise - """ - # Use in_channels from transformer config - in_channels = self.transformer.config.in_channels - shape = (batch_size, in_channels, num_frames, height, width) - latents = mx.random.normal(shape).astype(dtype) - return latents - - def prepare_text_embeddings( - self, - prompt: Union[str, List[str]], - batch_size: int, - max_length: int = 1024, - ) -> Tuple[mx.array, Optional[mx.array]]: - """Prepare text embeddings. - - Args: - prompt: Text prompt or list of prompts - batch_size: Batch size - max_length: Maximum sequence length for tokenization - - Returns: - Tuple of (text_embeddings, attention_mask) - """ - # If text encoder is available, use it - if self.text_encoder is not None and self.tokenizer is not None: - # Handle single or multiple prompts - if isinstance(prompt, str): - prompts = [prompt] * batch_size - else: - prompts = prompt - - # Tokenize - tokens = self.tokenizer( - prompts, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="np", - ) - - input_ids = mx.array(tokens["input_ids"]) - attention_mask = mx.array(tokens["attention_mask"]) - - # Encode - embeddings = self.text_encoder(input_ids, attention_mask) - mx.eval(embeddings) - - return embeddings, None # Connector handles masking internally - - # Fallback: random embeddings (for testing without text encoder) - print("Warning: No text encoder provided, using random embeddings") - seq_len = max_length + 128 # Account for learnable registers - embed_dim = self.transformer.config.caption_channels - - embeddings = mx.random.normal((batch_size, seq_len, embed_dim)) - mask = mx.ones((batch_size, seq_len)) - - return embeddings, mask - - def denoise_step( - self, - latents: mx.array, - sigma: float, - sigma_next: float, - text_embeddings: mx.array, - positions: mx.array, - text_mask: Optional[mx.array] = None, - ) -> mx.array: - """Perform one denoising step. - - Args: - latents: Current noisy latents - sigma: Current noise level - sigma_next: Next noise level - text_embeddings: Text conditioning - positions: Position grid for RoPE - text_mask: Optional attention mask for text - - Returns: - Denoised latents - """ - batch_size = latents.shape[0] - - # Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C) b, c, f, h, w = latents.shape - latents_flat = mx.reshape(latents, (b, c, -1)) - latents_flat = mx.transpose(latents_flat, (0, 2, 1)) + latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) - # Create timestep tensor - timesteps = mx.full((batch_size,), sigma) - - # Create video modality input video_modality = Modality( latent=latents_flat, - timesteps=timesteps, + timesteps=mx.full((1,), sigma), positions=positions, context=text_embeddings, - context_mask=text_mask, + context_mask=None, enabled=True, ) - # Run denoising - denoised_video, _ = self.x0_model(video=video_modality, audio=None) + velocity, _ = transformer(video=video_modality, audio=None) + mx.eval(velocity) - # Reshape back: (B, F*H*W, C) -> (B, C, F, H, W) - denoised_video = mx.transpose(denoised_video, (0, 2, 1)) - denoised_video = mx.reshape(denoised_video, (b, c, f, h, w)) + velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) + denoised = to_denoised(latents, velocity, sigma) + mx.eval(denoised) - # Euler step if sigma_next > 0: - # x_next = x0 + sigma_next * (x - x0) / sigma - noise = (latents - denoised_video) / sigma - latents = denoised_video + sigma_next * noise + latents = denoised + sigma_next * (latents - denoised) / sigma else: - latents = denoised_video + latents = denoised + mx.eval(latents) - return latents - - def __call__( - self, - prompt: str, - config: Optional[GenerationConfig] = None, - seed: Optional[int] = None, - ) -> mx.array: - """Generate video from text prompt. - - Args: - prompt: Text prompt - config: Generation configuration - seed: Random seed - - Returns: - Generated video tensor of shape (B, C, F, H, W) - """ - if config is None: - config = GenerationConfig() - - if seed is not None: - mx.random.seed(seed) - - batch_size = 1 - - # Prepare text embeddings - text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size) - - # Prepare initial latents - latents = self.prepare_latents( - batch_size=batch_size, - num_frames=config.latent_frames, - height=config.latent_height, - width=config.latent_width, - ) - - # Prepare position grid - positions = create_position_grid( - batch_size=batch_size, - num_frames=config.latent_frames, - height=config.latent_height, - width=config.latent_width, - ) - - # Get sigma schedule - num_tokens = config.latent_frames * config.latent_height * config.latent_width - sigmas = get_sigmas( - config.num_inference_steps, - num_tokens, - use_distilled=config.use_distilled, - ) - - # Denoising loop - for i in range(len(sigmas) - 1): - sigma = float(sigmas[i]) - sigma_next = float(sigmas[i + 1]) - - latents = self.denoise_step( - latents=latents, - sigma=sigma, - sigma_next=sigma_next, - text_embeddings=text_embeddings, - positions=positions, - text_mask=text_mask, - ) - - mx.eval(latents) - - # Decode latents to video - if self.vae_decoder is not None: - video = self.vae_decoder(latents) - else: - video = latents - - return video + return latents def generate_video( + model_repo: str, prompt: str, - transformer: LTXModel, - text_encoder: Optional[LTX2TextEncoder] = None, - tokenizer: Optional[any] = None, - vae_decoder: Optional[VideoDecoder] = None, - config: Optional[GenerationConfig] = None, - seed: Optional[int] = None, -) -> mx.array: + height: int = 512, + width: int = 512, + num_frames: int = 33, + seed: int = 42, + fps: int = 24, + output_path: str = "output.mp4", + save_frames: bool = False, +): """Generate video from text prompt. Args: - prompt: Text prompt - transformer: LTX transformer model - text_encoder: Optional text encoder - tokenizer: Optional tokenizer - vae_decoder: Optional VAE decoder - config: Generation configuration - seed: Random seed - - Returns: - Generated video tensor + prompt: Text description of the video to generate + height: Output video height (must be divisible by 64) + width: Output video width (must be divisible by 64) + num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97) + seed: Random seed for reproducibility + fps: Frames per second for output video + output_path: Path to save the output video + save_frames: Whether to save individual frames as images """ - pipeline = LTXVideoPipeline( - transformer=transformer, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae_decoder=vae_decoder, - ) + start_time = time.time() - return pipeline(prompt, config, seed) + # Validate dimensions + assert height % 64 == 0, f"Height must be divisible by 64, got {height}" + assert width % 64 == 0, f"Width must be divisible by 64, got {width}" + print(f"Generating {width}x{height} video with {num_frames} frames") + print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}") -def load_pipeline( - model_path: str, - text_encoder_path: Optional[str] = None, - tokenizer_path: Optional[str] = None, - load_text_encoder_weights: bool = True, -) -> LTXVideoPipeline: - """Load complete LTX-2 video generation pipeline. + # Get model path + model_path = get_model_path(model_repo) - Args: - model_path: Path to LTX-2 model weights (safetensors) - text_encoder_path: Path to text encoder weights directory - tokenizer_path: Path to tokenizer directory - load_text_encoder_weights: Whether to load text encoder weights + # Calculate latent dimensions + stage1_h, stage1_w = height // 2 // 32, width // 2 // 32 + stage2_h, stage2_w = height // 32, width // 32 + latent_frames = 1 + (num_frames - 1) // 8 - Returns: - Configured LTXVideoPipeline - """ - from transformers import AutoTokenizer + mx.random.seed(seed) - from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType - from mlx_video.models.ltx.ltx import LTXModel - from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder - from mlx_video.convert import sanitize_transformer_weights + # Load text encoder + print("Loading text encoder...") + from mlx_video.models.ltx.text_encoder import LTX2TextEncoder + text_encoder = LTX2TextEncoder(model_path=str(model_path)) + text_encoder.load(str(model_path)) + mx.eval(text_encoder.parameters()) - print("Loading LTX-2 pipeline...") + text_embeddings, _ = text_encoder(prompt) + mx.eval(text_embeddings) + + del text_encoder + mx.clear_cache() # Load transformer - print(" Loading transformer...") - raw_weights = mx.load(model_path) + print("Loading transformer...") + raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) config = LTXModelConfig( @@ -498,89 +217,177 @@ def load_pipeline( num_layers=48, cross_attention_dim=4096, caption_channels=3840, + rope_type=LTXRopeType.SPLIT, + double_precision_rope=True, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + timestep_scale_multiplier=1000, ) + transformer = LTXModel(config) transformer.load_weights(list(sanitized.items()), strict=False) - print(" Transformer loaded") + mx.eval(transformer.parameters()) - # Load VAE decoder - print(" Loading VAE decoder...") - vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True) - print(" VAE decoder loaded") + # Stage 1: Generate at half resolution + print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...") + mx.random.seed(seed) + latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) + mx.eval(latents) - # Load text encoder if paths provided - text_encoder = None - tokenizer = None + positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) + mx.eval(positions) - if load_text_encoder_weights and text_encoder_path is not None: - print(" Loading text encoder...") - text_encoder = load_text_encoder(model_path, text_encoder_path) - print(" Text encoder loaded") + latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS) - if tokenizer_path is not None: - print(" Loading tokenizer...") - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - print(" Tokenizer loaded") + # Upsample latents + print("Upsampling latents 2x...") + upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')) + mx.eval(upsampler.parameters()) - print("Pipeline ready!") - - return LTXVideoPipeline( - transformer=transformer, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae_decoder=vae_decoder, + vae_decoder = load_vae_decoder( + str(model_path / 'ltx-2-19b-distilled.safetensors'), + timestep_conditioning=True ) + latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std) + mx.eval(latents) -def video_to_numpy(video: mx.array) -> np.ndarray: - """Convert video tensor to numpy array. + del upsampler + mx.clear_cache() - Args: - video: Video tensor of shape (B, C, F, H, W) in range [-1, 1] + # Stage 2: Refine at full resolution + print(f"Stage 2: Refining at {width}x{height} (3 steps)...") + positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) + mx.eval(positions) - Returns: - Numpy array of shape (B, F, H, W, C) in range [0, 255] - """ - # Clamp to [-1, 1] - video = mx.clip(video, -1.0, 1.0) + # Add noise for refinement + noise_scale = STAGE_2_SIGMAS[0] + noise = mx.random.normal(latents.shape) + latents = noise * noise_scale + latents * (1 - noise_scale) + mx.eval(latents) - # Scale to [0, 255] - video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8) + latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS) - # Rearrange: (B, C, F, H, W) -> (B, F, H, W, C) - video = mx.transpose(video, (0, 2, 3, 4, 1)) + del transformer + mx.clear_cache() - return np.array(video) + # Decode to video + print("Decoding video...") + video = vae_decoder(latents) + mx.eval(video) + mx.clear_cache() + + # Convert to uint8 frames + video = mx.squeeze(video, axis=0) # (C, F, H, W) + video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C) + video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0) + video = (video * 255).astype(mx.uint8) + video_np = np.array(video) + + # Save outputs + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + try: + import imageio + imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264') + print(f"Saved video to {output_path}") + except Exception as e: + print(f"Could not save video: {e}") + + if save_frames: + frames_dir = output_path.parent / f"{output_path.stem}_frames" + frames_dir.mkdir(exist_ok=True) + for i, frame in enumerate(video_np): + Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png") + print(f"Saved {len(video_np)} frames to {frames_dir}") + + elapsed = time.time() - start_time + print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)") + + return video_np + + +def main(): + parser = argparse.ArgumentParser( + description="Generate videos with MLX LTX-2", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m mlx_video.generate --prompt "A cat walking on grass" + python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768 + python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4 + """ + ) + + parser.add_argument( + "--prompt", "-p", + type=str, + required=True, + help="Text description of the video to generate" + ) + parser.add_argument( + "--height", "-H", + type=int, + default=512, + help="Output video height (default: 512, must be divisible by 32)" + ) + parser.add_argument( + "--width", "-W", + type=int, + default=512, + help="Output video width (default: 512, must be divisible by 32)" + ) + parser.add_argument( + "--num-frames", "-n", + type=int, + default=33, + help="Number of frames (default: 33, must be 1 + 8*k)" + ) + parser.add_argument( + "--seed", "-s", + type=int, + default=42, + help="Random seed for reproducibility (default: 42)" + ) + parser.add_argument( + "--fps", + type=int, + default=24, + help="Frames per second for output video (default: 24)" + ) + parser.add_argument( + "--output", "-o", + type=str, + default="output.mp4", + help="Output video path (default: output.mp4)" + ) + parser.add_argument( + "--save-frames", + action="store_true", + help="Save individual frames as images" + ) + parser.add_argument( + "--model-repo", + type=str, + default="Lightricks/LTX-2", + help="Model repository to use (default: Lightricks/LTX-2)" + ) + args = parser.parse_args() + + generate_video( + model_repo=args.model_repo, + prompt=args.prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + seed=args.seed, + fps=args.fps, + output_path=args.output, + save_frames=args.save_frames, + ) if __name__ == "__main__": - # Example usage - from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType - - # Create a small test config - config = LTXModelConfig( - model_type=LTXModelType.VideoOnly, - num_layers=2, # Reduced for testing - num_attention_heads=4, - attention_head_dim=32, - ) - - # Create model - model = LTXModel(config) - - # Generate video - gen_config = GenerationConfig( - height=256, - width=256, - num_frames=9, - num_inference_steps=4, - ) - - print("Testing generation pipeline...") - pipeline = LTXVideoPipeline(transformer=model) - - # This would require proper text embeddings in practice - # video = pipeline("A cat walking", gen_config, seed=42) - # print(f"Generated video shape: {video.shape}") - - print("Pipeline initialized successfully!") + main()