initial commit (LTX-2)
This commit is contained in:
586
mlx_video/generate.py
Normal file
586
mlx_video/generate.py
Normal file
@@ -0,0 +1,586 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Iterator, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx.transformer import Modality
|
||||
from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
"""Configuration for video generation."""
|
||||
# Video dimensions
|
||||
height: int = 512
|
||||
width: int = 512
|
||||
num_frames: int = 33 # Must be 1 + 8*k
|
||||
|
||||
# Diffusion parameters
|
||||
num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True)
|
||||
guidance_scale: float = 3.0
|
||||
use_distilled: bool = True # Use hardcoded sigma values for distilled model
|
||||
|
||||
# Latent dimensions (computed from video dimensions)
|
||||
@property
|
||||
def latent_height(self) -> int:
|
||||
return self.height // 32
|
||||
|
||||
@property
|
||||
def latent_width(self) -> int:
|
||||
return self.width // 32
|
||||
|
||||
@property
|
||||
def latent_frames(self) -> int:
|
||||
return 1 + (self.num_frames - 1) // 8
|
||||
|
||||
|
||||
# Hardcoded sigma values for distilled model (from LTX-2 pipeline)
|
||||
# These were tuned to match the distillation process
|
||||
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
|
||||
|
||||
# Scheduler constants for dynamic sigma computation (non-distilled models)
|
||||
BASE_SHIFT_ANCHOR = 1024
|
||||
MAX_SHIFT_ANCHOR = 4096
|
||||
|
||||
|
||||
def get_sigmas(
|
||||
num_steps: int,
|
||||
num_tokens: int,
|
||||
max_shift: float = 2.05,
|
||||
base_shift: float = 0.95,
|
||||
stretch: bool = True,
|
||||
terminal: float = 0.1,
|
||||
use_distilled: bool = True,
|
||||
) -> mx.array:
|
||||
"""Get sigma schedule for diffusion.
|
||||
|
||||
Args:
|
||||
num_steps: Number of diffusion steps
|
||||
num_tokens: Number of latent tokens (T * H * W)
|
||||
max_shift: Maximum shift for sigma schedule
|
||||
base_shift: Base shift for sigma schedule
|
||||
stretch: Whether to stretch sigmas to terminal value
|
||||
terminal: Terminal value for stretching
|
||||
use_distilled: If True, use hardcoded distilled sigma values
|
||||
|
||||
Returns:
|
||||
Array of sigma values
|
||||
"""
|
||||
import math
|
||||
|
||||
# For distilled model, use hardcoded sigma values
|
||||
if use_distilled:
|
||||
return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32)
|
||||
|
||||
# For non-distilled models, compute dynamically using LTX2Scheduler logic
|
||||
# Linear base schedule
|
||||
sigmas = mx.linspace(1.0, 0.0, num_steps + 1)
|
||||
|
||||
# Compute token-dependent sigma shift
|
||||
x1 = BASE_SHIFT_ANCHOR
|
||||
x2 = MAX_SHIFT_ANCHOR
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
sigma_shift = num_tokens * mm + b
|
||||
|
||||
# Apply exponential transformation
|
||||
# sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1)
|
||||
power = 1
|
||||
exp_shift = math.exp(sigma_shift)
|
||||
|
||||
# Convert to numpy for computation then back to mx
|
||||
sigmas_np = np.array(sigmas)
|
||||
result = np.zeros_like(sigmas_np)
|
||||
non_zero = sigmas_np != 0
|
||||
result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power)
|
||||
|
||||
# Stretch sigmas so final value matches terminal
|
||||
if stretch:
|
||||
non_zero_mask = result != 0
|
||||
non_zero_sigmas = result[non_zero_mask]
|
||||
one_minus_z = 1.0 - non_zero_sigmas
|
||||
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||
result[non_zero_mask] = stretched
|
||||
|
||||
return mx.array(result, dtype=mx.float32)
|
||||
|
||||
|
||||
def create_position_grid(
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
temporal_scale: int = 8,
|
||||
spatial_scale: int = 32,
|
||||
fps: float = 24.0,
|
||||
causal_fix: bool = True,
|
||||
) -> mx.array:
|
||||
"""Create position grid for RoPE in pixel space.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size
|
||||
num_frames: Number of frames (latent)
|
||||
height: Height (latent)
|
||||
width: Width (latent)
|
||||
temporal_scale: VAE temporal scale factor (default 8)
|
||||
spatial_scale: VAE spatial scale factor (default 32)
|
||||
fps: Frames per second (default 24.0)
|
||||
causal_fix: Apply causal fix for first frame (default True)
|
||||
|
||||
Returns:
|
||||
Position grid of shape (B, 3, num_patches, 2) in pixel space
|
||||
where dim 2 is [start, end) bounds for each patch
|
||||
"""
|
||||
# Patch size is (1, 1, 1) for LTX-2 - no spatial patching
|
||||
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
|
||||
|
||||
# Generate grid coordinates for each dimension (frame, height, width)
|
||||
# These are the starting coordinates for each patch in latent space
|
||||
t_coords = np.arange(0, num_frames, patch_size_t)
|
||||
h_coords = np.arange(0, height, patch_size_h)
|
||||
w_coords = np.arange(0, width, patch_size_w)
|
||||
|
||||
# Create meshgrid with indexing='ij' for (frame, height, width) order
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||
|
||||
# Stack to get shape (3, grid_t, grid_h, grid_w)
|
||||
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||
|
||||
# Calculate end coordinates (start + patch_size)
|
||||
patch_size_delta = np.array([patch_size_t, patch_size_h, patch_size_w]).reshape(3, 1, 1, 1)
|
||||
patch_ends = patch_starts + patch_size_delta
|
||||
|
||||
# Stack start and end: shape (3, grid_t, grid_h, grid_w, 2)
|
||||
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
|
||||
|
||||
# Flatten spatial/temporal dims: (3, num_patches, 2)
|
||||
num_patches = num_frames * height * width
|
||||
latent_coords = latent_coords.reshape(3, num_patches, 2)
|
||||
|
||||
# Broadcast to batch: (batch, 3, num_patches, 2)
|
||||
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
|
||||
|
||||
# Convert latent coords to pixel coords by scaling with VAE factors
|
||||
scale_factors = np.array([temporal_scale, spatial_scale, spatial_scale]).reshape(1, 3, 1, 1)
|
||||
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
|
||||
|
||||
# Apply causal fix for first frame temporal axis
|
||||
if causal_fix:
|
||||
# VAE temporal stride for first frame is 1 instead of temporal_scale
|
||||
# Shift and clamp to keep first-frame timestamps non-negative
|
||||
pixel_coords[:, 0, :, :] = np.clip(
|
||||
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
|
||||
a_min=0,
|
||||
a_max=None
|
||||
)
|
||||
|
||||
# Convert temporal to time in seconds by dividing by fps
|
||||
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
||||
|
||||
return mx.array(pixel_coords, dtype=mx.float32)
|
||||
|
||||
|
||||
class LTXVideoPipeline:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: LTXModel,
|
||||
text_encoder: Optional[LTX2TextEncoder] = None,
|
||||
tokenizer: Optional[any] = None,
|
||||
vae_encoder: Optional[VideoEncoder] = None,
|
||||
vae_decoder: Optional[VideoDecoder] = None,
|
||||
):
|
||||
"""Initialize pipeline.
|
||||
|
||||
Args:
|
||||
transformer: LTX transformer model
|
||||
text_encoder: Optional LTX text encoder
|
||||
tokenizer: Optional tokenizer for text encoding
|
||||
vae_encoder: Optional VAE encoder
|
||||
vae_decoder: Optional VAE decoder
|
||||
"""
|
||||
self.transformer = transformer
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.vae_encoder = vae_encoder
|
||||
self.vae_decoder = vae_decoder
|
||||
self.x0_model = X0Model(transformer)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: mx.Dtype = mx.float16,
|
||||
) -> mx.array:
|
||||
"""Prepare initial noise latents.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size
|
||||
num_frames: Number of latent frames
|
||||
height: Latent height
|
||||
width: Latent width
|
||||
dtype: Data type
|
||||
|
||||
Returns:
|
||||
Random latent noise
|
||||
"""
|
||||
# Use in_channels from transformer config
|
||||
in_channels = self.transformer.config.in_channels
|
||||
shape = (batch_size, in_channels, num_frames, height, width)
|
||||
latents = mx.random.normal(shape).astype(dtype)
|
||||
return latents
|
||||
|
||||
def prepare_text_embeddings(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
batch_size: int,
|
||||
max_length: int = 1024,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
"""Prepare text embeddings.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt or list of prompts
|
||||
batch_size: Batch size
|
||||
max_length: Maximum sequence length for tokenization
|
||||
|
||||
Returns:
|
||||
Tuple of (text_embeddings, attention_mask)
|
||||
"""
|
||||
# If text encoder is available, use it
|
||||
if self.text_encoder is not None and self.tokenizer is not None:
|
||||
# Handle single or multiple prompts
|
||||
if isinstance(prompt, str):
|
||||
prompts = [prompt] * batch_size
|
||||
else:
|
||||
prompts = prompt
|
||||
|
||||
# Tokenize
|
||||
tokens = self.tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
input_ids = mx.array(tokens["input_ids"])
|
||||
attention_mask = mx.array(tokens["attention_mask"])
|
||||
|
||||
# Encode
|
||||
embeddings = self.text_encoder(input_ids, attention_mask)
|
||||
mx.eval(embeddings)
|
||||
|
||||
return embeddings, None # Connector handles masking internally
|
||||
|
||||
# Fallback: random embeddings (for testing without text encoder)
|
||||
print("Warning: No text encoder provided, using random embeddings")
|
||||
seq_len = max_length + 128 # Account for learnable registers
|
||||
embed_dim = self.transformer.config.caption_channels
|
||||
|
||||
embeddings = mx.random.normal((batch_size, seq_len, embed_dim))
|
||||
mask = mx.ones((batch_size, seq_len))
|
||||
|
||||
return embeddings, mask
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
latents: mx.array,
|
||||
sigma: float,
|
||||
sigma_next: float,
|
||||
text_embeddings: mx.array,
|
||||
positions: mx.array,
|
||||
text_mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
"""Perform one denoising step.
|
||||
|
||||
Args:
|
||||
latents: Current noisy latents
|
||||
sigma: Current noise level
|
||||
sigma_next: Next noise level
|
||||
text_embeddings: Text conditioning
|
||||
positions: Position grid for RoPE
|
||||
text_mask: Optional attention mask for text
|
||||
|
||||
Returns:
|
||||
Denoised latents
|
||||
"""
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C)
|
||||
b, c, f, h, w = latents.shape
|
||||
latents_flat = mx.reshape(latents, (b, c, -1))
|
||||
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
|
||||
|
||||
# Create timestep tensor
|
||||
timesteps = mx.full((batch_size,), sigma)
|
||||
|
||||
# Create video modality input
|
||||
video_modality = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
positions=positions,
|
||||
context=text_embeddings,
|
||||
context_mask=text_mask,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Run denoising
|
||||
denoised_video, _ = self.x0_model(video=video_modality, audio=None)
|
||||
|
||||
# Reshape back: (B, F*H*W, C) -> (B, C, F, H, W)
|
||||
denoised_video = mx.transpose(denoised_video, (0, 2, 1))
|
||||
denoised_video = mx.reshape(denoised_video, (b, c, f, h, w))
|
||||
|
||||
# Euler step
|
||||
if sigma_next > 0:
|
||||
# x_next = x0 + sigma_next * (x - x0) / sigma
|
||||
noise = (latents - denoised_video) / sigma
|
||||
latents = denoised_video + sigma_next * noise
|
||||
else:
|
||||
latents = denoised_video
|
||||
|
||||
return latents
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Optional[GenerationConfig] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
"""Generate video from text prompt.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
config: Generation configuration
|
||||
seed: Random seed
|
||||
|
||||
Returns:
|
||||
Generated video tensor of shape (B, C, F, H, W)
|
||||
"""
|
||||
if config is None:
|
||||
config = GenerationConfig()
|
||||
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
batch_size = 1
|
||||
|
||||
# Prepare text embeddings
|
||||
text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size)
|
||||
|
||||
# Prepare initial latents
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_frames=config.latent_frames,
|
||||
height=config.latent_height,
|
||||
width=config.latent_width,
|
||||
)
|
||||
|
||||
# Prepare position grid
|
||||
positions = create_position_grid(
|
||||
batch_size=batch_size,
|
||||
num_frames=config.latent_frames,
|
||||
height=config.latent_height,
|
||||
width=config.latent_width,
|
||||
)
|
||||
|
||||
# Get sigma schedule
|
||||
num_tokens = config.latent_frames * config.latent_height * config.latent_width
|
||||
sigmas = get_sigmas(
|
||||
config.num_inference_steps,
|
||||
num_tokens,
|
||||
use_distilled=config.use_distilled,
|
||||
)
|
||||
|
||||
# Denoising loop
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma = float(sigmas[i])
|
||||
sigma_next = float(sigmas[i + 1])
|
||||
|
||||
latents = self.denoise_step(
|
||||
latents=latents,
|
||||
sigma=sigma,
|
||||
sigma_next=sigma_next,
|
||||
text_embeddings=text_embeddings,
|
||||
positions=positions,
|
||||
text_mask=text_mask,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Decode latents to video
|
||||
if self.vae_decoder is not None:
|
||||
video = self.vae_decoder(latents)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
return video
|
||||
|
||||
|
||||
def generate_video(
|
||||
prompt: str,
|
||||
transformer: LTXModel,
|
||||
text_encoder: Optional[LTX2TextEncoder] = None,
|
||||
tokenizer: Optional[any] = None,
|
||||
vae_decoder: Optional[VideoDecoder] = None,
|
||||
config: Optional[GenerationConfig] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
"""Generate video from text prompt.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
transformer: LTX transformer model
|
||||
text_encoder: Optional text encoder
|
||||
tokenizer: Optional tokenizer
|
||||
vae_decoder: Optional VAE decoder
|
||||
config: Generation configuration
|
||||
seed: Random seed
|
||||
|
||||
Returns:
|
||||
Generated video tensor
|
||||
"""
|
||||
pipeline = LTXVideoPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae_decoder=vae_decoder,
|
||||
)
|
||||
|
||||
return pipeline(prompt, config, seed)
|
||||
|
||||
|
||||
def load_pipeline(
|
||||
model_path: str,
|
||||
text_encoder_path: Optional[str] = None,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
load_text_encoder_weights: bool = True,
|
||||
) -> LTXVideoPipeline:
|
||||
"""Load complete LTX-2 video generation pipeline.
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model weights (safetensors)
|
||||
text_encoder_path: Path to text encoder weights directory
|
||||
tokenizer_path: Path to tokenizer directory
|
||||
load_text_encoder_weights: Whether to load text encoder weights
|
||||
|
||||
Returns:
|
||||
Configured LTXVideoPipeline
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.convert import sanitize_transformer_weights
|
||||
|
||||
print("Loading LTX-2 pipeline...")
|
||||
|
||||
# Load transformer
|
||||
print(" Loading transformer...")
|
||||
raw_weights = mx.load(model_path)
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.VideoOnly,
|
||||
num_attention_heads=32,
|
||||
attention_head_dim=128,
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
)
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
print(" Transformer loaded")
|
||||
|
||||
# Load VAE decoder
|
||||
print(" Loading VAE decoder...")
|
||||
vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True)
|
||||
print(" VAE decoder loaded")
|
||||
|
||||
# Load text encoder if paths provided
|
||||
text_encoder = None
|
||||
tokenizer = None
|
||||
|
||||
if load_text_encoder_weights and text_encoder_path is not None:
|
||||
print(" Loading text encoder...")
|
||||
text_encoder = load_text_encoder(model_path, text_encoder_path)
|
||||
print(" Text encoder loaded")
|
||||
|
||||
if tokenizer_path is not None:
|
||||
print(" Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
print(" Tokenizer loaded")
|
||||
|
||||
print("Pipeline ready!")
|
||||
|
||||
return LTXVideoPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae_decoder=vae_decoder,
|
||||
)
|
||||
|
||||
|
||||
def video_to_numpy(video: mx.array) -> np.ndarray:
|
||||
"""Convert video tensor to numpy array.
|
||||
|
||||
Args:
|
||||
video: Video tensor of shape (B, C, F, H, W) in range [-1, 1]
|
||||
|
||||
Returns:
|
||||
Numpy array of shape (B, F, H, W, C) in range [0, 255]
|
||||
"""
|
||||
# Clamp to [-1, 1]
|
||||
video = mx.clip(video, -1.0, 1.0)
|
||||
|
||||
# Scale to [0, 255]
|
||||
video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8)
|
||||
|
||||
# Rearrange: (B, C, F, H, W) -> (B, F, H, W, C)
|
||||
video = mx.transpose(video, (0, 2, 3, 4, 1))
|
||||
|
||||
return np.array(video)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||
|
||||
# Create a small test config
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.VideoOnly,
|
||||
num_layers=2, # Reduced for testing
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=32,
|
||||
)
|
||||
|
||||
# Create model
|
||||
model = LTXModel(config)
|
||||
|
||||
# Generate video
|
||||
gen_config = GenerationConfig(
|
||||
height=256,
|
||||
width=256,
|
||||
num_frames=9,
|
||||
num_inference_steps=4,
|
||||
)
|
||||
|
||||
print("Testing generation pipeline...")
|
||||
pipeline = LTXVideoPipeline(transformer=model)
|
||||
|
||||
# This would require proper text embeddings in practice
|
||||
# video = pipeline("A cat walking", gen_config, seed=42)
|
||||
# print(f"Generated video shape: {video.shape}")
|
||||
|
||||
print("Pipeline initialized successfully!")
|
||||
Reference in New Issue
Block a user