Files
mlx-video/mlx_video/generate.py
2026-01-11 23:48:33 +01:00

587 lines
18 KiB
Python

from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Iterator, Union
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.ltx import LTXModel, X0Model
from mlx_video.models.ltx.transformer import Modality
from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder
@dataclass
class GenerationConfig:
"""Configuration for video generation."""
# Video dimensions
height: int = 512
width: int = 512
num_frames: int = 33 # Must be 1 + 8*k
# Diffusion parameters
num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True)
guidance_scale: float = 3.0
use_distilled: bool = True # Use hardcoded sigma values for distilled model
# Latent dimensions (computed from video dimensions)
@property
def latent_height(self) -> int:
return self.height // 32
@property
def latent_width(self) -> int:
return self.width // 32
@property
def latent_frames(self) -> int:
return 1 + (self.num_frames - 1) // 8
# Hardcoded sigma values for distilled model (from LTX-2 pipeline)
# These were tuned to match the distillation process
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
# Scheduler constants for dynamic sigma computation (non-distilled models)
BASE_SHIFT_ANCHOR = 1024
MAX_SHIFT_ANCHOR = 4096
def get_sigmas(
num_steps: int,
num_tokens: int,
max_shift: float = 2.05,
base_shift: float = 0.95,
stretch: bool = True,
terminal: float = 0.1,
use_distilled: bool = True,
) -> mx.array:
"""Get sigma schedule for diffusion.
Args:
num_steps: Number of diffusion steps
num_tokens: Number of latent tokens (T * H * W)
max_shift: Maximum shift for sigma schedule
base_shift: Base shift for sigma schedule
stretch: Whether to stretch sigmas to terminal value
terminal: Terminal value for stretching
use_distilled: If True, use hardcoded distilled sigma values
Returns:
Array of sigma values
"""
import math
# For distilled model, use hardcoded sigma values
if use_distilled:
return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32)
# For non-distilled models, compute dynamically using LTX2Scheduler logic
# Linear base schedule
sigmas = mx.linspace(1.0, 0.0, num_steps + 1)
# Compute token-dependent sigma shift
x1 = BASE_SHIFT_ANCHOR
x2 = MAX_SHIFT_ANCHOR
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = num_tokens * mm + b
# Apply exponential transformation
# sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1)
power = 1
exp_shift = math.exp(sigma_shift)
# Convert to numpy for computation then back to mx
sigmas_np = np.array(sigmas)
result = np.zeros_like(sigmas_np)
non_zero = sigmas_np != 0
result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power)
# Stretch sigmas so final value matches terminal
if stretch:
non_zero_mask = result != 0
non_zero_sigmas = result[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
result[non_zero_mask] = stretched
return mx.array(result, dtype=mx.float32)
def create_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
temporal_scale: int = 8,
spatial_scale: int = 32,
fps: float = 24.0,
causal_fix: bool = True,
) -> mx.array:
"""Create position grid for RoPE in pixel space.
Args:
batch_size: Batch size
num_frames: Number of frames (latent)
height: Height (latent)
width: Width (latent)
temporal_scale: VAE temporal scale factor (default 8)
spatial_scale: VAE spatial scale factor (default 32)
fps: Frames per second (default 24.0)
causal_fix: Apply causal fix for first frame (default True)
Returns:
Position grid of shape (B, 3, num_patches, 2) in pixel space
where dim 2 is [start, end) bounds for each patch
"""
# Patch size is (1, 1, 1) for LTX-2 - no spatial patching
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
# Generate grid coordinates for each dimension (frame, height, width)
# These are the starting coordinates for each patch in latent space
t_coords = np.arange(0, num_frames, patch_size_t)
h_coords = np.arange(0, height, patch_size_h)
w_coords = np.arange(0, width, patch_size_w)
# Create meshgrid with indexing='ij' for (frame, height, width) order
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
# Stack to get shape (3, grid_t, grid_h, grid_w)
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
# Calculate end coordinates (start + patch_size)
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
patch_ends = patch_starts + patch_size_delta
# Stack start and end: shape (3, grid_t, grid_h, grid_w, 2)
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
# Flatten spatial/temporal dims: (3, num_patches, 2)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
# Broadcast to batch: (batch, 3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
# Convert latent coords to pixel coords by scaling with VAE factors
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
# Apply causal fix for first frame temporal axis
if causal_fix:
# VAE temporal stride for first frame is 1 instead of temporal_scale
# Shift and clamp to keep first-frame timestamps non-negative
pixel_coords[:, 0, :, :] = np.clip(
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
a_min=0,
a_max=None
)
# Convert temporal to time in seconds by dividing by fps
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
return mx.array(pixel_coords, dtype=mx.float32)
class LTXVideoPipeline:
def __init__(
self,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_encoder: Optional[VideoEncoder] = None,
vae_decoder: Optional[VideoDecoder] = None,
):
"""Initialize pipeline.
Args:
transformer: LTX transformer model
text_encoder: Optional LTX text encoder
tokenizer: Optional tokenizer for text encoding
vae_encoder: Optional VAE encoder
vae_decoder: Optional VAE decoder
"""
self.transformer = transformer
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder
self.x0_model = X0Model(transformer)
def prepare_latents(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: mx.Dtype = mx.float16,
) -> mx.array:
"""Prepare initial noise latents.
Args:
batch_size: Batch size
num_frames: Number of latent frames
height: Latent height
width: Latent width
dtype: Data type
Returns:
Random latent noise
"""
# Use in_channels from transformer config
in_channels = self.transformer.config.in_channels
shape = (batch_size, in_channels, num_frames, height, width)
latents = mx.random.normal(shape).astype(dtype)
return latents
def prepare_text_embeddings(
self,
prompt: Union[str, List[str]],
batch_size: int,
max_length: int = 1024,
) -> Tuple[mx.array, Optional[mx.array]]:
"""Prepare text embeddings.
Args:
prompt: Text prompt or list of prompts
batch_size: Batch size
max_length: Maximum sequence length for tokenization
Returns:
Tuple of (text_embeddings, attention_mask)
"""
# If text encoder is available, use it
if self.text_encoder is not None and self.tokenizer is not None:
# Handle single or multiple prompts
if isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = prompt
# Tokenize
tokens = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
input_ids = mx.array(tokens["input_ids"])
attention_mask = mx.array(tokens["attention_mask"])
# Encode
embeddings = self.text_encoder(input_ids, attention_mask)
mx.eval(embeddings)
return embeddings, None # Connector handles masking internally
# Fallback: random embeddings (for testing without text encoder)
print("Warning: No text encoder provided, using random embeddings")
seq_len = max_length + 128 # Account for learnable registers
embed_dim = self.transformer.config.caption_channels
embeddings = mx.random.normal((batch_size, seq_len, embed_dim))
mask = mx.ones((batch_size, seq_len))
return embeddings, mask
def denoise_step(
self,
latents: mx.array,
sigma: float,
sigma_next: float,
text_embeddings: mx.array,
positions: mx.array,
text_mask: Optional[mx.array] = None,
) -> mx.array:
"""Perform one denoising step.
Args:
latents: Current noisy latents
sigma: Current noise level
sigma_next: Next noise level
text_embeddings: Text conditioning
positions: Position grid for RoPE
text_mask: Optional attention mask for text
Returns:
Denoised latents
"""
batch_size = latents.shape[0]
# Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C)
b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
# Create timestep tensor
timesteps = mx.full((batch_size,), sigma)
# Create video modality input
video_modality = Modality(
latent=latents_flat,
timesteps=timesteps,
positions=positions,
context=text_embeddings,
context_mask=text_mask,
enabled=True,
)
# Run denoising
denoised_video, _ = self.x0_model(video=video_modality, audio=None)
# Reshape back: (B, F*H*W, C) -> (B, C, F, H, W)
denoised_video = mx.transpose(denoised_video, (0, 2, 1))
denoised_video = mx.reshape(denoised_video, (b, c, f, h, w))
# Euler step
if sigma_next > 0:
# x_next = x0 + sigma_next * (x - x0) / sigma
noise = (latents - denoised_video) / sigma
latents = denoised_video + sigma_next * noise
else:
latents = denoised_video
return latents
def __call__(
self,
prompt: str,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
"""Generate video from text prompt.
Args:
prompt: Text prompt
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor of shape (B, C, F, H, W)
"""
if config is None:
config = GenerationConfig()
if seed is not None:
mx.random.seed(seed)
batch_size = 1
# Prepare text embeddings
text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size)
# Prepare initial latents
latents = self.prepare_latents(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Prepare position grid
positions = create_position_grid(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Get sigma schedule
num_tokens = config.latent_frames * config.latent_height * config.latent_width
sigmas = get_sigmas(
config.num_inference_steps,
num_tokens,
use_distilled=config.use_distilled,
)
# Denoising loop
for i in range(len(sigmas) - 1):
sigma = float(sigmas[i])
sigma_next = float(sigmas[i + 1])
latents = self.denoise_step(
latents=latents,
sigma=sigma,
sigma_next=sigma_next,
text_embeddings=text_embeddings,
positions=positions,
text_mask=text_mask,
)
mx.eval(latents)
# Decode latents to video
if self.vae_decoder is not None:
video = self.vae_decoder(latents)
else:
video = latents
return video
def generate_video(
prompt: str,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_decoder: Optional[VideoDecoder] = None,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
"""Generate video from text prompt.
Args:
prompt: Text prompt
transformer: LTX transformer model
text_encoder: Optional text encoder
tokenizer: Optional tokenizer
vae_decoder: Optional VAE decoder
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor
"""
pipeline = LTXVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
)
return pipeline(prompt, config, seed)
def load_pipeline(
model_path: str,
text_encoder_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
load_text_encoder_weights: bool = True,
) -> LTXVideoPipeline:
"""Load complete LTX-2 video generation pipeline.
Args:
model_path: Path to LTX-2 model weights (safetensors)
text_encoder_path: Path to text encoder weights directory
tokenizer_path: Path to tokenizer directory
load_text_encoder_weights: Whether to load text encoder weights
Returns:
Configured LTXVideoPipeline
"""
from transformers import AutoTokenizer
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.convert import sanitize_transformer_weights
print("Loading LTX-2 pipeline...")
# Load transformer
print(" Loading transformer...")
raw_weights = mx.load(model_path)
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
print(" Transformer loaded")
# Load VAE decoder
print(" Loading VAE decoder...")
vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True)
print(" VAE decoder loaded")
# Load text encoder if paths provided
text_encoder = None
tokenizer = None
if load_text_encoder_weights and text_encoder_path is not None:
print(" Loading text encoder...")
text_encoder = load_text_encoder(model_path, text_encoder_path)
print(" Text encoder loaded")
if tokenizer_path is not None:
print(" Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(" Tokenizer loaded")
print("Pipeline ready!")
return LTXVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
)
def video_to_numpy(video: mx.array) -> np.ndarray:
"""Convert video tensor to numpy array.
Args:
video: Video tensor of shape (B, C, F, H, W) in range [-1, 1]
Returns:
Numpy array of shape (B, F, H, W, C) in range [0, 255]
"""
# Clamp to [-1, 1]
video = mx.clip(video, -1.0, 1.0)
# Scale to [0, 255]
video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8)
# Rearrange: (B, C, F, H, W) -> (B, F, H, W, C)
video = mx.transpose(video, (0, 2, 3, 4, 1))
return np.array(video)
if __name__ == "__main__":
# Example usage
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
# Create a small test config
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_layers=2, # Reduced for testing
num_attention_heads=4,
attention_head_dim=32,
)
# Create model
model = LTXModel(config)
# Generate video
gen_config = GenerationConfig(
height=256,
width=256,
num_frames=9,
num_inference_steps=4,
)
print("Testing generation pipeline...")
pipeline = LTXVideoPipeline(transformer=model)
# This would require proper text embeddings in practice
# video = pipeline("A cat walking", gen_config, seed=42)
# print(f"Generated video shape: {video.shape}")
print("Pipeline initialized successfully!")