diff --git a/mlx_video/conditioning/__init__.py b/mlx_video/conditioning/__init__.py new file mode 100644 index 0000000..f976035 --- /dev/null +++ b/mlx_video/conditioning/__init__.py @@ -0,0 +1,3 @@ +"""Conditioning modules for LTX-2 video generation.""" + +from mlx_video.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning diff --git a/mlx_video/conditioning/latent.py b/mlx_video/conditioning/latent.py new file mode 100644 index 0000000..1825e3d --- /dev/null +++ b/mlx_video/conditioning/latent.py @@ -0,0 +1,196 @@ +"""Latent-based conditioning for I2V (Image-to-Video) generation. + +This module provides conditioning that injects encoded image latents into +the video generation process at specific frame positions. +""" + +from dataclasses import dataclass +from typing import Optional, List, Tuple + +import mlx.core as mx + + +@dataclass +class VideoConditionByLatentIndex: + """Condition video generation by injecting latents at a specific frame index. + + This replaces the latent at the specified frame index with the conditioned + latent and controls how much denoising is applied via the strength parameter. + + Args: + latent: Encoded image latent of shape (B, C, 1, H, W) + frame_idx: Frame index to condition (0 = first frame) + strength: Denoising strength (1.0 = full denoise, 0.0 = keep original) + """ + latent: mx.array + frame_idx: int = 0 + strength: float = 1.0 + + def get_num_latent_frames(self) -> int: + """Get number of latent frames in the conditioning.""" + return self.latent.shape[2] + + +@dataclass +class LatentState: + """State for latent diffusion with conditioning support. + + Attributes: + latent: Current noisy latent (B, C, F, H, W) + clean_latent: Clean conditioning latent (B, C, F, H, W) + denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where + 1.0 = full denoise, 0.0 = keep clean + """ + latent: mx.array + clean_latent: mx.array + denoise_mask: mx.array + + def clone(self) -> "LatentState": + """Create a copy of the state.""" + return LatentState( + latent=self.latent, + clean_latent=self.clean_latent, + denoise_mask=self.denoise_mask, + ) + + +def create_initial_state( + shape: Tuple[int, ...], + seed: Optional[int] = None, + noise_scale: float = 1.0, +) -> LatentState: + """Create initial noisy latent state. + + Args: + shape: Shape of latent (B, C, F, H, W) + seed: Optional random seed + noise_scale: Scale for initial noise (sigma) + + Returns: + Initial LatentState with random noise + """ + if seed is not None: + mx.random.seed(seed) + + noise = mx.random.normal(shape) + + return LatentState( + latent=noise * noise_scale, + clean_latent=mx.zeros(shape), + denoise_mask=mx.ones((shape[0], 1, shape[2], 1, 1)), # Full denoise by default + ) + + +def apply_conditioning( + state: LatentState, + conditionings: List[VideoConditionByLatentIndex], +) -> LatentState: + """Apply conditioning items to a latent state. + + Args: + state: Current latent state + conditionings: List of conditioning items to apply + + Returns: + Updated LatentState with conditioning applied + """ + state = state.clone() + b, c, f, h, w = state.latent.shape + + for cond in conditionings: + cond_latent = cond.latent + frame_idx = cond.frame_idx + strength = cond.strength + + # Validate shapes + _, cond_c, cond_f, cond_h, cond_w = cond_latent.shape + if (cond_c, cond_h, cond_w) != (c, h, w): + raise ValueError( + f"Conditioning latent spatial shape ({cond_c}, {cond_h}, {cond_w}) " + f"does not match target shape ({c}, {h}, {w})" + ) + + if frame_idx >= f: + raise ValueError( + f"Frame index {frame_idx} is out of bounds for latent with {f} frames" + ) + + # Get the conditioning frames count + num_cond_frames = cond_f + end_idx = min(frame_idx + num_cond_frames, f) + + # Replace latent at conditioning position + # state.latent[:, :, frame_idx:end_idx] = cond_latent[:, :, :end_idx - frame_idx] + latent_list = [] + clean_list = [] + mask_list = [] + + for i in range(f): + if frame_idx <= i < end_idx: + # Use conditioning latent + cond_idx = i - frame_idx + latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) + clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) + # Set mask: 1.0 - strength means less denoising for conditioned frames + mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength)) + else: + # Keep original + latent_list.append(state.latent[:, :, i:i+1]) + clean_list.append(state.clean_latent[:, :, i:i+1]) + mask_list.append(state.denoise_mask[:, :, i:i+1]) + + state.latent = mx.concatenate(latent_list, axis=2) + state.clean_latent = mx.concatenate(clean_list, axis=2) + state.denoise_mask = mx.concatenate(mask_list, axis=2) + + return state + + +def apply_denoise_mask( + denoised: mx.array, + clean: mx.array, + denoise_mask: mx.array, +) -> mx.array: + """Blend denoised output with clean state based on mask. + + Args: + denoised: Denoised latent (B, C, F, H, W) + clean: Clean conditioning latent (B, C, F, H, W) + denoise_mask: Mask where 1.0 = use denoised, 0.0 = use clean + + Returns: + Blended latent + """ + return denoised * denoise_mask + clean * (1.0 - denoise_mask) + + +def add_noise_with_state( + state: LatentState, + noise_scale: float, +) -> LatentState: + """Add noise to state while respecting conditioning. + + For conditioned frames (mask < 1.0), adds noise proportionally + to allow some refinement while preserving the conditioning. + + Args: + state: Current latent state + noise_scale: Scale for noise (sigma) + + Returns: + Updated state with noise added + """ + state = state.clone() + + # Generate noise + noise = mx.random.normal(state.latent.shape) + + # For fully conditioned frames (mask=0), we want to add minimal noise + # For unconditioned frames (mask=1), we want full noise + # noisy = noise * sigma + latent * (1 - sigma) + # But we scale sigma by the mask for conditioned regions + + effective_scale = noise_scale * state.denoise_mask + state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale) + + return state diff --git a/mlx_video/convert.py b/mlx_video/convert.py index e251fc5..11491e0 100644 --- a/mlx_video/convert.py +++ b/mlx_video/convert.py @@ -291,6 +291,59 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: return sanitized +def sanitize_vae_encoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + """Sanitize VAE encoder weight names from PyTorch format to MLX format. + + Args: + weights: Dictionary of weights with PyTorch naming + + Returns: + Dictionary with MLX-compatible naming for VAE encoder + """ + sanitized = {} + + for key, value in weights.items(): + new_key = key + + # Skip position_ids (not needed) + if "position_ids" in key: + continue + + # Only process VAE encoder weights + if not key.startswith("vae."): + continue + + # Handle per-channel statistics key mapping + if "vae.per_channel_statistics" in key: + if key == "vae.per_channel_statistics.mean-of-means": + new_key = "per_channel_statistics._mean_of_means" + elif key == "vae.per_channel_statistics.std-of-means": + new_key = "per_channel_statistics._std_of_means" + else: + # Skip other per_channel_statistics keys + continue + elif key.startswith("vae.encoder."): + # Strip the vae.encoder. prefix for encoder weights + new_key = key.replace("vae.encoder.", "") + else: + # Skip other vae.* keys that are not encoder weights + continue + + # Handle Conv3d weight shape conversion + # PyTorch: (out_channels, in_channels, D, H, W) + # MLX: (out_channels, D, H, W, in_channels) + if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5: + value = mx.transpose(value, (0, 2, 3, 4, 1)) + + # Handle Conv2d weight shape conversion + if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4: + value = mx.transpose(value, (0, 2, 3, 1)) + + sanitized[new_key] = value + + return sanitized + + def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]: """Sanitize audio VAE weight names from PyTorch format to MLX format. diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 0c9e562..973fa60 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1,6 +1,7 @@ import argparse import time from pathlib import Path +from typing import Optional, List, Tuple import mlx.core as mx import numpy as np @@ -22,10 +23,13 @@ class Colors: from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality -from mlx_video.convert import sanitize_transformer_weights -from mlx_video.utils import to_denoised +from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder_weights +from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder +from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents +from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning +from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state from mlx_video.utils import get_model_path @@ -115,17 +119,49 @@ def denoise( transformer: LTXModel, sigmas: list, verbose: bool = True, + state: Optional[LatentState] = None, ) -> mx.array: - """Run denoising loop.""" + """Run denoising loop with optional conditioning. + + Args: + latents: Noisy latent tensor (B, C, F, H, W) + positions: Position embeddings + text_embeddings: Text conditioning embeddings + transformer: LTX model + sigmas: List of sigma values for denoising schedule + verbose: Whether to show progress bar + state: Optional LatentState for I2V conditioning + + Returns: + Denoised latent tensor + """ + # If state is provided, use its latent (which may have conditioning applied) + if state is not None: + latents = state.latent + for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose): sigma, sigma_next = sigmas[i], sigmas[i + 1] b, c, f, h, w = latents.shape + num_tokens = f * h * w latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) + # Compute per-token timesteps + # For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1) + if state is not None: + # Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens) + denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) + denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) + denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) + # Per-token timesteps: sigma * mask + timesteps = sigma * denoise_mask_flat + else: + # All tokens get the same timestep + timesteps = mx.full((b, num_tokens), sigma) + video_modality = Modality( latent=latents_flat, - timesteps=mx.full((1,), sigma), + timesteps=timesteps, positions=positions, context=text_embeddings, context_mask=None, @@ -137,6 +173,11 @@ def denoise( velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) denoised = to_denoised(latents, velocity, sigma) + + # Apply conditioning mask if state is provided + if state is not None: + denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) + mx.eval(denoised) if sigma_next > 0: @@ -163,10 +204,15 @@ def generate_video( enhance_prompt: bool = False, max_tokens: int = 512, temperature: float = 0.7, + image: Optional[str] = None, + image_strength: float = 1.0, + image_frame_idx: int = 0, ): - """Generate video from text prompt. + """Generate video from text prompt, optionally conditioned on an image. Args: + model_repo: Model repository ID + text_encoder_repo: Text encoder repository ID prompt: Text description of the video to generate height: Output video height (must be divisible by 64) width: Output video width (must be divisible by 64) @@ -175,6 +221,13 @@ def generate_video( fps: Frames per second for output video output_path: Path to save the output video save_frames: Whether to save individual frames as images + verbose: Whether to print progress + enhance_prompt: Whether to enhance prompt using Gemma + max_tokens: Max tokens for prompt enhancement + temperature: Temperature for prompt enhancement + image: Path to conditioning image for I2V (Image-to-Video) + image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original) + image_frame_idx: Frame index to condition (0 = first frame) """ start_time = time.time() @@ -188,8 +241,12 @@ def generate_video( num_frames = adjusted_num_frames - print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") + is_i2v = image is not None + mode_str = "I2V" if is_i2v else "T2V" + print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") + if is_i2v: + print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") # Get model path model_path = get_model_path(model_repo) @@ -247,6 +304,31 @@ def generate_video( transformer.load_weights(list(sanitized.items()), strict=False) mx.eval(transformer.parameters()) + # Load VAE encoder and encode image for I2V conditioning + stage1_image_latent = None + stage2_image_latent = None + if is_i2v: + print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}") + vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors')) + mx.eval(vae_encoder.parameters()) + + # Load and prepare image for stage 1 (half resolution) + input_image = load_image(image, height=height // 2, width=width // 2) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) + stage1_image_latent = vae_encoder(stage1_image_tensor) + mx.eval(stage1_image_latent) + print(f" Stage 1 image latent: {stage1_image_latent.shape}") + + # Load and prepare image for stage 2 (full resolution) + input_image = load_image(image, height=height, width=width) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) + stage2_image_latent = vae_encoder(stage2_image_tensor) + mx.eval(stage2_image_latent) + print(f" Stage 2 image latent: {stage2_image_latent.shape}") + + del vae_encoder + mx.clear_cache() + # Stage 1: Generate at half resolution print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) @@ -256,7 +338,25 @@ def generate_video( positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) - latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose) + # Apply I2V conditioning if provided + state1 = None + if is_i2v and stage1_image_latent is not None: + # Create state with conditioning + state1 = LatentState( + latent=latents, + clean_latent=mx.zeros_like(latents), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + ) + conditioning = VideoConditionByLatentIndex( + latent=stage1_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) + state1 = apply_conditioning(state1, [conditioning]) + latents = state1.latent + mx.eval(latents) + + latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1) # Upsample latents print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}") @@ -285,7 +385,24 @@ def generate_video( latents = noise * noise_scale + latents * (1 - noise_scale) mx.eval(latents) - latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose) + # Apply I2V conditioning for stage 2 if provided + state2 = None + if is_i2v and stage2_image_latent is not None: + state2 = LatentState( + latent=latents, + clean_latent=mx.zeros_like(latents), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + ) + conditioning = VideoConditionByLatentIndex( + latent=stage2_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) + state2 = apply_conditioning(state2, [conditioning]) + latents = state2.latent + mx.eval(latents) + + latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2) del transformer mx.clear_cache() @@ -335,13 +452,18 @@ def generate_video( def main(): parser = argparse.ArgumentParser( - description="Generate videos with MLX LTX-2", + description="Generate videos with MLX LTX-2 (T2V and I2V)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: + # Text-to-Video (T2V) python -m mlx_video.generate --prompt "A cat walking on grass" python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768 python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4 + + # Image-to-Video (I2V) + python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg + python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8 """ ) @@ -426,6 +548,24 @@ Examples: default=0.7, help="Temperature for prompt enhancement (default: 0.7)" ) + parser.add_argument( + "--image", "-i", + type=str, + default=None, + help="Path to conditioning image for I2V (Image-to-Video) generation" + ) + parser.add_argument( + "--image-strength", + type=float, + default=1.0, + help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)" + ) + parser.add_argument( + "--image-frame-idx", + type=int, + default=0, + help="Frame index to condition for I2V (0 = first frame, default: 0)" + ) args = parser.parse_args() generate_video( diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py index 7551c51..f73998d 100644 --- a/mlx_video/generate_av.py +++ b/mlx_video/generate_av.py @@ -3,7 +3,7 @@ import argparse import time from pathlib import Path -from typing import Optional +from typing import Optional, List import mlx.core as mx import numpy as np @@ -27,9 +27,12 @@ from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeTyp from mlx_video.models.ltx.ltx import LTXModel from mlx_video.models.ltx.transformer import Modality from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights -from mlx_video.utils import to_denoised, get_model_path +from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder +from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents +from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning +from mlx_video.conditioning.latent import LatentState, apply_denoise_mask # Distilled sigma schedules @@ -141,13 +144,35 @@ def denoise_av( transformer: LTXModel, sigmas: list, verbose: bool = True, + video_state: Optional[LatentState] = None, ) -> tuple[mx.array, mx.array]: - """Run denoising loop for audio-video generation.""" + """Run denoising loop for audio-video generation with optional I2V conditioning. + + Args: + video_latents: Video latent tensor (B, C, F, H, W) + audio_latents: Audio latent tensor (B, C, T, F) + video_positions: Video position embeddings + audio_positions: Audio position embeddings + video_embeddings: Video text embeddings + audio_embeddings: Audio text embeddings + transformer: LTX model + sigmas: List of sigma values + verbose: Whether to show progress bar + video_state: Optional LatentState for I2V conditioning + + Returns: + Tuple of (video_latents, audio_latents) + """ + # If video state is provided, use its latent + if video_state is not None: + video_latents = video_state.latent + for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose): sigma, sigma_next = sigmas[i], sigmas[i + 1] # Flatten video latents b, c, f, h, w = video_latents.shape + num_video_tokens = f * h * w video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1)) # Flatten audio latents: (B, C, T, F) -> (B, T, C*F) @@ -155,9 +180,22 @@ def denoise_av( audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F) audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) + # Compute per-token timesteps for video + # For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1) + if video_state is not None: + # Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens) + denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1)) + denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) + denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) + # Per-token timesteps: sigma * mask + video_timesteps = sigma * denoise_mask_flat + else: + # All tokens get the same timestep + video_timesteps = mx.full((b, num_video_tokens), sigma) + video_modality = Modality( latent=video_flat, - timesteps=mx.full((1,), sigma), + timesteps=video_timesteps, positions=video_positions, context=video_embeddings, context_mask=None, @@ -166,7 +204,7 @@ def denoise_av( audio_modality = Modality( latent=audio_flat, - timesteps=mx.full((1,), sigma), + timesteps=mx.full((ab, at), sigma), positions=audio_positions, context=audio_embeddings, context_mask=None, @@ -184,6 +222,11 @@ def denoise_av( # Compute denoised video_denoised = to_denoised(video_latents, video_velocity, sigma) audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) + + # Apply conditioning mask for video if state is provided + if video_state is not None: + video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask) + mx.eval(video_denoised, audio_denoised) # Euler step @@ -317,8 +360,31 @@ def generate_video_with_audio( enhance_prompt: bool = False, max_tokens: int = 512, temperature: float = 0.7, + image: Optional[str] = None, + image_strength: float = 1.0, + image_frame_idx: int = 0, ): - """Generate video with synchronized audio from text prompt.""" + """Generate video with synchronized audio from text prompt, optionally conditioned on an image. + + Args: + model_repo: Model repository ID + text_encoder_repo: Text encoder repository ID + prompt: Text description of the video to generate + height: Output video height (must be divisible by 64) + width: Output video width (must be divisible by 64) + num_frames: Number of frames + seed: Random seed + fps: Frames per second + output_path: Output video path + output_audio_path: Output audio path + verbose: Whether to print progress + enhance_prompt: Whether to enhance prompt using Gemma + max_tokens: Max tokens for prompt enhancement + temperature: Temperature for prompt enhancement + image: Path to conditioning image for I2V + image_strength: Conditioning strength (1.0 = full denoise) + image_frame_idx: Frame index to condition (0 = first frame) + """ start_time = time.time() # Validate dimensions @@ -333,9 +399,13 @@ def generate_video_with_audio( # Calculate audio frames audio_frames = compute_audio_frames(num_frames, fps) - print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}") + is_i2v = image is not None + mode_str = "I2V+Audio" if is_i2v else "T2V+Audio" + print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}") print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}") print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}") + if is_i2v: + print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}") model_path = get_model_path(model_repo) text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo) @@ -400,6 +470,31 @@ def generate_video_with_audio( transformer.load_weights(list(sanitized.items()), strict=False) mx.eval(transformer.parameters()) + # Load VAE encoder and encode image for I2V conditioning + stage1_image_latent = None + stage2_image_latent = None + if is_i2v: + print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}") + vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors')) + mx.eval(vae_encoder.parameters()) + + # Load and prepare image for stage 1 (half resolution) + input_image = load_image(image, height=height // 2, width=width // 2) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) + stage1_image_latent = vae_encoder(stage1_image_tensor) + mx.eval(stage1_image_latent) + print(f" Stage 1 image latent: {stage1_image_latent.shape}") + + # Load and prepare image for stage 2 (full resolution) + input_image = load_image(image, height=height, width=width) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) + stage2_image_latent = vae_encoder(stage2_image_tensor) + mx.eval(stage2_image_latent) + print(f" Stage 2 image latent: {stage2_image_latent.shape}") + + del vae_encoder + mx.clear_cache() + # Initialize latents print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) @@ -412,12 +507,30 @@ def generate_video_with_audio( audio_positions = create_audio_position_grid(1, audio_frames) mx.eval(video_positions, audio_positions) + # Apply I2V conditioning for stage 1 if provided + video_state1 = None + if is_i2v and stage1_image_latent is not None: + video_state1 = LatentState( + latent=video_latents, + clean_latent=mx.zeros_like(video_latents), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + ) + conditioning = VideoConditionByLatentIndex( + latent=stage1_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) + video_state1 = apply_conditioning(video_state1, [conditioning]) + video_latents = video_state1.latent + mx.eval(video_latents) + # Stage 1 denoising video_latents, audio_latents = denoise_av( video_latents, audio_latents, video_positions, audio_positions, video_embeddings, audio_embeddings, - transformer, STAGE_1_SIGMAS, verbose=verbose + transformer, STAGE_1_SIGMAS, verbose=verbose, + video_state=video_state1 ) # Upsample video latents @@ -449,11 +562,29 @@ def generate_video_with_audio( audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) mx.eval(video_latents, audio_latents) + # Apply I2V conditioning for stage 2 if provided + video_state2 = None + if is_i2v and stage2_image_latent is not None: + video_state2 = LatentState( + latent=video_latents, + clean_latent=mx.zeros_like(video_latents), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + ) + conditioning = VideoConditionByLatentIndex( + latent=stage2_image_latent, + frame_idx=image_frame_idx, + strength=image_strength, + ) + video_state2 = apply_conditioning(video_state2, [conditioning]) + video_latents = video_state2.latent + mx.eval(video_latents) + video_latents, audio_latents = denoise_av( video_latents, audio_latents, video_positions, audio_positions, video_embeddings, audio_embeddings, - transformer, STAGE_2_SIGMAS, verbose=verbose + transformer, STAGE_2_SIGMAS, verbose=verbose, + video_state=video_state2 ) del transformer @@ -549,13 +680,18 @@ def generate_video_with_audio( def main(): parser = argparse.ArgumentParser( - description="Generate videos with synchronized audio using MLX LTX-2", + description="Generate videos with synchronized audio using MLX LTX-2 (T2V and I2V)", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: + # Text-to-Video with Audio (T2V+Audio) python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach" python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav + + # Image-to-Video with Audio (I2V+Audio) + python -m mlx_video.generate_av --prompt "A person dancing" --image photo.jpg + python -m mlx_video.generate_av --prompt "Waves crashing" --image beach.png --image-strength 0.8 """ ) @@ -587,6 +723,12 @@ Examples: help="Max tokens for prompt enhancement") parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for prompt enhancement") + parser.add_argument("--image", "-i", type=str, default=None, + help="Path to conditioning image for I2V (Image-to-Video) generation") + parser.add_argument("--image-strength", type=float, default=1.0, + help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)") + parser.add_argument("--image-frame-idx", type=int, default=0, + help="Frame index to condition for I2V (0 = first frame, default: 0)") args = parser.parse_args() @@ -605,6 +747,9 @@ Examples: enhance_prompt=args.enhance_prompt, max_tokens=args.max_tokens, temperature=args.temperature, + image=args.image, + image_strength=args.image_strength, + image_frame_idx=args.image_frame_idx, ) diff --git a/mlx_video/models/ltx/video_vae/__init__.py b/mlx_video/models/ltx/video_vae/__init__.py index 208d8f5..d620f17 100644 --- a/mlx_video/models/ltx/video_vae/__init__.py +++ b/mlx_video/models/ltx/video_vae/__init__.py @@ -1 +1,3 @@ from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder +from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image +from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder diff --git a/mlx_video/models/ltx/video_vae/encoder.py b/mlx_video/models/ltx/video_vae/encoder.py new file mode 100644 index 0000000..6c90a4b --- /dev/null +++ b/mlx_video/models/ltx/video_vae/encoder.py @@ -0,0 +1,187 @@ +"""Video VAE Encoder for LTX-2 Image-to-Video. + +The encoder compresses input images/videos to latent representations. +Used for I2V (image-to-video) conditioning by encoding the input image +to latent space, which can then be used to condition video generation. +""" + +from pathlib import Path +from typing import List, Tuple, Any, Optional +import json + +import mlx.core as mx +import mlx.nn as nn + +from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType + + +def load_vae_encoder(model_path: str) -> VideoEncoder: + """Load VAE encoder from safetensors file. + + Args: + model_path: Path to the model weights (safetensors file or directory) + + Returns: + Loaded VideoEncoder instance + """ + from safetensors import safe_open + + model_path = Path(model_path) + + # Try to find the weights file + if model_path.is_file() and model_path.suffix == ".safetensors": + weights_path = model_path + elif (model_path / "ltx-2-19b-distilled.safetensors").exists(): + weights_path = model_path / "ltx-2-19b-distilled.safetensors" + elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists(): + weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors" + else: + raise FileNotFoundError(f"VAE weights not found at {model_path}") + + print(f"Loading VAE encoder from {weights_path}...") + + # Read config from safetensors metadata + encoder_blocks = [] + norm_layer = NormLayerType.PIXEL_NORM + latent_log_var = LogVarianceType.UNIFORM + patch_size = 4 + + try: + with safe_open(str(weights_path), framework="numpy") as f: + metadata = f.metadata() + if metadata and "config" in metadata: + configs = json.loads(metadata["config"]) + vae_config = configs.get("vae", {}) + + # Parse encoder blocks + raw_blocks = vae_config.get("encoder_blocks", []) + for block in raw_blocks: + if isinstance(block, list) and len(block) == 2: + name, params = block + encoder_blocks.append((name, params)) + + # Parse other config + norm_str = vae_config.get("norm_layer", "pixel_norm") + norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM + + var_str = vae_config.get("latent_log_var", "uniform") + if var_str == "uniform": + latent_log_var = LogVarianceType.UNIFORM + elif var_str == "per_channel": + latent_log_var = LogVarianceType.PER_CHANNEL + elif var_str == "constant": + latent_log_var = LogVarianceType.CONSTANT + else: + latent_log_var = LogVarianceType.NONE + + patch_size = vae_config.get("patch_size", 4) + + print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}") + except Exception as e: + print(f" Could not read config from metadata: {e}") + # Use default config + encoder_blocks = [ + ("res_x", {"num_layers": 4}), + ("compress_space_res", {"multiplier": 2}), + ("res_x", {"num_layers": 6}), + ("compress_time_res", {"multiplier": 2}), + ("res_x", {"num_layers": 6}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ("compress_all_res", {"multiplier": 2}), + ("res_x", {"num_layers": 2}), + ] + print(f" Using default encoder config with {len(encoder_blocks)} blocks") + + # Create encoder + encoder = VideoEncoder( + convolution_dimensions=3, + in_channels=3, + out_channels=128, + encoder_blocks=encoder_blocks, + patch_size=patch_size, + norm_layer=norm_layer, + latent_log_var=latent_log_var, + encoder_spatial_padding_mode=PaddingModeType.ZEROS, + ) + + # Load weights + weights = mx.load(str(weights_path)) + + # Determine prefix based on weight keys + has_vae_prefix = any(k.startswith("vae.") for k in weights.keys()) + + if has_vae_prefix: + prefix = "vae.encoder." + stats_prefix = "vae.per_channel_statistics." + else: + prefix = "encoder." + stats_prefix = "per_channel_statistics." + + # Load per-channel statistics for normalization + mean_key = f"{stats_prefix}mean-of-means" + std_key = f"{stats_prefix}std-of-means" + + if mean_key in weights: + encoder.per_channel_statistics.mean = weights[mean_key] + print(f" Loaded latent mean: shape {weights[mean_key].shape}") + if std_key in weights: + encoder.per_channel_statistics.std = weights[std_key] + print(f" Loaded latent std: shape {weights[std_key].shape}") + + # Build encoder weights dict with key remapping + encoder_weights = {} + for key, value in weights.items(): + if not key.startswith(prefix): + continue + + # Remove prefix + new_key = key[len(prefix):] + + # Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I) + if ".weight" in key and value.ndim == 5: + value = mx.transpose(value, (0, 2, 3, 4, 1)) + + encoder_weights[new_key] = value + + print(f" Found {len(encoder_weights)} encoder weights") + + # Load weights + encoder.load_weights(list(encoder_weights.items()), strict=False) + + print("VAE encoder loaded successfully") + return encoder + + +def encode_image( + image: mx.array, + encoder: VideoEncoder, +) -> mx.array: + """Encode a single image to latent space. + + Args: + image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3) + encoder: Loaded VAE encoder + + Returns: + Latent tensor of shape (1, 128, 1, H//32, W//32) + """ + # Add batch dimension if needed + if image.ndim == 3: + image = mx.expand_dims(image, axis=0) # (1, H, W, 3) + + # Convert from (B, H, W, C) to (B, C, H, W) + image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W) + + # Normalize to [-1, 1] + if image.max() > 1.0: + image = image / 255.0 + image = image * 2.0 - 1.0 + + # Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W) + image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W) + + # Encode + latent = encoder(image) + + return latent diff --git a/mlx_video/models/ltx/video_vae/resnet.py b/mlx_video/models/ltx/video_vae/resnet.py index f6d7c1b..d93754c 100644 --- a/mlx_video/models/ltx/video_vae/resnet.py +++ b/mlx_video/models/ltx/video_vae/resnet.py @@ -141,9 +141,10 @@ class UNetMidBlock3D(nn.Module): self.num_layers = num_layers - # Create ResNet blocks - self.resnets = [ - ResnetBlock3D( + # Create ResNet blocks - use dict for MLX parameter tracking + # Named res_blocks to match PyTorch weight keys + self.res_blocks = { + i: ResnetBlock3D( dims=dims, in_channels=in_channels, out_channels=in_channels, @@ -154,8 +155,8 @@ class UNetMidBlock3D(nn.Module): timestep_conditioning=timestep_conditioning, spatial_padding_mode=spatial_padding_mode, ) - for _ in range(num_layers) - ] + for i in range(num_layers) + } def __call__( self, @@ -164,8 +165,8 @@ class UNetMidBlock3D(nn.Module): timestep: Optional[mx.array] = None, generator: Optional[int] = None, ) -> mx.array: - - for resnet in self.resnets: + + for resnet in self.res_blocks.values(): x = resnet(x, causal=causal, generator=generator) return x diff --git a/mlx_video/models/ltx/video_vae/sampling.py b/mlx_video/models/ltx/video_vae/sampling.py index f672e4d..6ca3d41 100644 --- a/mlx_video/models/ltx/video_vae/sampling.py +++ b/mlx_video/models/ltx/video_vae/sampling.py @@ -9,6 +9,15 @@ from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingMode class SpaceToDepthDownsample(nn.Module): + """Space-to-depth downsampling with 3x3 conv and skip connection. + + PyTorch-compatible implementation: + 1. Apply 3x3 conv: in_channels -> out_channels // prod(stride) + 2. Space-to-depth on conv output: channels * prod(stride) + 3. Space-to-depth on input with group averaging for skip connection + 4. Add skip connection + """ + def __init__( self, dims: int, @@ -17,7 +26,6 @@ class SpaceToDepthDownsample(nn.Module): stride: Union[int, Tuple[int, int, int]], spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, ): - super().__init__() if isinstance(stride, int): @@ -25,61 +33,74 @@ class SpaceToDepthDownsample(nn.Module): self.stride = stride self.dims = dims + self.out_channels = out_channels - # Calculate the multiplier for channels + # Calculate channels multiplier = stride[0] * stride[1] * stride[2] - intermediate_channels = in_channels * multiplier + self.group_size = in_channels * multiplier // out_channels + conv_out_channels = out_channels // multiplier - # 1x1x1 convolution to adjust channels + # 3x3 convolution (not 1x1) self.conv = CausalConv3d( - in_channels=intermediate_channels, - out_channels=out_channels, - kernel_size=1, + in_channels=in_channels, + out_channels=conv_out_channels, + kernel_size=3, stride=1, - padding=0, + padding=1, spatial_padding_mode=spatial_padding_mode, ) - def __call__(self, x: mx.array, causal: bool = True) -> mx.array: - + def _space_to_depth(self, x: mx.array) -> mx.array: + """Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w""" b, c, d, h, w = x.shape st, sh, sw = self.stride + # Reshape to group spatial elements + x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw)) + + # Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W') + x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6)) + + # Reshape to combine channels + new_c = c * st * sh * sw + new_d = d // st + new_h = h // sh + new_w = w // sw + x = mx.reshape(x, (b, new_c, new_d, new_h, new_w)) + + return x + + def __call__(self, x: mx.array, causal: bool = True) -> mx.array: + b, c, d, h, w = x.shape + st, sh, sw = self.stride + + # Temporal padding for causal mode + if st == 2: + # Duplicate first frame for padding + x = mx.concatenate([x[:, :, :1, :, :], x], axis=2) + d = d + 1 + # Pad if necessary to make dimensions divisible by stride pad_d = (st - d % st) % st pad_h = (sh - h % sh) % sh pad_w = (sw - w % sw) % sw if pad_d > 0 or pad_h > 0 or pad_w > 0: - # For causal, pad at the end of temporal dimension - if causal: - x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)]) - else: - x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2), - (pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)]) + x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)]) - b, c, d, h, w = x.shape + # Skip connection: space-to-depth on input, then group mean + x_in = self._space_to_depth(x) + # Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w) + b2, c2, d2, h2, w2 = x_in.shape + x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2)) + x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w) - # Reshape to group spatial elements - # (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw) - x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw)) + # Conv branch: apply conv then space-to-depth + x_conv = self.conv(x, causal=causal) + x_conv = self._space_to_depth(x_conv) - # Permute to move stride elements to channel dim - # (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W') - x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6)) - - # Reshape to combine channels - # (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W') - new_c = c * st * sh * sw - new_d = d // st - new_h = h // sh - new_w = w // sw - x = mx.reshape(x, (b, new_c, new_d, new_h, new_w)) - - # Apply 1x1 conv to adjust channels - x = self.conv(x, causal=causal) - - return x + # Add skip connection + return x_conv + x_in class DepthToSpaceUpsample(nn.Module): diff --git a/mlx_video/models/ltx/video_vae/video_vae.py b/mlx_video/models/ltx/video_vae/video_vae.py index 2375650..cc3ec3a 100644 --- a/mlx_video/models/ltx/video_vae/video_vae.py +++ b/mlx_video/models/ltx/video_vae/video_vae.py @@ -273,9 +273,9 @@ class VideoEncoder(nn.Module): spatial_padding_mode=encoder_spatial_padding_mode, ) - # Build encoder blocks - self.down_blocks = [] - for block_name, block_params in encoder_blocks: + # Build encoder blocks - use dict with int keys for MLX parameter tracking + self.down_blocks = {} + for i, (block_name, block_params) in enumerate(encoder_blocks): block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params block, feature_channels = _make_encoder_block( @@ -287,7 +287,7 @@ class VideoEncoder(nn.Module): norm_num_groups=self._norm_num_groups, spatial_padding_mode=encoder_spatial_padding_mode, ) - self.down_blocks.append(block) + self.down_blocks[i] = block # Output normalization and convolution if norm_layer == NormLayerType.GROUP_NORM: @@ -341,7 +341,7 @@ class VideoEncoder(nn.Module): sample = self.conv_in(sample, causal=True) # Process through encoder blocks - for down_block in self.down_blocks: + for down_block in self.down_blocks.values(): if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)): sample = down_block(sample, causal=True) else: diff --git a/mlx_video/utils.py b/mlx_video/utils.py index c6840bd..4b50536 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -1,10 +1,13 @@ import math +from typing import Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn +import numpy as np from functools import partial from pathlib import Path from huggingface_hub import snapshot_download +from PIL import Image def get_model_path(model_repo: str): """Get or download LTX-2 model path.""" @@ -160,3 +163,122 @@ def get_timestep_embedding( emb = mx.pad(emb, [(0, 0), (0, 1)]) return emb + + +def load_image( + image_path: Union[str, Path], + height: Optional[int] = None, + width: Optional[int] = None, +) -> mx.array: + """Load and preprocess an image for I2V conditioning. + + Args: + image_path: Path to the image file + height: Target height (must be divisible by 32). If None, uses original. + width: Target width (must be divisible by 32). If None, uses original. + + Returns: + Image tensor of shape (H, W, 3) in range [0, 1] + """ + image = Image.open(image_path).convert("RGB") + + # Resize if dimensions specified + if height is not None and width is not None: + image = image.resize((width, height), Image.Resampling.LANCZOS) + elif height is not None or width is not None: + # If only one dimension specified, resize preserving aspect ratio + orig_w, orig_h = image.size + if height is not None: + scale = height / orig_h + new_w = int(orig_w * scale) + new_w = (new_w // 32) * 32 # Round to nearest 32 + image = image.resize((new_w, height), Image.Resampling.LANCZOS) + else: + scale = width / orig_w + new_h = int(orig_h * scale) + new_h = (new_h // 32) * 32 # Round to nearest 32 + image = image.resize((width, new_h), Image.Resampling.LANCZOS) + else: + # Round to nearest 32 + orig_w, orig_h = image.size + new_w = (orig_w // 32) * 32 + new_h = (orig_h // 32) * 32 + if new_w != orig_w or new_h != orig_h: + image = image.resize((new_w, new_h), Image.Resampling.LANCZOS) + + # Convert to numpy then MLX + image_np = np.array(image).astype(np.float32) / 255.0 + return mx.array(image_np) + + +def resize_image_aspect_ratio( + image: mx.array, + long_side: int = 512, +) -> mx.array: + """Resize image preserving aspect ratio, making long side = long_side. + + Args: + image: Image tensor of shape (H, W, 3) + long_side: Target size for the longer dimension + + Returns: + Resized image tensor + """ + h, w = image.shape[:2] + + if h > w: + new_h = long_side + new_w = int(w * long_side / h) + else: + new_w = long_side + new_h = int(h * long_side / w) + + # Round to nearest 32 + new_h = (new_h // 32) * 32 + new_w = (new_w // 32) * 32 + + # Use PIL for high-quality resize + image_np = np.array(image) + if image_np.max() <= 1.0: + image_np = (image_np * 255).astype(np.uint8) + pil_image = Image.fromarray(image_np) + pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS) + + return mx.array(np.array(pil_image).astype(np.float32) / 255.0) + + +def prepare_image_for_encoding( + image: mx.array, + target_height: int, + target_width: int, +) -> mx.array: + """Prepare image for VAE encoding by resizing and normalizing. + + Args: + image: Image tensor of shape (H, W, 3) in range [0, 1] + target_height: Target height for the video + target_width: Target width for the video + + Returns: + Image tensor ready for encoding, shape (1, 3, 1, H, W) in range [-1, 1] + """ + h, w = image.shape[:2] + + # Resize if needed + if h != target_height or w != target_width: + image_np = np.array(image) + if image_np.max() <= 1.0: + image_np = (image_np * 255).astype(np.uint8) + pil_image = Image.fromarray(image_np) + pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS) + image = mx.array(np.array(pil_image).astype(np.float32) / 255.0) + + # Normalize to [-1, 1] + image = image * 2.0 - 1.0 + + # Convert to (B, C, 1, H, W) + image = mx.transpose(image, (2, 0, 1)) # (3, H, W) + image = mx.expand_dims(image, axis=0) # (1, 3, H, W) + image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W) + + return image