Add image-to-video (I2V) conditioning support

- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation.
- Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning.
- Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`.
- Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning.
- Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
Prince Canuma
2026-01-17 00:19:52 +01:00
parent 5f86e881d7
commit 146f5d2981
11 changed files with 937 additions and 67 deletions

View File

@@ -0,0 +1,3 @@
"""Conditioning modules for LTX-2 video generation."""
from mlx_video.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning

View File

@@ -0,0 +1,196 @@
"""Latent-based conditioning for I2V (Image-to-Video) generation.
This module provides conditioning that injects encoded image latents into
the video generation process at specific frame positions.
"""
from dataclasses import dataclass
from typing import Optional, List, Tuple
import mlx.core as mx
@dataclass
class VideoConditionByLatentIndex:
"""Condition video generation by injecting latents at a specific frame index.
This replaces the latent at the specified frame index with the conditioned
latent and controls how much denoising is applied via the strength parameter.
Args:
latent: Encoded image latent of shape (B, C, 1, H, W)
frame_idx: Frame index to condition (0 = first frame)
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
"""
latent: mx.array
frame_idx: int = 0
strength: float = 1.0
def get_num_latent_frames(self) -> int:
"""Get number of latent frames in the conditioning."""
return self.latent.shape[2]
@dataclass
class LatentState:
"""State for latent diffusion with conditioning support.
Attributes:
latent: Current noisy latent (B, C, F, H, W)
clean_latent: Clean conditioning latent (B, C, F, H, W)
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
1.0 = full denoise, 0.0 = keep clean
"""
latent: mx.array
clean_latent: mx.array
denoise_mask: mx.array
def clone(self) -> "LatentState":
"""Create a copy of the state."""
return LatentState(
latent=self.latent,
clean_latent=self.clean_latent,
denoise_mask=self.denoise_mask,
)
def create_initial_state(
shape: Tuple[int, ...],
seed: Optional[int] = None,
noise_scale: float = 1.0,
) -> LatentState:
"""Create initial noisy latent state.
Args:
shape: Shape of latent (B, C, F, H, W)
seed: Optional random seed
noise_scale: Scale for initial noise (sigma)
Returns:
Initial LatentState with random noise
"""
if seed is not None:
mx.random.seed(seed)
noise = mx.random.normal(shape)
return LatentState(
latent=noise * noise_scale,
clean_latent=mx.zeros(shape),
denoise_mask=mx.ones((shape[0], 1, shape[2], 1, 1)), # Full denoise by default
)
def apply_conditioning(
state: LatentState,
conditionings: List[VideoConditionByLatentIndex],
) -> LatentState:
"""Apply conditioning items to a latent state.
Args:
state: Current latent state
conditionings: List of conditioning items to apply
Returns:
Updated LatentState with conditioning applied
"""
state = state.clone()
b, c, f, h, w = state.latent.shape
for cond in conditionings:
cond_latent = cond.latent
frame_idx = cond.frame_idx
strength = cond.strength
# Validate shapes
_, cond_c, cond_f, cond_h, cond_w = cond_latent.shape
if (cond_c, cond_h, cond_w) != (c, h, w):
raise ValueError(
f"Conditioning latent spatial shape ({cond_c}, {cond_h}, {cond_w}) "
f"does not match target shape ({c}, {h}, {w})"
)
if frame_idx >= f:
raise ValueError(
f"Frame index {frame_idx} is out of bounds for latent with {f} frames"
)
# Get the conditioning frames count
num_cond_frames = cond_f
end_idx = min(frame_idx + num_cond_frames, f)
# Replace latent at conditioning position
# state.latent[:, :, frame_idx:end_idx] = cond_latent[:, :, :end_idx - frame_idx]
latent_list = []
clean_list = []
mask_list = []
for i in range(f):
if frame_idx <= i < end_idx:
# Use conditioning latent
cond_idx = i - frame_idx
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
# Set mask: 1.0 - strength means less denoising for conditioned frames
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength))
else:
# Keep original
latent_list.append(state.latent[:, :, i:i+1])
clean_list.append(state.clean_latent[:, :, i:i+1])
mask_list.append(state.denoise_mask[:, :, i:i+1])
state.latent = mx.concatenate(latent_list, axis=2)
state.clean_latent = mx.concatenate(clean_list, axis=2)
state.denoise_mask = mx.concatenate(mask_list, axis=2)
return state
def apply_denoise_mask(
denoised: mx.array,
clean: mx.array,
denoise_mask: mx.array,
) -> mx.array:
"""Blend denoised output with clean state based on mask.
Args:
denoised: Denoised latent (B, C, F, H, W)
clean: Clean conditioning latent (B, C, F, H, W)
denoise_mask: Mask where 1.0 = use denoised, 0.0 = use clean
Returns:
Blended latent
"""
return denoised * denoise_mask + clean * (1.0 - denoise_mask)
def add_noise_with_state(
state: LatentState,
noise_scale: float,
) -> LatentState:
"""Add noise to state while respecting conditioning.
For conditioned frames (mask < 1.0), adds noise proportionally
to allow some refinement while preserving the conditioning.
Args:
state: Current latent state
noise_scale: Scale for noise (sigma)
Returns:
Updated state with noise added
"""
state = state.clone()
# Generate noise
noise = mx.random.normal(state.latent.shape)
# For fully conditioned frames (mask=0), we want to add minimal noise
# For unconditioned frames (mask=1), we want full noise
# noisy = noise * sigma + latent * (1 - sigma)
# But we scale sigma by the mask for conditioned regions
effective_scale = noise_scale * state.denoise_mask
state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale)
return state

View File

@@ -291,6 +291,59 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
return sanitized
def sanitize_vae_encoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize VAE encoder weight names from PyTorch format to MLX format.
Args:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE encoder
"""
sanitized = {}
for key, value in weights.items():
new_key = key
# Skip position_ids (not needed)
if "position_ids" in key:
continue
# Only process VAE encoder weights
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics._mean_of_means"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics._std_of_means"
else:
# Skip other per_channel_statistics keys
continue
elif key.startswith("vae.encoder."):
# Strip the vae.encoder. prefix for encoder weights
new_key = key.replace("vae.encoder.", "")
else:
# Skip other vae.* keys that are not encoder weights
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value
return sanitized
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
"""Sanitize audio VAE weight names from PyTorch format to MLX format.

View File

@@ -1,6 +1,7 @@
import argparse
import time
from pathlib import Path
from typing import Optional, List, Tuple
import mlx.core as mx
import numpy as np
@@ -22,10 +23,13 @@ class Colors:
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights
from mlx_video.utils import to_denoised
from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder_weights
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state
from mlx_video.utils import get_model_path
@@ -115,17 +119,49 @@ def denoise(
transformer: LTXModel,
sigmas: list,
verbose: bool = True,
state: Optional[LatentState] = None,
) -> mx.array:
"""Run denoising loop."""
"""Run denoising loop with optional conditioning.
Args:
latents: Noisy latent tensor (B, C, F, H, W)
positions: Position embeddings
text_embeddings: Text conditioning embeddings
transformer: LTX model
sigmas: List of sigma values for denoising schedule
verbose: Whether to show progress bar
state: Optional LatentState for I2V conditioning
Returns:
Denoised latent tensor
"""
# If state is provided, use its latent (which may have conditioning applied)
if state is not None:
latents = state.latent
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
num_tokens = f * h * w
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
# Compute per-token timesteps
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
if state is not None:
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
# Per-token timesteps: sigma * mask
timesteps = sigma * denoise_mask_flat
else:
# All tokens get the same timestep
timesteps = mx.full((b, num_tokens), sigma)
video_modality = Modality(
latent=latents_flat,
timesteps=mx.full((1,), sigma),
timesteps=timesteps,
positions=positions,
context=text_embeddings,
context_mask=None,
@@ -137,6 +173,11 @@ def denoise(
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
# Apply conditioning mask if state is provided
if state is not None:
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
mx.eval(denoised)
if sigma_next > 0:
@@ -163,10 +204,15 @@ def generate_video(
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
image: Optional[str] = None,
image_strength: float = 1.0,
image_frame_idx: int = 0,
):
"""Generate video from text prompt.
"""Generate video from text prompt, optionally conditioned on an image.
Args:
model_repo: Model repository ID
text_encoder_repo: Text encoder repository ID
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
@@ -175,6 +221,13 @@ def generate_video(
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
verbose: Whether to print progress
enhance_prompt: Whether to enhance prompt using Gemma
max_tokens: Max tokens for prompt enhancement
temperature: Temperature for prompt enhancement
image: Path to conditioning image for I2V (Image-to-Video)
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
image_frame_idx: Frame index to condition (0 = first frame)
"""
start_time = time.time()
@@ -188,8 +241,12 @@ def generate_video(
num_frames = adjusted_num_frames
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
is_i2v = image is not None
mode_str = "I2V" if is_i2v else "T2V"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
# Get model path
model_path = get_model_path(model_repo)
@@ -247,6 +304,31 @@ def generate_video(
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
# Load VAE encoder and encode image for I2V conditioning
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
del vae_encoder
mx.clear_cache()
# Stage 1: Generate at half resolution
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
@@ -256,7 +338,25 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
# Apply I2V conditioning if provided
state1 = None
if is_i2v and stage1_image_latent is not None:
# Create state with conditioning
state1 = LatentState(
latent=latents,
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
state1 = apply_conditioning(state1, [conditioning])
latents = state1.latent
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
# Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
@@ -285,7 +385,24 @@ def generate_video(
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
# Apply I2V conditioning for stage 2 if provided
state2 = None
if is_i2v and stage2_image_latent is not None:
state2 = LatentState(
latent=latents,
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
state2 = apply_conditioning(state2, [conditioning])
latents = state2.latent
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
del transformer
mx.clear_cache()
@@ -335,13 +452,18 @@ def generate_video(
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
description="Generate videos with MLX LTX-2 (T2V and I2V)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Text-to-Video (T2V)
python -m mlx_video.generate --prompt "A cat walking on grass"
python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768
python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
# Image-to-Video (I2V)
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8
"""
)
@@ -426,6 +548,24 @@ Examples:
default=0.7,
help="Temperature for prompt enhancement (default: 0.7)"
)
parser.add_argument(
"--image", "-i",
type=str,
default=None,
help="Path to conditioning image for I2V (Image-to-Video) generation"
)
parser.add_argument(
"--image-strength",
type=float,
default=1.0,
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)"
)
parser.add_argument(
"--image-frame-idx",
type=int,
default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)"
)
args = parser.parse_args()
generate_video(

View File

@@ -3,7 +3,7 @@
import argparse
import time
from pathlib import Path
from typing import Optional
from typing import Optional, List
import mlx.core as mx
import numpy as np
@@ -27,9 +27,12 @@ from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeTyp
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
from mlx_video.utils import to_denoised, get_model_path
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
# Distilled sigma schedules
@@ -141,13 +144,35 @@ def denoise_av(
transformer: LTXModel,
sigmas: list,
verbose: bool = True,
video_state: Optional[LatentState] = None,
) -> tuple[mx.array, mx.array]:
"""Run denoising loop for audio-video generation."""
"""Run denoising loop for audio-video generation with optional I2V conditioning.
Args:
video_latents: Video latent tensor (B, C, F, H, W)
audio_latents: Audio latent tensor (B, C, T, F)
video_positions: Video position embeddings
audio_positions: Audio position embeddings
video_embeddings: Video text embeddings
audio_embeddings: Audio text embeddings
transformer: LTX model
sigmas: List of sigma values
verbose: Whether to show progress bar
video_state: Optional LatentState for I2V conditioning
Returns:
Tuple of (video_latents, audio_latents)
"""
# If video state is provided, use its latent
if video_state is not None:
video_latents = video_state.latent
for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
# Flatten video latents
b, c, f, h, w = video_latents.shape
num_video_tokens = f * h * w
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
@@ -155,9 +180,22 @@ def denoise_av(
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
# Compute per-token timesteps for video
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
if video_state is not None:
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
# Per-token timesteps: sigma * mask
video_timesteps = sigma * denoise_mask_flat
else:
# All tokens get the same timestep
video_timesteps = mx.full((b, num_video_tokens), sigma)
video_modality = Modality(
latent=video_flat,
timesteps=mx.full((1,), sigma),
timesteps=video_timesteps,
positions=video_positions,
context=video_embeddings,
context_mask=None,
@@ -166,7 +204,7 @@ def denoise_av(
audio_modality = Modality(
latent=audio_flat,
timesteps=mx.full((1,), sigma),
timesteps=mx.full((ab, at), sigma),
positions=audio_positions,
context=audio_embeddings,
context_mask=None,
@@ -184,6 +222,11 @@ def denoise_av(
# Compute denoised
video_denoised = to_denoised(video_latents, video_velocity, sigma)
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
# Apply conditioning mask for video if state is provided
if video_state is not None:
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
mx.eval(video_denoised, audio_denoised)
# Euler step
@@ -317,8 +360,31 @@ def generate_video_with_audio(
enhance_prompt: bool = False,
max_tokens: int = 512,
temperature: float = 0.7,
image: Optional[str] = None,
image_strength: float = 1.0,
image_frame_idx: int = 0,
):
"""Generate video with synchronized audio from text prompt."""
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image.
Args:
model_repo: Model repository ID
text_encoder_repo: Text encoder repository ID
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames
seed: Random seed
fps: Frames per second
output_path: Output video path
output_audio_path: Output audio path
verbose: Whether to print progress
enhance_prompt: Whether to enhance prompt using Gemma
max_tokens: Max tokens for prompt enhancement
temperature: Temperature for prompt enhancement
image: Path to conditioning image for I2V
image_strength: Conditioning strength (1.0 = full denoise)
image_frame_idx: Frame index to condition (0 = first frame)
"""
start_time = time.time()
# Validate dimensions
@@ -333,9 +399,13 @@ def generate_video_with_audio(
# Calculate audio frames
audio_frames = compute_audio_frames(num_frames, fps)
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
is_i2v = image is not None
mode_str = "I2V+Audio" if is_i2v else "T2V+Audio"
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
if is_i2v:
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
model_path = get_model_path(model_repo)
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
@@ -400,6 +470,31 @@ def generate_video_with_audio(
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
# Load VAE encoder and encode image for I2V conditioning
stage1_image_latent = None
stage2_image_latent = None
if is_i2v:
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
del vae_encoder
mx.clear_cache()
# Initialize latents
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
@@ -412,12 +507,30 @@ def generate_video_with_audio(
audio_positions = create_audio_position_grid(1, audio_frames)
mx.eval(video_positions, audio_positions)
# Apply I2V conditioning for stage 1 if provided
video_state1 = None
if is_i2v and stage1_image_latent is not None:
video_state1 = LatentState(
latent=video_latents,
clean_latent=mx.zeros_like(video_latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
video_state1 = apply_conditioning(video_state1, [conditioning])
video_latents = video_state1.latent
mx.eval(video_latents)
# Stage 1 denoising
video_latents, audio_latents = denoise_av(
video_latents, audio_latents,
video_positions, audio_positions,
video_embeddings, audio_embeddings,
transformer, STAGE_1_SIGMAS, verbose=verbose
transformer, STAGE_1_SIGMAS, verbose=verbose,
video_state=video_state1
)
# Upsample video latents
@@ -449,11 +562,29 @@ def generate_video_with_audio(
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
mx.eval(video_latents, audio_latents)
# Apply I2V conditioning for stage 2 if provided
video_state2 = None
if is_i2v and stage2_image_latent is not None:
video_state2 = LatentState(
latent=video_latents,
clean_latent=mx.zeros_like(video_latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
)
conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
video_state2 = apply_conditioning(video_state2, [conditioning])
video_latents = video_state2.latent
mx.eval(video_latents)
video_latents, audio_latents = denoise_av(
video_latents, audio_latents,
video_positions, audio_positions,
video_embeddings, audio_embeddings,
transformer, STAGE_2_SIGMAS, verbose=verbose
transformer, STAGE_2_SIGMAS, verbose=verbose,
video_state=video_state2
)
del transformer
@@ -549,13 +680,18 @@ def generate_video_with_audio(
def main():
parser = argparse.ArgumentParser(
description="Generate videos with synchronized audio using MLX LTX-2",
description="Generate videos with synchronized audio using MLX LTX-2 (T2V and I2V)",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Text-to-Video with Audio (T2V+Audio)
python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach"
python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt
python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav
# Image-to-Video with Audio (I2V+Audio)
python -m mlx_video.generate_av --prompt "A person dancing" --image photo.jpg
python -m mlx_video.generate_av --prompt "Waves crashing" --image beach.png --image-strength 0.8
"""
)
@@ -587,6 +723,12 @@ Examples:
help="Max tokens for prompt enhancement")
parser.add_argument("--temperature", type=float, default=0.7,
help="Temperature for prompt enhancement")
parser.add_argument("--image", "-i", type=str, default=None,
help="Path to conditioning image for I2V (Image-to-Video) generation")
parser.add_argument("--image-strength", type=float, default=1.0,
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
parser.add_argument("--image-frame-idx", type=int, default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)")
args = parser.parse_args()
@@ -605,6 +747,9 @@ Examples:
enhance_prompt=args.enhance_prompt,
max_tokens=args.max_tokens,
temperature=args.temperature,
image=args.image,
image_strength=args.image_strength,
image_frame_idx=args.image_frame_idx,
)

View File

@@ -1 +1,3 @@
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder

View File

@@ -0,0 +1,187 @@
"""Video VAE Encoder for LTX-2 Image-to-Video.
The encoder compresses input images/videos to latent representations.
Used for I2V (image-to-video) conditioning by encoding the input image
to latent space, which can then be used to condition video generation.
"""
from pathlib import Path
from typing import List, Tuple, Any, Optional
import json
import mlx.core as mx
import mlx.nn as nn
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
def load_vae_encoder(model_path: str) -> VideoEncoder:
"""Load VAE encoder from safetensors file.
Args:
model_path: Path to the model weights (safetensors file or directory)
Returns:
Loaded VideoEncoder instance
"""
from safetensors import safe_open
model_path = Path(model_path)
# Try to find the weights file
if model_path.is_file() and model_path.suffix == ".safetensors":
weights_path = model_path
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
else:
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE encoder from {weights_path}...")
# Read config from safetensors metadata
encoder_blocks = []
norm_layer = NormLayerType.PIXEL_NORM
latent_log_var = LogVarianceType.UNIFORM
patch_size = 4
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
# Parse encoder blocks
raw_blocks = vae_config.get("encoder_blocks", [])
for block in raw_blocks:
if isinstance(block, list) and len(block) == 2:
name, params = block
encoder_blocks.append((name, params))
# Parse other config
norm_str = vae_config.get("norm_layer", "pixel_norm")
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
var_str = vae_config.get("latent_log_var", "uniform")
if var_str == "uniform":
latent_log_var = LogVarianceType.UNIFORM
elif var_str == "per_channel":
latent_log_var = LogVarianceType.PER_CHANNEL
elif var_str == "constant":
latent_log_var = LogVarianceType.CONSTANT
else:
latent_log_var = LogVarianceType.NONE
patch_size = vae_config.get("patch_size", 4)
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
except Exception as e:
print(f" Could not read config from metadata: {e}")
# Use default config
encoder_blocks = [
("res_x", {"num_layers": 4}),
("compress_space_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_time_res", {"multiplier": 2}),
("res_x", {"num_layers": 6}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
("compress_all_res", {"multiplier": 2}),
("res_x", {"num_layers": 2}),
]
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
# Create encoder
encoder = VideoEncoder(
convolution_dimensions=3,
in_channels=3,
out_channels=128,
encoder_blocks=encoder_blocks,
patch_size=patch_size,
norm_layer=norm_layer,
latent_log_var=latent_log_var,
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
)
# Load weights
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
if has_vae_prefix:
prefix = "vae.encoder."
stats_prefix = "vae.per_channel_statistics."
else:
prefix = "encoder."
stats_prefix = "per_channel_statistics."
# Load per-channel statistics for normalization
mean_key = f"{stats_prefix}mean-of-means"
std_key = f"{stats_prefix}std-of-means"
if mean_key in weights:
encoder.per_channel_statistics.mean = weights[mean_key]
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
if std_key in weights:
encoder.per_channel_statistics.std = weights[std_key]
print(f" Loaded latent std: shape {weights[std_key].shape}")
# Build encoder weights dict with key remapping
encoder_weights = {}
for key, value in weights.items():
if not key.startswith(prefix):
continue
# Remove prefix
new_key = key[len(prefix):]
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
if ".weight" in key and value.ndim == 5:
value = mx.transpose(value, (0, 2, 3, 4, 1))
encoder_weights[new_key] = value
print(f" Found {len(encoder_weights)} encoder weights")
# Load weights
encoder.load_weights(list(encoder_weights.items()), strict=False)
print("VAE encoder loaded successfully")
return encoder
def encode_image(
image: mx.array,
encoder: VideoEncoder,
) -> mx.array:
"""Encode a single image to latent space.
Args:
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
encoder: Loaded VAE encoder
Returns:
Latent tensor of shape (1, 128, 1, H//32, W//32)
"""
# Add batch dimension if needed
if image.ndim == 3:
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
# Convert from (B, H, W, C) to (B, C, H, W)
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
# Normalize to [-1, 1]
if image.max() > 1.0:
image = image / 255.0
image = image * 2.0 - 1.0
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
# Encode
latent = encoder(image)
return latent

View File

@@ -141,9 +141,10 @@ class UNetMidBlock3D(nn.Module):
self.num_layers = num_layers
# Create ResNet blocks
self.resnets = [
ResnetBlock3D(
# Create ResNet blocks - use dict for MLX parameter tracking
# Named res_blocks to match PyTorch weight keys
self.res_blocks = {
i: ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
@@ -154,8 +155,8 @@ class UNetMidBlock3D(nn.Module):
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
for i in range(num_layers)
}
def __call__(
self,
@@ -165,7 +166,7 @@ class UNetMidBlock3D(nn.Module):
generator: Optional[int] = None,
) -> mx.array:
for resnet in self.resnets:
for resnet in self.res_blocks.values():
x = resnet(x, causal=causal, generator=generator)
return x

View File

@@ -9,6 +9,15 @@ from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingMode
class SpaceToDepthDownsample(nn.Module):
"""Space-to-depth downsampling with 3x3 conv and skip connection.
PyTorch-compatible implementation:
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
2. Space-to-depth on conv output: channels * prod(stride)
3. Space-to-depth on input with group averaging for skip connection
4. Add skip connection
"""
def __init__(
self,
dims: int,
@@ -17,7 +26,6 @@ class SpaceToDepthDownsample(nn.Module):
stride: Union[int, Tuple[int, int, int]],
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
):
super().__init__()
if isinstance(stride, int):
@@ -25,61 +33,74 @@ class SpaceToDepthDownsample(nn.Module):
self.stride = stride
self.dims = dims
self.out_channels = out_channels
# Calculate the multiplier for channels
# Calculate channels
multiplier = stride[0] * stride[1] * stride[2]
intermediate_channels = in_channels * multiplier
self.group_size = in_channels * multiplier // out_channels
conv_out_channels = out_channels // multiplier
# 1x1x1 convolution to adjust channels
# 3x3 convolution (not 1x1)
self.conv = CausalConv3d(
in_channels=intermediate_channels,
out_channels=out_channels,
kernel_size=1,
in_channels=in_channels,
out_channels=conv_out_channels,
kernel_size=3,
stride=1,
padding=0,
padding=1,
spatial_padding_mode=spatial_padding_mode,
)
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
def _space_to_depth(self, x: mx.array) -> mx.array:
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Reshape to group spatial elements
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
return x
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
b, c, d, h, w = x.shape
st, sh, sw = self.stride
# Temporal padding for causal mode
if st == 2:
# Duplicate first frame for padding
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
d = d + 1
# Pad if necessary to make dimensions divisible by stride
pad_d = (st - d % st) % st
pad_h = (sh - h % sh) % sh
pad_w = (sw - w % sw) % sw
if pad_d > 0 or pad_h > 0 or pad_w > 0:
# For causal, pad at the end of temporal dimension
if causal:
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
else:
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
b, c, d, h, w = x.shape
# Skip connection: space-to-depth on input, then group mean
x_in = self._space_to_depth(x)
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
b2, c2, d2, h2, w2 = x_in.shape
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
# Reshape to group spatial elements
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
# Conv branch: apply conv then space-to-depth
x_conv = self.conv(x, causal=causal)
x_conv = self._space_to_depth(x_conv)
# Permute to move stride elements to channel dim
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
# Reshape to combine channels
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
new_c = c * st * sh * sw
new_d = d // st
new_h = h // sh
new_w = w // sw
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
# Apply 1x1 conv to adjust channels
x = self.conv(x, causal=causal)
return x
# Add skip connection
return x_conv + x_in
class DepthToSpaceUpsample(nn.Module):

View File

@@ -273,9 +273,9 @@ class VideoEncoder(nn.Module):
spatial_padding_mode=encoder_spatial_padding_mode,
)
# Build encoder blocks
self.down_blocks = []
for block_name, block_params in encoder_blocks:
# Build encoder blocks - use dict with int keys for MLX parameter tracking
self.down_blocks = {}
for i, (block_name, block_params) in enumerate(encoder_blocks):
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
block, feature_channels = _make_encoder_block(
@@ -287,7 +287,7 @@ class VideoEncoder(nn.Module):
norm_num_groups=self._norm_num_groups,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.down_blocks.append(block)
self.down_blocks[i] = block
# Output normalization and convolution
if norm_layer == NormLayerType.GROUP_NORM:
@@ -341,7 +341,7 @@ class VideoEncoder(nn.Module):
sample = self.conv_in(sample, causal=True)
# Process through encoder blocks
for down_block in self.down_blocks:
for down_block in self.down_blocks.values():
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
sample = down_block(sample, causal=True)
else:

View File

@@ -1,10 +1,13 @@
import math
from typing import Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from functools import partial
from pathlib import Path
from huggingface_hub import snapshot_download
from PIL import Image
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
@@ -160,3 +163,122 @@ def get_timestep_embedding(
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb
def load_image(
image_path: Union[str, Path],
height: Optional[int] = None,
width: Optional[int] = None,
) -> mx.array:
"""Load and preprocess an image for I2V conditioning.
Args:
image_path: Path to the image file
height: Target height (must be divisible by 32). If None, uses original.
width: Target width (must be divisible by 32). If None, uses original.
Returns:
Image tensor of shape (H, W, 3) in range [0, 1]
"""
image = Image.open(image_path).convert("RGB")
# Resize if dimensions specified
if height is not None and width is not None:
image = image.resize((width, height), Image.Resampling.LANCZOS)
elif height is not None or width is not None:
# If only one dimension specified, resize preserving aspect ratio
orig_w, orig_h = image.size
if height is not None:
scale = height / orig_h
new_w = int(orig_w * scale)
new_w = (new_w // 32) * 32 # Round to nearest 32
image = image.resize((new_w, height), Image.Resampling.LANCZOS)
else:
scale = width / orig_w
new_h = int(orig_h * scale)
new_h = (new_h // 32) * 32 # Round to nearest 32
image = image.resize((width, new_h), Image.Resampling.LANCZOS)
else:
# Round to nearest 32
orig_w, orig_h = image.size
new_w = (orig_w // 32) * 32
new_h = (orig_h // 32) * 32
if new_w != orig_w or new_h != orig_h:
image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
# Convert to numpy then MLX
image_np = np.array(image).astype(np.float32) / 255.0
return mx.array(image_np)
def resize_image_aspect_ratio(
image: mx.array,
long_side: int = 512,
) -> mx.array:
"""Resize image preserving aspect ratio, making long side = long_side.
Args:
image: Image tensor of shape (H, W, 3)
long_side: Target size for the longer dimension
Returns:
Resized image tensor
"""
h, w = image.shape[:2]
if h > w:
new_h = long_side
new_w = int(w * long_side / h)
else:
new_w = long_side
new_h = int(h * long_side / w)
# Round to nearest 32
new_h = (new_h // 32) * 32
new_w = (new_w // 32) * 32
# Use PIL for high-quality resize
image_np = np.array(image)
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
return mx.array(np.array(pil_image).astype(np.float32) / 255.0)
def prepare_image_for_encoding(
image: mx.array,
target_height: int,
target_width: int,
) -> mx.array:
"""Prepare image for VAE encoding by resizing and normalizing.
Args:
image: Image tensor of shape (H, W, 3) in range [0, 1]
target_height: Target height for the video
target_width: Target width for the video
Returns:
Image tensor ready for encoding, shape (1, 3, 1, H, W) in range [-1, 1]
"""
h, w = image.shape[:2]
# Resize if needed
if h != target_height or w != target_width:
image_np = np.array(image)
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
# Normalize to [-1, 1]
image = image * 2.0 - 1.0
# Convert to (B, C, 1, H, W)
image = mx.transpose(image, (2, 0, 1)) # (3, H, W)
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
return image