Add image-to-video (I2V) conditioning support
- Introduced `load_image`, `prepare_image_for_encoding`, and `apply_conditioning` functions for handling image inputs and conditioning during video generation. - Enhanced `generate_video` and `denoise_av` functions to accept optional image inputs for I2V conditioning. - Updated command-line interface to include parameters for image conditioning, such as `--image`, `--image-strength`, and `--image-frame-idx`. - Added new `VideoConditionByLatentIndex` and `LatentState` classes for managing latent states with conditioning. - Implemented VAE encoder loading and image encoding for conditioning in the video generation process.d
This commit is contained in:
3
mlx_video/conditioning/__init__.py
Normal file
3
mlx_video/conditioning/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Conditioning modules for LTX-2 video generation."""
|
||||||
|
|
||||||
|
from mlx_video.conditioning.latent import VideoConditionByLatentIndex, apply_conditioning
|
||||||
196
mlx_video/conditioning/latent.py
Normal file
196
mlx_video/conditioning/latent.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""Latent-based conditioning for I2V (Image-to-Video) generation.
|
||||||
|
|
||||||
|
This module provides conditioning that injects encoded image latents into
|
||||||
|
the video generation process at specific frame positions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VideoConditionByLatentIndex:
|
||||||
|
"""Condition video generation by injecting latents at a specific frame index.
|
||||||
|
|
||||||
|
This replaces the latent at the specified frame index with the conditioned
|
||||||
|
latent and controls how much denoising is applied via the strength parameter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latent: Encoded image latent of shape (B, C, 1, H, W)
|
||||||
|
frame_idx: Frame index to condition (0 = first frame)
|
||||||
|
strength: Denoising strength (1.0 = full denoise, 0.0 = keep original)
|
||||||
|
"""
|
||||||
|
latent: mx.array
|
||||||
|
frame_idx: int = 0
|
||||||
|
strength: float = 1.0
|
||||||
|
|
||||||
|
def get_num_latent_frames(self) -> int:
|
||||||
|
"""Get number of latent frames in the conditioning."""
|
||||||
|
return self.latent.shape[2]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LatentState:
|
||||||
|
"""State for latent diffusion with conditioning support.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
latent: Current noisy latent (B, C, F, H, W)
|
||||||
|
clean_latent: Clean conditioning latent (B, C, F, H, W)
|
||||||
|
denoise_mask: Per-frame denoising mask (B, 1, F, 1, 1) where
|
||||||
|
1.0 = full denoise, 0.0 = keep clean
|
||||||
|
"""
|
||||||
|
latent: mx.array
|
||||||
|
clean_latent: mx.array
|
||||||
|
denoise_mask: mx.array
|
||||||
|
|
||||||
|
def clone(self) -> "LatentState":
|
||||||
|
"""Create a copy of the state."""
|
||||||
|
return LatentState(
|
||||||
|
latent=self.latent,
|
||||||
|
clean_latent=self.clean_latent,
|
||||||
|
denoise_mask=self.denoise_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_initial_state(
|
||||||
|
shape: Tuple[int, ...],
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
noise_scale: float = 1.0,
|
||||||
|
) -> LatentState:
|
||||||
|
"""Create initial noisy latent state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: Shape of latent (B, C, F, H, W)
|
||||||
|
seed: Optional random seed
|
||||||
|
noise_scale: Scale for initial noise (sigma)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initial LatentState with random noise
|
||||||
|
"""
|
||||||
|
if seed is not None:
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
noise = mx.random.normal(shape)
|
||||||
|
|
||||||
|
return LatentState(
|
||||||
|
latent=noise * noise_scale,
|
||||||
|
clean_latent=mx.zeros(shape),
|
||||||
|
denoise_mask=mx.ones((shape[0], 1, shape[2], 1, 1)), # Full denoise by default
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_conditioning(
|
||||||
|
state: LatentState,
|
||||||
|
conditionings: List[VideoConditionByLatentIndex],
|
||||||
|
) -> LatentState:
|
||||||
|
"""Apply conditioning items to a latent state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current latent state
|
||||||
|
conditionings: List of conditioning items to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated LatentState with conditioning applied
|
||||||
|
"""
|
||||||
|
state = state.clone()
|
||||||
|
b, c, f, h, w = state.latent.shape
|
||||||
|
|
||||||
|
for cond in conditionings:
|
||||||
|
cond_latent = cond.latent
|
||||||
|
frame_idx = cond.frame_idx
|
||||||
|
strength = cond.strength
|
||||||
|
|
||||||
|
# Validate shapes
|
||||||
|
_, cond_c, cond_f, cond_h, cond_w = cond_latent.shape
|
||||||
|
if (cond_c, cond_h, cond_w) != (c, h, w):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conditioning latent spatial shape ({cond_c}, {cond_h}, {cond_w}) "
|
||||||
|
f"does not match target shape ({c}, {h}, {w})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if frame_idx >= f:
|
||||||
|
raise ValueError(
|
||||||
|
f"Frame index {frame_idx} is out of bounds for latent with {f} frames"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the conditioning frames count
|
||||||
|
num_cond_frames = cond_f
|
||||||
|
end_idx = min(frame_idx + num_cond_frames, f)
|
||||||
|
|
||||||
|
# Replace latent at conditioning position
|
||||||
|
# state.latent[:, :, frame_idx:end_idx] = cond_latent[:, :, :end_idx - frame_idx]
|
||||||
|
latent_list = []
|
||||||
|
clean_list = []
|
||||||
|
mask_list = []
|
||||||
|
|
||||||
|
for i in range(f):
|
||||||
|
if frame_idx <= i < end_idx:
|
||||||
|
# Use conditioning latent
|
||||||
|
cond_idx = i - frame_idx
|
||||||
|
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||||
|
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||||
|
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||||
|
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength))
|
||||||
|
else:
|
||||||
|
# Keep original
|
||||||
|
latent_list.append(state.latent[:, :, i:i+1])
|
||||||
|
clean_list.append(state.clean_latent[:, :, i:i+1])
|
||||||
|
mask_list.append(state.denoise_mask[:, :, i:i+1])
|
||||||
|
|
||||||
|
state.latent = mx.concatenate(latent_list, axis=2)
|
||||||
|
state.clean_latent = mx.concatenate(clean_list, axis=2)
|
||||||
|
state.denoise_mask = mx.concatenate(mask_list, axis=2)
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def apply_denoise_mask(
|
||||||
|
denoised: mx.array,
|
||||||
|
clean: mx.array,
|
||||||
|
denoise_mask: mx.array,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Blend denoised output with clean state based on mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
denoised: Denoised latent (B, C, F, H, W)
|
||||||
|
clean: Clean conditioning latent (B, C, F, H, W)
|
||||||
|
denoise_mask: Mask where 1.0 = use denoised, 0.0 = use clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Blended latent
|
||||||
|
"""
|
||||||
|
return denoised * denoise_mask + clean * (1.0 - denoise_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise_with_state(
|
||||||
|
state: LatentState,
|
||||||
|
noise_scale: float,
|
||||||
|
) -> LatentState:
|
||||||
|
"""Add noise to state while respecting conditioning.
|
||||||
|
|
||||||
|
For conditioned frames (mask < 1.0), adds noise proportionally
|
||||||
|
to allow some refinement while preserving the conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current latent state
|
||||||
|
noise_scale: Scale for noise (sigma)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated state with noise added
|
||||||
|
"""
|
||||||
|
state = state.clone()
|
||||||
|
|
||||||
|
# Generate noise
|
||||||
|
noise = mx.random.normal(state.latent.shape)
|
||||||
|
|
||||||
|
# For fully conditioned frames (mask=0), we want to add minimal noise
|
||||||
|
# For unconditioned frames (mask=1), we want full noise
|
||||||
|
# noisy = noise * sigma + latent * (1 - sigma)
|
||||||
|
# But we scale sigma by the mask for conditioned regions
|
||||||
|
|
||||||
|
effective_scale = noise_scale * state.denoise_mask
|
||||||
|
state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale)
|
||||||
|
|
||||||
|
return state
|
||||||
@@ -291,6 +291,59 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
|||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_vae_encoder_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
"""Sanitize VAE encoder weight names from PyTorch format to MLX format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights: Dictionary of weights with PyTorch naming
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with MLX-compatible naming for VAE encoder
|
||||||
|
"""
|
||||||
|
sanitized = {}
|
||||||
|
|
||||||
|
for key, value in weights.items():
|
||||||
|
new_key = key
|
||||||
|
|
||||||
|
# Skip position_ids (not needed)
|
||||||
|
if "position_ids" in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only process VAE encoder weights
|
||||||
|
if not key.startswith("vae."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle per-channel statistics key mapping
|
||||||
|
if "vae.per_channel_statistics" in key:
|
||||||
|
if key == "vae.per_channel_statistics.mean-of-means":
|
||||||
|
new_key = "per_channel_statistics._mean_of_means"
|
||||||
|
elif key == "vae.per_channel_statistics.std-of-means":
|
||||||
|
new_key = "per_channel_statistics._std_of_means"
|
||||||
|
else:
|
||||||
|
# Skip other per_channel_statistics keys
|
||||||
|
continue
|
||||||
|
elif key.startswith("vae.encoder."):
|
||||||
|
# Strip the vae.encoder. prefix for encoder weights
|
||||||
|
new_key = key.replace("vae.encoder.", "")
|
||||||
|
else:
|
||||||
|
# Skip other vae.* keys that are not encoder weights
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle Conv3d weight shape conversion
|
||||||
|
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||||
|
# MLX: (out_channels, D, H, W, in_channels)
|
||||||
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
# Handle Conv2d weight shape conversion
|
||||||
|
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 1))
|
||||||
|
|
||||||
|
sanitized[new_key] = value
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
def sanitize_audio_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
"""Sanitize audio VAE weight names from PyTorch format to MLX format.
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -22,10 +23,13 @@ class Colors:
|
|||||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||||
from mlx_video.models.ltx.ltx import LTXModel
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
from mlx_video.models.ltx.transformer import Modality
|
from mlx_video.models.ltx.transformer import Modality
|
||||||
from mlx_video.convert import sanitize_transformer_weights
|
from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder_weights
|
||||||
from mlx_video.utils import to_denoised
|
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
|
||||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||||
|
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||||
|
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
|
||||||
|
from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state
|
||||||
|
|
||||||
from mlx_video.utils import get_model_path
|
from mlx_video.utils import get_model_path
|
||||||
|
|
||||||
@@ -115,17 +119,49 @@ def denoise(
|
|||||||
transformer: LTXModel,
|
transformer: LTXModel,
|
||||||
sigmas: list,
|
sigmas: list,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
state: Optional[LatentState] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Run denoising loop."""
|
"""Run denoising loop with optional conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: Noisy latent tensor (B, C, F, H, W)
|
||||||
|
positions: Position embeddings
|
||||||
|
text_embeddings: Text conditioning embeddings
|
||||||
|
transformer: LTX model
|
||||||
|
sigmas: List of sigma values for denoising schedule
|
||||||
|
verbose: Whether to show progress bar
|
||||||
|
state: Optional LatentState for I2V conditioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denoised latent tensor
|
||||||
|
"""
|
||||||
|
# If state is provided, use its latent (which may have conditioning applied)
|
||||||
|
if state is not None:
|
||||||
|
latents = state.latent
|
||||||
|
|
||||||
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
|
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
|
||||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||||
|
|
||||||
b, c, f, h, w = latents.shape
|
b, c, f, h, w = latents.shape
|
||||||
|
num_tokens = f * h * w
|
||||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||||
|
|
||||||
|
# Compute per-token timesteps
|
||||||
|
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
|
||||||
|
if state is not None:
|
||||||
|
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
|
||||||
|
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
||||||
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||||
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
|
||||||
|
# Per-token timesteps: sigma * mask
|
||||||
|
timesteps = sigma * denoise_mask_flat
|
||||||
|
else:
|
||||||
|
# All tokens get the same timestep
|
||||||
|
timesteps = mx.full((b, num_tokens), sigma)
|
||||||
|
|
||||||
video_modality = Modality(
|
video_modality = Modality(
|
||||||
latent=latents_flat,
|
latent=latents_flat,
|
||||||
timesteps=mx.full((1,), sigma),
|
timesteps=timesteps,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
context=text_embeddings,
|
context=text_embeddings,
|
||||||
context_mask=None,
|
context_mask=None,
|
||||||
@@ -137,6 +173,11 @@ def denoise(
|
|||||||
|
|
||||||
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||||
denoised = to_denoised(latents, velocity, sigma)
|
denoised = to_denoised(latents, velocity, sigma)
|
||||||
|
|
||||||
|
# Apply conditioning mask if state is provided
|
||||||
|
if state is not None:
|
||||||
|
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
|
||||||
|
|
||||||
mx.eval(denoised)
|
mx.eval(denoised)
|
||||||
|
|
||||||
if sigma_next > 0:
|
if sigma_next > 0:
|
||||||
@@ -163,10 +204,15 @@ def generate_video(
|
|||||||
enhance_prompt: bool = False,
|
enhance_prompt: bool = False,
|
||||||
max_tokens: int = 512,
|
max_tokens: int = 512,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
|
image: Optional[str] = None,
|
||||||
|
image_strength: float = 1.0,
|
||||||
|
image_frame_idx: int = 0,
|
||||||
):
|
):
|
||||||
"""Generate video from text prompt.
|
"""Generate video from text prompt, optionally conditioned on an image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
model_repo: Model repository ID
|
||||||
|
text_encoder_repo: Text encoder repository ID
|
||||||
prompt: Text description of the video to generate
|
prompt: Text description of the video to generate
|
||||||
height: Output video height (must be divisible by 64)
|
height: Output video height (must be divisible by 64)
|
||||||
width: Output video width (must be divisible by 64)
|
width: Output video width (must be divisible by 64)
|
||||||
@@ -175,6 +221,13 @@ def generate_video(
|
|||||||
fps: Frames per second for output video
|
fps: Frames per second for output video
|
||||||
output_path: Path to save the output video
|
output_path: Path to save the output video
|
||||||
save_frames: Whether to save individual frames as images
|
save_frames: Whether to save individual frames as images
|
||||||
|
verbose: Whether to print progress
|
||||||
|
enhance_prompt: Whether to enhance prompt using Gemma
|
||||||
|
max_tokens: Max tokens for prompt enhancement
|
||||||
|
temperature: Temperature for prompt enhancement
|
||||||
|
image: Path to conditioning image for I2V (Image-to-Video)
|
||||||
|
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
|
||||||
|
image_frame_idx: Frame index to condition (0 = first frame)
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -188,8 +241,12 @@ def generate_video(
|
|||||||
num_frames = adjusted_num_frames
|
num_frames = adjusted_num_frames
|
||||||
|
|
||||||
|
|
||||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
is_i2v = image is not None
|
||||||
|
mode_str = "I2V" if is_i2v else "T2V"
|
||||||
|
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||||
|
if is_i2v:
|
||||||
|
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
|
||||||
|
|
||||||
# Get model path
|
# Get model path
|
||||||
model_path = get_model_path(model_repo)
|
model_path = get_model_path(model_repo)
|
||||||
@@ -247,6 +304,31 @@ def generate_video(
|
|||||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||||
mx.eval(transformer.parameters())
|
mx.eval(transformer.parameters())
|
||||||
|
|
||||||
|
# Load VAE encoder and encode image for I2V conditioning
|
||||||
|
stage1_image_latent = None
|
||||||
|
stage2_image_latent = None
|
||||||
|
if is_i2v:
|
||||||
|
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
|
||||||
|
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||||
|
mx.eval(vae_encoder.parameters())
|
||||||
|
|
||||||
|
# Load and prepare image for stage 1 (half resolution)
|
||||||
|
input_image = load_image(image, height=height // 2, width=width // 2)
|
||||||
|
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
|
||||||
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||||
|
mx.eval(stage1_image_latent)
|
||||||
|
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
||||||
|
|
||||||
|
# Load and prepare image for stage 2 (full resolution)
|
||||||
|
input_image = load_image(image, height=height, width=width)
|
||||||
|
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
|
||||||
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||||
|
mx.eval(stage2_image_latent)
|
||||||
|
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
||||||
|
|
||||||
|
del vae_encoder
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
# Stage 1: Generate at half resolution
|
# Stage 1: Generate at half resolution
|
||||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
@@ -256,7 +338,25 @@ def generate_video(
|
|||||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
|
# Apply I2V conditioning if provided
|
||||||
|
state1 = None
|
||||||
|
if is_i2v and stage1_image_latent is not None:
|
||||||
|
# Create state with conditioning
|
||||||
|
state1 = LatentState(
|
||||||
|
latent=latents,
|
||||||
|
clean_latent=mx.zeros_like(latents),
|
||||||
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||||
|
)
|
||||||
|
conditioning = VideoConditionByLatentIndex(
|
||||||
|
latent=stage1_image_latent,
|
||||||
|
frame_idx=image_frame_idx,
|
||||||
|
strength=image_strength,
|
||||||
|
)
|
||||||
|
state1 = apply_conditioning(state1, [conditioning])
|
||||||
|
latents = state1.latent
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
|
||||||
|
|
||||||
# Upsample latents
|
# Upsample latents
|
||||||
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
||||||
@@ -285,7 +385,24 @@ def generate_video(
|
|||||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
|
# Apply I2V conditioning for stage 2 if provided
|
||||||
|
state2 = None
|
||||||
|
if is_i2v and stage2_image_latent is not None:
|
||||||
|
state2 = LatentState(
|
||||||
|
latent=latents,
|
||||||
|
clean_latent=mx.zeros_like(latents),
|
||||||
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||||
|
)
|
||||||
|
conditioning = VideoConditionByLatentIndex(
|
||||||
|
latent=stage2_image_latent,
|
||||||
|
frame_idx=image_frame_idx,
|
||||||
|
strength=image_strength,
|
||||||
|
)
|
||||||
|
state2 = apply_conditioning(state2, [conditioning])
|
||||||
|
latents = state2.latent
|
||||||
|
mx.eval(latents)
|
||||||
|
|
||||||
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
|
||||||
|
|
||||||
del transformer
|
del transformer
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
@@ -335,13 +452,18 @@ def generate_video(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate videos with MLX LTX-2",
|
description="Generate videos with MLX LTX-2 (T2V and I2V)",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
|
# Text-to-Video (T2V)
|
||||||
python -m mlx_video.generate --prompt "A cat walking on grass"
|
python -m mlx_video.generate --prompt "A cat walking on grass"
|
||||||
python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768
|
python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768
|
||||||
python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
|
python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
|
||||||
|
|
||||||
|
# Image-to-Video (I2V)
|
||||||
|
python -m mlx_video.generate --prompt "A person dancing" --image photo.jpg
|
||||||
|
python -m mlx_video.generate --prompt "Waves crashing" --image beach.png --image-strength 0.8
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -426,6 +548,24 @@ Examples:
|
|||||||
default=0.7,
|
default=0.7,
|
||||||
help="Temperature for prompt enhancement (default: 0.7)"
|
help="Temperature for prompt enhancement (default: 0.7)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image", "-i",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to conditioning image for I2V (Image-to-Video) generation"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image-strength",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image-frame-idx",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Frame index to condition for I2V (0 = first frame, default: 0)"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
generate_video(
|
generate_video(
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -27,9 +27,12 @@ from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeTyp
|
|||||||
from mlx_video.models.ltx.ltx import LTXModel
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
from mlx_video.models.ltx.transformer import Modality
|
from mlx_video.models.ltx.transformer import Modality
|
||||||
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
|
from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_weights, sanitize_vocoder_weights
|
||||||
from mlx_video.utils import to_denoised, get_model_path
|
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
|
||||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||||
|
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||||
|
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
|
||||||
|
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
|
||||||
|
|
||||||
|
|
||||||
# Distilled sigma schedules
|
# Distilled sigma schedules
|
||||||
@@ -141,13 +144,35 @@ def denoise_av(
|
|||||||
transformer: LTXModel,
|
transformer: LTXModel,
|
||||||
sigmas: list,
|
sigmas: list,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
|
video_state: Optional[LatentState] = None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
"""Run denoising loop for audio-video generation."""
|
"""Run denoising loop for audio-video generation with optional I2V conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_latents: Video latent tensor (B, C, F, H, W)
|
||||||
|
audio_latents: Audio latent tensor (B, C, T, F)
|
||||||
|
video_positions: Video position embeddings
|
||||||
|
audio_positions: Audio position embeddings
|
||||||
|
video_embeddings: Video text embeddings
|
||||||
|
audio_embeddings: Audio text embeddings
|
||||||
|
transformer: LTX model
|
||||||
|
sigmas: List of sigma values
|
||||||
|
verbose: Whether to show progress bar
|
||||||
|
video_state: Optional LatentState for I2V conditioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (video_latents, audio_latents)
|
||||||
|
"""
|
||||||
|
# If video state is provided, use its latent
|
||||||
|
if video_state is not None:
|
||||||
|
video_latents = video_state.latent
|
||||||
|
|
||||||
for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose):
|
for i in tqdm(range(len(sigmas) - 1), desc="Denoising A/V", disable=not verbose):
|
||||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||||
|
|
||||||
# Flatten video latents
|
# Flatten video latents
|
||||||
b, c, f, h, w = video_latents.shape
|
b, c, f, h, w = video_latents.shape
|
||||||
|
num_video_tokens = f * h * w
|
||||||
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
video_flat = mx.transpose(mx.reshape(video_latents, (b, c, -1)), (0, 2, 1))
|
||||||
|
|
||||||
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
|
# Flatten audio latents: (B, C, T, F) -> (B, T, C*F)
|
||||||
@@ -155,9 +180,22 @@ def denoise_av(
|
|||||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
|
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) # (B, T, C, F)
|
||||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
||||||
|
|
||||||
|
# Compute per-token timesteps for video
|
||||||
|
# For I2V: conditioned tokens get timestep=0 (mask=0), unconditioned get timestep=sigma (mask=1)
|
||||||
|
if video_state is not None:
|
||||||
|
# Reshape denoise_mask from (B, 1, F, 1, 1) to (B, num_tokens)
|
||||||
|
denoise_mask_flat = mx.reshape(video_state.denoise_mask, (b, 1, f, 1, 1))
|
||||||
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||||
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
||||||
|
# Per-token timesteps: sigma * mask
|
||||||
|
video_timesteps = sigma * denoise_mask_flat
|
||||||
|
else:
|
||||||
|
# All tokens get the same timestep
|
||||||
|
video_timesteps = mx.full((b, num_video_tokens), sigma)
|
||||||
|
|
||||||
video_modality = Modality(
|
video_modality = Modality(
|
||||||
latent=video_flat,
|
latent=video_flat,
|
||||||
timesteps=mx.full((1,), sigma),
|
timesteps=video_timesteps,
|
||||||
positions=video_positions,
|
positions=video_positions,
|
||||||
context=video_embeddings,
|
context=video_embeddings,
|
||||||
context_mask=None,
|
context_mask=None,
|
||||||
@@ -166,7 +204,7 @@ def denoise_av(
|
|||||||
|
|
||||||
audio_modality = Modality(
|
audio_modality = Modality(
|
||||||
latent=audio_flat,
|
latent=audio_flat,
|
||||||
timesteps=mx.full((1,), sigma),
|
timesteps=mx.full((ab, at), sigma),
|
||||||
positions=audio_positions,
|
positions=audio_positions,
|
||||||
context=audio_embeddings,
|
context=audio_embeddings,
|
||||||
context_mask=None,
|
context_mask=None,
|
||||||
@@ -184,6 +222,11 @@ def denoise_av(
|
|||||||
# Compute denoised
|
# Compute denoised
|
||||||
video_denoised = to_denoised(video_latents, video_velocity, sigma)
|
video_denoised = to_denoised(video_latents, video_velocity, sigma)
|
||||||
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
||||||
|
|
||||||
|
# Apply conditioning mask for video if state is provided
|
||||||
|
if video_state is not None:
|
||||||
|
video_denoised = apply_denoise_mask(video_denoised, video_state.clean_latent, video_state.denoise_mask)
|
||||||
|
|
||||||
mx.eval(video_denoised, audio_denoised)
|
mx.eval(video_denoised, audio_denoised)
|
||||||
|
|
||||||
# Euler step
|
# Euler step
|
||||||
@@ -317,8 +360,31 @@ def generate_video_with_audio(
|
|||||||
enhance_prompt: bool = False,
|
enhance_prompt: bool = False,
|
||||||
max_tokens: int = 512,
|
max_tokens: int = 512,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
|
image: Optional[str] = None,
|
||||||
|
image_strength: float = 1.0,
|
||||||
|
image_frame_idx: int = 0,
|
||||||
):
|
):
|
||||||
"""Generate video with synchronized audio from text prompt."""
|
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_repo: Model repository ID
|
||||||
|
text_encoder_repo: Text encoder repository ID
|
||||||
|
prompt: Text description of the video to generate
|
||||||
|
height: Output video height (must be divisible by 64)
|
||||||
|
width: Output video width (must be divisible by 64)
|
||||||
|
num_frames: Number of frames
|
||||||
|
seed: Random seed
|
||||||
|
fps: Frames per second
|
||||||
|
output_path: Output video path
|
||||||
|
output_audio_path: Output audio path
|
||||||
|
verbose: Whether to print progress
|
||||||
|
enhance_prompt: Whether to enhance prompt using Gemma
|
||||||
|
max_tokens: Max tokens for prompt enhancement
|
||||||
|
temperature: Temperature for prompt enhancement
|
||||||
|
image: Path to conditioning image for I2V
|
||||||
|
image_strength: Conditioning strength (1.0 = full denoise)
|
||||||
|
image_frame_idx: Frame index to condition (0 = first frame)
|
||||||
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Validate dimensions
|
# Validate dimensions
|
||||||
@@ -333,9 +399,13 @@ def generate_video_with_audio(
|
|||||||
# Calculate audio frames
|
# Calculate audio frames
|
||||||
audio_frames = compute_audio_frames(num_frames, fps)
|
audio_frames = compute_audio_frames(num_frames, fps)
|
||||||
|
|
||||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
|
is_i2v = image is not None
|
||||||
|
mode_str = "I2V+Audio" if is_i2v else "T2V+Audio"
|
||||||
|
print(f"{Colors.BOLD}{Colors.CYAN}🎬 [{mode_str}] Generating {width}x{height} video with {num_frames} frames + audio{Colors.RESET}")
|
||||||
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
|
print(f"{Colors.DIM}Audio: {audio_frames} latent frames @ {AUDIO_SAMPLE_RATE}Hz{Colors.RESET}")
|
||||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||||
|
if is_i2v:
|
||||||
|
print(f"{Colors.DIM}Image: {image} (strength={image_strength}, frame={image_frame_idx}){Colors.RESET}")
|
||||||
|
|
||||||
model_path = get_model_path(model_repo)
|
model_path = get_model_path(model_repo)
|
||||||
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
text_encoder_path = model_path if text_encoder_repo is None else get_model_path(text_encoder_repo)
|
||||||
@@ -400,6 +470,31 @@ def generate_video_with_audio(
|
|||||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||||
mx.eval(transformer.parameters())
|
mx.eval(transformer.parameters())
|
||||||
|
|
||||||
|
# Load VAE encoder and encode image for I2V conditioning
|
||||||
|
stage1_image_latent = None
|
||||||
|
stage2_image_latent = None
|
||||||
|
if is_i2v:
|
||||||
|
print(f"{Colors.BLUE}🖼️ Loading VAE encoder and encoding image...{Colors.RESET}")
|
||||||
|
vae_encoder = load_vae_encoder(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||||
|
mx.eval(vae_encoder.parameters())
|
||||||
|
|
||||||
|
# Load and prepare image for stage 1 (half resolution)
|
||||||
|
input_image = load_image(image, height=height // 2, width=width // 2)
|
||||||
|
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
|
||||||
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||||
|
mx.eval(stage1_image_latent)
|
||||||
|
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
||||||
|
|
||||||
|
# Load and prepare image for stage 2 (full resolution)
|
||||||
|
input_image = load_image(image, height=height, width=width)
|
||||||
|
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
|
||||||
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||||
|
mx.eval(stage2_image_latent)
|
||||||
|
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
||||||
|
|
||||||
|
del vae_encoder
|
||||||
|
mx.clear_cache()
|
||||||
|
|
||||||
# Initialize latents
|
# Initialize latents
|
||||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
@@ -412,12 +507,30 @@ def generate_video_with_audio(
|
|||||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
audio_positions = create_audio_position_grid(1, audio_frames)
|
||||||
mx.eval(video_positions, audio_positions)
|
mx.eval(video_positions, audio_positions)
|
||||||
|
|
||||||
|
# Apply I2V conditioning for stage 1 if provided
|
||||||
|
video_state1 = None
|
||||||
|
if is_i2v and stage1_image_latent is not None:
|
||||||
|
video_state1 = LatentState(
|
||||||
|
latent=video_latents,
|
||||||
|
clean_latent=mx.zeros_like(video_latents),
|
||||||
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||||
|
)
|
||||||
|
conditioning = VideoConditionByLatentIndex(
|
||||||
|
latent=stage1_image_latent,
|
||||||
|
frame_idx=image_frame_idx,
|
||||||
|
strength=image_strength,
|
||||||
|
)
|
||||||
|
video_state1 = apply_conditioning(video_state1, [conditioning])
|
||||||
|
video_latents = video_state1.latent
|
||||||
|
mx.eval(video_latents)
|
||||||
|
|
||||||
# Stage 1 denoising
|
# Stage 1 denoising
|
||||||
video_latents, audio_latents = denoise_av(
|
video_latents, audio_latents = denoise_av(
|
||||||
video_latents, audio_latents,
|
video_latents, audio_latents,
|
||||||
video_positions, audio_positions,
|
video_positions, audio_positions,
|
||||||
video_embeddings, audio_embeddings,
|
video_embeddings, audio_embeddings,
|
||||||
transformer, STAGE_1_SIGMAS, verbose=verbose
|
transformer, STAGE_1_SIGMAS, verbose=verbose,
|
||||||
|
video_state=video_state1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Upsample video latents
|
# Upsample video latents
|
||||||
@@ -449,11 +562,29 @@ def generate_video_with_audio(
|
|||||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||||
mx.eval(video_latents, audio_latents)
|
mx.eval(video_latents, audio_latents)
|
||||||
|
|
||||||
|
# Apply I2V conditioning for stage 2 if provided
|
||||||
|
video_state2 = None
|
||||||
|
if is_i2v and stage2_image_latent is not None:
|
||||||
|
video_state2 = LatentState(
|
||||||
|
latent=video_latents,
|
||||||
|
clean_latent=mx.zeros_like(video_latents),
|
||||||
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
||||||
|
)
|
||||||
|
conditioning = VideoConditionByLatentIndex(
|
||||||
|
latent=stage2_image_latent,
|
||||||
|
frame_idx=image_frame_idx,
|
||||||
|
strength=image_strength,
|
||||||
|
)
|
||||||
|
video_state2 = apply_conditioning(video_state2, [conditioning])
|
||||||
|
video_latents = video_state2.latent
|
||||||
|
mx.eval(video_latents)
|
||||||
|
|
||||||
video_latents, audio_latents = denoise_av(
|
video_latents, audio_latents = denoise_av(
|
||||||
video_latents, audio_latents,
|
video_latents, audio_latents,
|
||||||
video_positions, audio_positions,
|
video_positions, audio_positions,
|
||||||
video_embeddings, audio_embeddings,
|
video_embeddings, audio_embeddings,
|
||||||
transformer, STAGE_2_SIGMAS, verbose=verbose
|
transformer, STAGE_2_SIGMAS, verbose=verbose,
|
||||||
|
video_state=video_state2
|
||||||
)
|
)
|
||||||
|
|
||||||
del transformer
|
del transformer
|
||||||
@@ -549,13 +680,18 @@ def generate_video_with_audio(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate videos with synchronized audio using MLX LTX-2",
|
description="Generate videos with synchronized audio using MLX LTX-2 (T2V and I2V)",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
|
# Text-to-Video with Audio (T2V+Audio)
|
||||||
python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach"
|
python -m mlx_video.generate_av --prompt "Ocean waves crashing on a beach"
|
||||||
python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt
|
python -m mlx_video.generate_av --prompt "A jazz band playing" --enhance-prompt
|
||||||
python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav
|
python -m mlx_video.generate_av --prompt "..." --output my_video.mp4 --output-audio my_audio.wav
|
||||||
|
|
||||||
|
# Image-to-Video with Audio (I2V+Audio)
|
||||||
|
python -m mlx_video.generate_av --prompt "A person dancing" --image photo.jpg
|
||||||
|
python -m mlx_video.generate_av --prompt "Waves crashing" --image beach.png --image-strength 0.8
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -587,6 +723,12 @@ Examples:
|
|||||||
help="Max tokens for prompt enhancement")
|
help="Max tokens for prompt enhancement")
|
||||||
parser.add_argument("--temperature", type=float, default=0.7,
|
parser.add_argument("--temperature", type=float, default=0.7,
|
||||||
help="Temperature for prompt enhancement")
|
help="Temperature for prompt enhancement")
|
||||||
|
parser.add_argument("--image", "-i", type=str, default=None,
|
||||||
|
help="Path to conditioning image for I2V (Image-to-Video) generation")
|
||||||
|
parser.add_argument("--image-strength", type=float, default=1.0,
|
||||||
|
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
|
||||||
|
parser.add_argument("--image-frame-idx", type=int, default=0,
|
||||||
|
help="Frame index to condition for I2V (0 = first frame, default: 0)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -605,6 +747,9 @@ Examples:
|
|||||||
enhance_prompt=args.enhance_prompt,
|
enhance_prompt=args.enhance_prompt,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
temperature=args.temperature,
|
temperature=args.temperature,
|
||||||
|
image=args.image,
|
||||||
|
image_strength=args.image_strength,
|
||||||
|
image_frame_idx=args.image_frame_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||||
|
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
|
||||||
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
|
||||||
|
|||||||
187
mlx_video/models/ltx/video_vae/encoder.py
Normal file
187
mlx_video/models/ltx/video_vae/encoder.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""Video VAE Encoder for LTX-2 Image-to-Video.
|
||||||
|
|
||||||
|
The encoder compresses input images/videos to latent representations.
|
||||||
|
Used for I2V (image-to-video) conditioning by encoding the input image
|
||||||
|
to latent space, which can then be used to condition video generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple, Any, Optional
|
||||||
|
import json
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae_encoder(model_path: str) -> VideoEncoder:
|
||||||
|
"""Load VAE encoder from safetensors file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model weights (safetensors file or directory)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded VideoEncoder instance
|
||||||
|
"""
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
# Try to find the weights file
|
||||||
|
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||||
|
weights_path = model_path
|
||||||
|
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||||
|
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||||
|
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||||
|
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||||
|
|
||||||
|
print(f"Loading VAE encoder from {weights_path}...")
|
||||||
|
|
||||||
|
# Read config from safetensors metadata
|
||||||
|
encoder_blocks = []
|
||||||
|
norm_layer = NormLayerType.PIXEL_NORM
|
||||||
|
latent_log_var = LogVarianceType.UNIFORM
|
||||||
|
patch_size = 4
|
||||||
|
|
||||||
|
try:
|
||||||
|
with safe_open(str(weights_path), framework="numpy") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata and "config" in metadata:
|
||||||
|
configs = json.loads(metadata["config"])
|
||||||
|
vae_config = configs.get("vae", {})
|
||||||
|
|
||||||
|
# Parse encoder blocks
|
||||||
|
raw_blocks = vae_config.get("encoder_blocks", [])
|
||||||
|
for block in raw_blocks:
|
||||||
|
if isinstance(block, list) and len(block) == 2:
|
||||||
|
name, params = block
|
||||||
|
encoder_blocks.append((name, params))
|
||||||
|
|
||||||
|
# Parse other config
|
||||||
|
norm_str = vae_config.get("norm_layer", "pixel_norm")
|
||||||
|
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
|
||||||
|
|
||||||
|
var_str = vae_config.get("latent_log_var", "uniform")
|
||||||
|
if var_str == "uniform":
|
||||||
|
latent_log_var = LogVarianceType.UNIFORM
|
||||||
|
elif var_str == "per_channel":
|
||||||
|
latent_log_var = LogVarianceType.PER_CHANNEL
|
||||||
|
elif var_str == "constant":
|
||||||
|
latent_log_var = LogVarianceType.CONSTANT
|
||||||
|
else:
|
||||||
|
latent_log_var = LogVarianceType.NONE
|
||||||
|
|
||||||
|
patch_size = vae_config.get("patch_size", 4)
|
||||||
|
|
||||||
|
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Could not read config from metadata: {e}")
|
||||||
|
# Use default config
|
||||||
|
encoder_blocks = [
|
||||||
|
("res_x", {"num_layers": 4}),
|
||||||
|
("compress_space_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 6}),
|
||||||
|
("compress_time_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 6}),
|
||||||
|
("compress_all_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 2}),
|
||||||
|
("compress_all_res", {"multiplier": 2}),
|
||||||
|
("res_x", {"num_layers": 2}),
|
||||||
|
]
|
||||||
|
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
|
||||||
|
|
||||||
|
# Create encoder
|
||||||
|
encoder = VideoEncoder(
|
||||||
|
convolution_dimensions=3,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=128,
|
||||||
|
encoder_blocks=encoder_blocks,
|
||||||
|
patch_size=patch_size,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
latent_log_var=latent_log_var,
|
||||||
|
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
weights = mx.load(str(weights_path))
|
||||||
|
|
||||||
|
# Determine prefix based on weight keys
|
||||||
|
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||||
|
|
||||||
|
if has_vae_prefix:
|
||||||
|
prefix = "vae.encoder."
|
||||||
|
stats_prefix = "vae.per_channel_statistics."
|
||||||
|
else:
|
||||||
|
prefix = "encoder."
|
||||||
|
stats_prefix = "per_channel_statistics."
|
||||||
|
|
||||||
|
# Load per-channel statistics for normalization
|
||||||
|
mean_key = f"{stats_prefix}mean-of-means"
|
||||||
|
std_key = f"{stats_prefix}std-of-means"
|
||||||
|
|
||||||
|
if mean_key in weights:
|
||||||
|
encoder.per_channel_statistics.mean = weights[mean_key]
|
||||||
|
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
|
||||||
|
if std_key in weights:
|
||||||
|
encoder.per_channel_statistics.std = weights[std_key]
|
||||||
|
print(f" Loaded latent std: shape {weights[std_key].shape}")
|
||||||
|
|
||||||
|
# Build encoder weights dict with key remapping
|
||||||
|
encoder_weights = {}
|
||||||
|
for key, value in weights.items():
|
||||||
|
if not key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Remove prefix
|
||||||
|
new_key = key[len(prefix):]
|
||||||
|
|
||||||
|
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||||
|
if ".weight" in key and value.ndim == 5:
|
||||||
|
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
encoder_weights[new_key] = value
|
||||||
|
|
||||||
|
print(f" Found {len(encoder_weights)} encoder weights")
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
encoder.load_weights(list(encoder_weights.items()), strict=False)
|
||||||
|
|
||||||
|
print("VAE encoder loaded successfully")
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image(
|
||||||
|
image: mx.array,
|
||||||
|
encoder: VideoEncoder,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Encode a single image to latent space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image tensor of shape (H, W, 3) in range [0, 1] or (B, H, W, 3)
|
||||||
|
encoder: Loaded VAE encoder
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Latent tensor of shape (1, 128, 1, H//32, W//32)
|
||||||
|
"""
|
||||||
|
# Add batch dimension if needed
|
||||||
|
if image.ndim == 3:
|
||||||
|
image = mx.expand_dims(image, axis=0) # (1, H, W, 3)
|
||||||
|
|
||||||
|
# Convert from (B, H, W, C) to (B, C, H, W)
|
||||||
|
image = mx.transpose(image, (0, 3, 1, 2)) # (B, 3, H, W)
|
||||||
|
|
||||||
|
# Normalize to [-1, 1]
|
||||||
|
if image.max() > 1.0:
|
||||||
|
image = image / 255.0
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
|
||||||
|
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
|
||||||
|
image = mx.expand_dims(image, axis=2) # (B, 3, 1, H, W)
|
||||||
|
|
||||||
|
# Encode
|
||||||
|
latent = encoder(image)
|
||||||
|
|
||||||
|
return latent
|
||||||
@@ -141,9 +141,10 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
|
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
# Create ResNet blocks
|
# Create ResNet blocks - use dict for MLX parameter tracking
|
||||||
self.resnets = [
|
# Named res_blocks to match PyTorch weight keys
|
||||||
ResnetBlock3D(
|
self.res_blocks = {
|
||||||
|
i: ResnetBlock3D(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=in_channels,
|
out_channels=in_channels,
|
||||||
@@ -154,8 +155,8 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
timestep_conditioning=timestep_conditioning,
|
timestep_conditioning=timestep_conditioning,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -165,7 +166,7 @@ class UNetMidBlock3D(nn.Module):
|
|||||||
generator: Optional[int] = None,
|
generator: Optional[int] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
|
||||||
for resnet in self.resnets:
|
for resnet in self.res_blocks.values():
|
||||||
x = resnet(x, causal=causal, generator=generator)
|
x = resnet(x, causal=causal, generator=generator)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -9,6 +9,15 @@ from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingMode
|
|||||||
|
|
||||||
|
|
||||||
class SpaceToDepthDownsample(nn.Module):
|
class SpaceToDepthDownsample(nn.Module):
|
||||||
|
"""Space-to-depth downsampling with 3x3 conv and skip connection.
|
||||||
|
|
||||||
|
PyTorch-compatible implementation:
|
||||||
|
1. Apply 3x3 conv: in_channels -> out_channels // prod(stride)
|
||||||
|
2. Space-to-depth on conv output: channels * prod(stride)
|
||||||
|
3. Space-to-depth on input with group averaging for skip connection
|
||||||
|
4. Add skip connection
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int,
|
dims: int,
|
||||||
@@ -17,7 +26,6 @@ class SpaceToDepthDownsample(nn.Module):
|
|||||||
stride: Union[int, Tuple[int, int, int]],
|
stride: Union[int, Tuple[int, int, int]],
|
||||||
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if isinstance(stride, int):
|
if isinstance(stride, int):
|
||||||
@@ -25,61 +33,74 @@ class SpaceToDepthDownsample(nn.Module):
|
|||||||
|
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
# Calculate the multiplier for channels
|
# Calculate channels
|
||||||
multiplier = stride[0] * stride[1] * stride[2]
|
multiplier = stride[0] * stride[1] * stride[2]
|
||||||
intermediate_channels = in_channels * multiplier
|
self.group_size = in_channels * multiplier // out_channels
|
||||||
|
conv_out_channels = out_channels // multiplier
|
||||||
|
|
||||||
# 1x1x1 convolution to adjust channels
|
# 3x3 convolution (not 1x1)
|
||||||
self.conv = CausalConv3d(
|
self.conv = CausalConv3d(
|
||||||
in_channels=intermediate_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=conv_out_channels,
|
||||||
kernel_size=1,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=1,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
def _space_to_depth(self, x: mx.array) -> mx.array:
|
||||||
|
"""Rearrange: b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w"""
|
||||||
b, c, d, h, w = x.shape
|
b, c, d, h, w = x.shape
|
||||||
st, sh, sw = self.stride
|
st, sh, sw = self.stride
|
||||||
|
|
||||||
|
# Reshape to group spatial elements
|
||||||
|
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
||||||
|
|
||||||
|
# Permute: (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
||||||
|
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
||||||
|
|
||||||
|
# Reshape to combine channels
|
||||||
|
new_c = c * st * sh * sw
|
||||||
|
new_d = d // st
|
||||||
|
new_h = h // sh
|
||||||
|
new_w = w // sw
|
||||||
|
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, causal: bool = True) -> mx.array:
|
||||||
|
b, c, d, h, w = x.shape
|
||||||
|
st, sh, sw = self.stride
|
||||||
|
|
||||||
|
# Temporal padding for causal mode
|
||||||
|
if st == 2:
|
||||||
|
# Duplicate first frame for padding
|
||||||
|
x = mx.concatenate([x[:, :, :1, :, :], x], axis=2)
|
||||||
|
d = d + 1
|
||||||
|
|
||||||
# Pad if necessary to make dimensions divisible by stride
|
# Pad if necessary to make dimensions divisible by stride
|
||||||
pad_d = (st - d % st) % st
|
pad_d = (st - d % st) % st
|
||||||
pad_h = (sh - h % sh) % sh
|
pad_h = (sh - h % sh) % sh
|
||||||
pad_w = (sw - w % sw) % sw
|
pad_w = (sw - w % sw) % sw
|
||||||
|
|
||||||
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
if pad_d > 0 or pad_h > 0 or pad_w > 0:
|
||||||
# For causal, pad at the end of temporal dimension
|
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
||||||
if causal:
|
|
||||||
x = mx.pad(x, [(0, 0), (0, 0), (0, pad_d), (0, pad_h), (0, pad_w)])
|
|
||||||
else:
|
|
||||||
x = mx.pad(x, [(0, 0), (0, 0), (pad_d // 2, pad_d - pad_d // 2),
|
|
||||||
(pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2)])
|
|
||||||
|
|
||||||
b, c, d, h, w = x.shape
|
# Skip connection: space-to-depth on input, then group mean
|
||||||
|
x_in = self._space_to_depth(x)
|
||||||
|
# Reshape for group mean: (b, c*prod(stride), d, h, w) -> (b, out_channels, group_size, d, h, w)
|
||||||
|
b2, c2, d2, h2, w2 = x_in.shape
|
||||||
|
x_in = mx.reshape(x_in, (b2, self.out_channels, self.group_size, d2, h2, w2))
|
||||||
|
x_in = mx.mean(x_in, axis=2) # (b, out_channels, d, h, w)
|
||||||
|
|
||||||
# Reshape to group spatial elements
|
# Conv branch: apply conv then space-to-depth
|
||||||
# (B, C, D, H, W) -> (B, C, D/st, st, H/sh, sh, W/sw, sw)
|
x_conv = self.conv(x, causal=causal)
|
||||||
x = mx.reshape(x, (b, c, d // st, st, h // sh, sh, w // sw, sw))
|
x_conv = self._space_to_depth(x_conv)
|
||||||
|
|
||||||
# Permute to move stride elements to channel dim
|
# Add skip connection
|
||||||
# (B, C, D', st, H', sh, W', sw) -> (B, C, st, sh, sw, D', H', W')
|
return x_conv + x_in
|
||||||
x = mx.transpose(x, (0, 1, 3, 5, 7, 2, 4, 6))
|
|
||||||
|
|
||||||
# Reshape to combine channels
|
|
||||||
# (B, C, st, sh, sw, D', H', W') -> (B, C*st*sh*sw, D', H', W')
|
|
||||||
new_c = c * st * sh * sw
|
|
||||||
new_d = d // st
|
|
||||||
new_h = h // sh
|
|
||||||
new_w = w // sw
|
|
||||||
x = mx.reshape(x, (b, new_c, new_d, new_h, new_w))
|
|
||||||
|
|
||||||
# Apply 1x1 conv to adjust channels
|
|
||||||
x = self.conv(x, causal=causal)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DepthToSpaceUpsample(nn.Module):
|
class DepthToSpaceUpsample(nn.Module):
|
||||||
|
|||||||
@@ -273,9 +273,9 @@ class VideoEncoder(nn.Module):
|
|||||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build encoder blocks
|
# Build encoder blocks - use dict with int keys for MLX parameter tracking
|
||||||
self.down_blocks = []
|
self.down_blocks = {}
|
||||||
for block_name, block_params in encoder_blocks:
|
for i, (block_name, block_params) in enumerate(encoder_blocks):
|
||||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||||
|
|
||||||
block, feature_channels = _make_encoder_block(
|
block, feature_channels = _make_encoder_block(
|
||||||
@@ -287,7 +287,7 @@ class VideoEncoder(nn.Module):
|
|||||||
norm_num_groups=self._norm_num_groups,
|
norm_num_groups=self._norm_num_groups,
|
||||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||||
)
|
)
|
||||||
self.down_blocks.append(block)
|
self.down_blocks[i] = block
|
||||||
|
|
||||||
# Output normalization and convolution
|
# Output normalization and convolution
|
||||||
if norm_layer == NormLayerType.GROUP_NORM:
|
if norm_layer == NormLayerType.GROUP_NORM:
|
||||||
@@ -341,7 +341,7 @@ class VideoEncoder(nn.Module):
|
|||||||
sample = self.conv_in(sample, causal=True)
|
sample = self.conv_in(sample, causal=True)
|
||||||
|
|
||||||
# Process through encoder blocks
|
# Process through encoder blocks
|
||||||
for down_block in self.down_blocks:
|
for down_block in self.down_blocks.values():
|
||||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||||
sample = down_block(sample, causal=True)
|
sample = down_block(sample, causal=True)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
def get_model_path(model_repo: str):
|
def get_model_path(model_repo: str):
|
||||||
"""Get or download LTX-2 model path."""
|
"""Get or download LTX-2 model path."""
|
||||||
@@ -160,3 +163,122 @@ def get_timestep_embedding(
|
|||||||
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||||
|
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(
|
||||||
|
image_path: Union[str, Path],
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Load and preprocess an image for I2V conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the image file
|
||||||
|
height: Target height (must be divisible by 32). If None, uses original.
|
||||||
|
width: Target width (must be divisible by 32). If None, uses original.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image tensor of shape (H, W, 3) in range [0, 1]
|
||||||
|
"""
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# Resize if dimensions specified
|
||||||
|
if height is not None and width is not None:
|
||||||
|
image = image.resize((width, height), Image.Resampling.LANCZOS)
|
||||||
|
elif height is not None or width is not None:
|
||||||
|
# If only one dimension specified, resize preserving aspect ratio
|
||||||
|
orig_w, orig_h = image.size
|
||||||
|
if height is not None:
|
||||||
|
scale = height / orig_h
|
||||||
|
new_w = int(orig_w * scale)
|
||||||
|
new_w = (new_w // 32) * 32 # Round to nearest 32
|
||||||
|
image = image.resize((new_w, height), Image.Resampling.LANCZOS)
|
||||||
|
else:
|
||||||
|
scale = width / orig_w
|
||||||
|
new_h = int(orig_h * scale)
|
||||||
|
new_h = (new_h // 32) * 32 # Round to nearest 32
|
||||||
|
image = image.resize((width, new_h), Image.Resampling.LANCZOS)
|
||||||
|
else:
|
||||||
|
# Round to nearest 32
|
||||||
|
orig_w, orig_h = image.size
|
||||||
|
new_w = (orig_w // 32) * 32
|
||||||
|
new_h = (orig_h // 32) * 32
|
||||||
|
if new_w != orig_w or new_h != orig_h:
|
||||||
|
image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# Convert to numpy then MLX
|
||||||
|
image_np = np.array(image).astype(np.float32) / 255.0
|
||||||
|
return mx.array(image_np)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image_aspect_ratio(
|
||||||
|
image: mx.array,
|
||||||
|
long_side: int = 512,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Resize image preserving aspect ratio, making long side = long_side.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image tensor of shape (H, W, 3)
|
||||||
|
long_side: Target size for the longer dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resized image tensor
|
||||||
|
"""
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
|
if h > w:
|
||||||
|
new_h = long_side
|
||||||
|
new_w = int(w * long_side / h)
|
||||||
|
else:
|
||||||
|
new_w = long_side
|
||||||
|
new_h = int(h * long_side / w)
|
||||||
|
|
||||||
|
# Round to nearest 32
|
||||||
|
new_h = (new_h // 32) * 32
|
||||||
|
new_w = (new_w // 32) * 32
|
||||||
|
|
||||||
|
# Use PIL for high-quality resize
|
||||||
|
image_np = np.array(image)
|
||||||
|
if image_np.max() <= 1.0:
|
||||||
|
image_np = (image_np * 255).astype(np.uint8)
|
||||||
|
pil_image = Image.fromarray(image_np)
|
||||||
|
pil_image = pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
return mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_image_for_encoding(
|
||||||
|
image: mx.array,
|
||||||
|
target_height: int,
|
||||||
|
target_width: int,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Prepare image for VAE encoding by resizing and normalizing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image tensor of shape (H, W, 3) in range [0, 1]
|
||||||
|
target_height: Target height for the video
|
||||||
|
target_width: Target width for the video
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image tensor ready for encoding, shape (1, 3, 1, H, W) in range [-1, 1]
|
||||||
|
"""
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
|
# Resize if needed
|
||||||
|
if h != target_height or w != target_width:
|
||||||
|
image_np = np.array(image)
|
||||||
|
if image_np.max() <= 1.0:
|
||||||
|
image_np = (image_np * 255).astype(np.uint8)
|
||||||
|
pil_image = Image.fromarray(image_np)
|
||||||
|
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
||||||
|
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
||||||
|
|
||||||
|
# Normalize to [-1, 1]
|
||||||
|
image = image * 2.0 - 1.0
|
||||||
|
|
||||||
|
# Convert to (B, C, 1, H, W)
|
||||||
|
image = mx.transpose(image, (2, 0, 1)) # (3, H, W)
|
||||||
|
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
|
||||||
|
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|||||||
Reference in New Issue
Block a user