diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 55f0965..17f3770 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1,9 +1,10 @@ import argparse import time from pathlib import Path -from typing import Optional, List, Tuple +from typing import Optional import mlx.core as mx +import mlx.nn as nn import numpy as np from PIL import Image from tqdm import tqdm @@ -110,6 +111,7 @@ def create_position_grid( # Convert temporal to time in seconds by dividing by fps pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps + # Always return float32 for RoPE precision - bfloat16 causes quality degradation return mx.array(pixel_coords, dtype=mx.float32) @@ -137,6 +139,7 @@ def denoise( Denoised latent tensor """ # If state is provided, use its latent (which may have conditioning applied) + dtype = latents.dtype if state is not None: latents = state.latent @@ -154,11 +157,11 @@ def denoise( denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) - # Per-token timesteps: sigma * mask - timesteps = sigma * denoise_mask_flat + # Per-token timesteps: sigma * mask (preserve dtype) + timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat else: - # All tokens get the same timestep - timesteps = mx.full((b, num_tokens), sigma) + # All tokens get the same timestep (use latent dtype) + timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) video_modality = Modality( latent=latents_flat, @@ -181,8 +184,11 @@ def denoise( mx.eval(denoised) + # Euler step (preserve dtype by converting Python floats to arrays) if sigma_next > 0: - latents = denoised + sigma_next * (latents - denoised) / sigma + sigma_next_arr = mx.array(sigma_next, dtype=dtype) + sigma_arr = mx.array(sigma, dtype=dtype) + latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr else: latents = denoised mx.eval(latents) @@ -283,6 +289,7 @@ def generate_video( print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) + model_dtype = text_embeddings.dtype # bfloat16 from text encoder mx.eval(text_embeddings) del text_encoder @@ -292,6 +299,8 @@ def generate_video( print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}") raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) + # Convert transformer weights to bfloat16 for memory efficiency + sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} config = LTXModelConfig( model_type=LTXModelType.VideoOnly, @@ -310,7 +319,7 @@ def generate_video( timestep_scale_multiplier=1000, ) - transformer = LTXModel(config) + transformer = LTXModel(config) transformer.load_weights(list(sanitized.items()), strict=False) mx.eval(transformer.parameters()) @@ -323,15 +332,15 @@ def generate_video( mx.eval(vae_encoder.parameters()) # Load and prepare image for stage 1 (half resolution) - input_image = load_image(image, height=height // 2, width=width // 2) - stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) stage1_image_latent = vae_encoder(stage1_image_tensor) mx.eval(stage1_image_latent) print(f" Stage 1 image latent: {stage1_image_latent.shape}") # Load and prepare image for stage 2 (full resolution) - input_image = load_image(image, height=height, width=width) - stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) stage2_image_latent = vae_encoder(stage2_image_tensor) mx.eval(stage2_image_latent) print(f" Stage 2 image latent: {stage2_image_latent.shape}") @@ -343,6 +352,7 @@ def generate_video( print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) + # Position grids stay float32 for RoPE precision positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) @@ -353,24 +363,26 @@ def generate_video( # Create initial state with zeros latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) state1 = LatentState( - latent=mx.zeros(latent_shape), - clean_latent=mx.zeros(latent_shape), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + latent=mx.zeros(latent_shape, dtype=model_dtype), + clean_latent=mx.zeros(latent_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength, ) + state1 = apply_conditioning(state1, [conditioning]) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # For Stage 1, noise_scale = 1.0 (first sigma) - noise = mx.random.normal(latent_shape) - noise_scale = STAGE_1_SIGMAS[0] # 1.0 + noise = mx.random.normal(latent_shape, dtype=model_dtype) + noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0 scaled_mask = state1.denoise_mask * noise_scale + state1 = LatentState( - latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask), + latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state1.clean_latent, denoise_mask=state1.denoise_mask, ) @@ -378,7 +390,7 @@ def generate_video( mx.eval(latents) else: # T2V: just use random noise - latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) + latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype) mx.eval(latents) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1) @@ -401,6 +413,7 @@ def generate_video( # Stage 2: Refine at full resolution print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") + # Position grids stay float32 for RoPE precision positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) @@ -411,7 +424,7 @@ def generate_video( state2 = LatentState( latent=latents, # Start with upscaled latent clean_latent=mx.zeros_like(latents), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( latent=stage2_image_latent, @@ -423,11 +436,11 @@ def generate_video( # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # For Stage 2, noise_scale = stage_2_sigmas[0] # Conditioned frames (mask=0) keep image latent, unconditioned get partial noise - noise = mx.random.normal(latents.shape) - noise_scale = STAGE_2_SIGMAS[0] + noise = mx.random.normal(latents.shape).astype(model_dtype) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) scaled_mask = state2.denoise_mask * noise_scale state2 = LatentState( - latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask), + latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state2.clean_latent, denoise_mask=state2.denoise_mask, ) @@ -435,9 +448,10 @@ def generate_video( mx.eval(latents) else: # T2V: add noise to all frames for refinement - noise_scale = STAGE_2_SIGMAS[0] - noise = mx.random.normal(latents.shape) - latents = noise * noise_scale + latents * (1 - noise_scale) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) + noise = mx.random.normal(latents.shape).astype(model_dtype) + latents = noise * noise_scale + latents * one_minus_scale mx.eval(latents) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2) diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py index b7cba4a..e0fb22b 100644 --- a/mlx_video/generate_av.py +++ b/mlx_video/generate_av.py @@ -3,7 +3,7 @@ import argparse import time from pathlib import Path -from typing import Optional, List +from typing import Optional import mlx.core as mx import numpy as np @@ -164,6 +164,7 @@ def denoise_av( Returns: Tuple of (video_latents, audio_latents) """ + dtype = video_latents.dtype # If video state is provided, use its latent if video_state is not None: video_latents = video_state.latent @@ -189,10 +190,10 @@ def denoise_av( denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) # Per-token timesteps: sigma * mask - video_timesteps = sigma * denoise_mask_flat + video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat else: # All tokens get the same timestep - video_timesteps = mx.full((b, num_video_tokens), sigma) + video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) video_modality = Modality( latent=video_flat, @@ -205,7 +206,7 @@ def denoise_av( audio_modality = Modality( latent=audio_flat, - timesteps=mx.full((ab, at), sigma), + timesteps=mx.full((ab, at), sigma, dtype=dtype), positions=audio_positions, context=audio_embeddings, context_mask=None, @@ -230,10 +231,12 @@ def denoise_av( mx.eval(video_denoised, audio_denoised) - # Euler step + # Euler step - use dtype-preserving arrays to avoid float32 promotion if sigma_next > 0: - video_latents = video_denoised + sigma_next * (video_latents - video_denoised) / sigma - audio_latents = audio_denoised + sigma_next * (audio_latents - audio_denoised) / sigma + sigma_next_arr = mx.array(sigma_next, dtype=dtype) + sigma_arr = mx.array(sigma, dtype=dtype) + video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr + audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr else: video_latents = video_denoised audio_latents = audio_denoised @@ -435,6 +438,7 @@ def generate_video_with_audio( # Get both video and audio embeddings video_embeddings, audio_embeddings = text_encoder(prompt) + model_dtype = video_embeddings.dtype # bfloat16 from text encoder mx.eval(video_embeddings, audio_embeddings) del text_encoder @@ -445,6 +449,9 @@ def generate_video_with_audio( raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) + # Convert transformer weights to bfloat16 for memory efficiency + sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} + config = LTXModelConfig( model_type=LTXModelType.AudioVideo, num_attention_heads=32, @@ -482,18 +489,16 @@ def generate_video_with_audio( mx.eval(vae_encoder.parameters()) # Load and prepare image for stage 1 (half resolution) - input_image = load_image(image, height=height // 2, width=width // 2) - stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) stage1_image_latent = vae_encoder(stage1_image_tensor) mx.eval(stage1_image_latent) - print(f" Stage 1 image latent: {stage1_image_latent.shape}") # Load and prepare image for stage 2 (full resolution) - input_image = load_image(image, height=height, width=width) - stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) stage2_image_latent = vae_encoder(stage2_image_tensor) mx.eval(stage2_image_latent) - print(f" Stage 2 image latent: {stage2_image_latent.shape}") del vae_encoder mx.clear_cache() @@ -502,9 +507,10 @@ def generate_video_with_audio( print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) - # Create position grids - video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) - audio_positions = create_audio_position_grid(1, audio_frames) + # Create position grids - MUST stay float32 for RoPE precision + # bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations + video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32 + audio_positions = create_audio_position_grid(1, audio_frames) # float32 mx.eval(video_positions, audio_positions) # Apply I2V conditioning for stage 1 if provided @@ -513,9 +519,9 @@ def generate_video_with_audio( if is_i2v and stage1_image_latent is not None: # PyTorch flow: create zeros -> apply conditioning -> apply noiser video_state1 = LatentState( - latent=mx.zeros(video_latent_shape), - clean_latent=mx.zeros(video_latent_shape), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + latent=mx.zeros(video_latent_shape, dtype=model_dtype), + clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( latent=stage1_image_latent, @@ -525,11 +531,11 @@ def generate_video_with_audio( video_state1 = apply_conditioning(video_state1, [conditioning]) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) - noise = mx.random.normal(video_latent_shape) - noise_scale = STAGE_1_SIGMAS[0] # 1.0 + noise = mx.random.normal(video_latent_shape).astype(model_dtype) + noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0 scaled_mask = video_state1.denoise_mask * noise_scale video_state1 = LatentState( - latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask), + latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=video_state1.clean_latent, denoise_mask=video_state1.denoise_mask, ) @@ -537,11 +543,11 @@ def generate_video_with_audio( mx.eval(video_latents) else: # T2V: just use random noise - video_latents = mx.random.normal(video_latent_shape) + video_latents = mx.random.normal(video_latent_shape).astype(model_dtype) mx.eval(video_latents) # Audio always uses pure noise (no I2V for audio) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) mx.eval(audio_latents) # Stage 1 denoising @@ -571,7 +577,8 @@ def generate_video_with_audio( # Stage 2: Refine at full resolution print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") - video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) + # Position grids stay float32 for RoPE precision + video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32 mx.eval(video_positions) # Apply I2V conditioning for stage 2 if provided @@ -581,7 +588,7 @@ def generate_video_with_audio( video_state2 = LatentState( latent=video_latents, # Start with upscaled latent clean_latent=mx.zeros_like(video_latents), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( latent=stage2_image_latent, @@ -591,11 +598,11 @@ def generate_video_with_audio( video_state2 = apply_conditioning(video_state2, [conditioning]) # Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise - video_noise = mx.random.normal(video_latents.shape) - noise_scale = STAGE_2_SIGMAS[0] + video_noise = mx.random.normal(video_latents.shape).astype(model_dtype) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) scaled_mask = video_state2.denoise_mask * noise_scale video_state2 = LatentState( - latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask), + latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=video_state2.clean_latent, denoise_mask=video_state2.denoise_mask, ) @@ -603,16 +610,18 @@ def generate_video_with_audio( mx.eval(video_latents) # Audio still gets noise (no I2V for audio) - audio_noise = mx.random.normal(audio_latents.shape) - audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale mx.eval(audio_latents) else: # T2V: add noise to all frames for refinement - noise_scale = STAGE_2_SIGMAS[0] - video_noise = mx.random.normal(video_latents.shape) - audio_noise = mx.random.normal(audio_latents.shape) - video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale) - audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale + video_noise = mx.random.normal(video_latents.shape).astype(model_dtype) + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + video_latents = video_noise * noise_scale + video_latents * one_minus_scale + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale mx.eval(video_latents, audio_latents) video_latents, audio_latents = denoise_av( @@ -671,27 +680,13 @@ def generate_video_with_audio( vocoder = load_vocoder(model_path) mx.eval(audio_decoder.parameters(), vocoder.parameters()) - # Debug: check per-channel statistics are loaded - pcs = audio_decoder.per_channel_statistics - print(f"Per-channel stats: mean_of_means range=[{pcs._mean_of_means.min():.4f}, {pcs._mean_of_means.max():.4f}], std_of_means range=[{pcs._std_of_means.min():.4f}, {pcs._std_of_means.max():.4f}]") - - # Debug: check audio latent statistics - print(f"Audio latents shape: {audio_latents.shape}") - print(f"Audio latents stats: min={audio_latents.min():.4f}, max={audio_latents.max():.4f}, mean={audio_latents.mean():.4f}, std={mx.std(audio_latents):.4f}") - mel_spectrogram = audio_decoder(audio_latents) mx.eval(mel_spectrogram) - print(f"Mel spectrogram shape: {mel_spectrogram.shape}") - print(f"Mel spectrogram stats: min={mel_spectrogram.min():.4f}, max={mel_spectrogram.max():.4f}, mean={mel_spectrogram.mean():.4f}") - # Audio decoder output is already in vocoder format (B, C, T, F) audio_waveform = vocoder(mel_spectrogram) mx.eval(audio_waveform) - print(f"Audio waveform shape: {audio_waveform.shape}") - print(f"Audio waveform stats: min={audio_waveform.min():.4f}, max={audio_waveform.max():.4f}, mean={audio_waveform.mean():.4f}") - audio_np = np.array(audio_waveform) if audio_np.ndim == 3: audio_np = audio_np[0] # Remove batch dim diff --git a/mlx_video/utils.py b/mlx_video/utils.py index aff48ed..cebbed7 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -171,6 +171,7 @@ def load_image( image_path: Union[str, Path], height: Optional[int] = None, width: Optional[int] = None, + dtype: mx.Dtype = mx.float32, ) -> mx.array: """Load and preprocess an image for I2V conditioning. @@ -210,7 +211,7 @@ def load_image( # Convert to numpy then MLX image_np = np.array(image).astype(np.float32) / 255.0 - return mx.array(image_np) + return mx.array(image_np, dtype=dtype) def resize_image_aspect_ratio(