Remove main.py and refactor video generation logic into generate.py.

This commit is contained in:
Prince Canuma
2026-01-12 14:23:02 +01:00
parent 7114b023bd
commit 75511a0b17
3 changed files with 246 additions and 764 deletions

321
main.py
View File

@@ -1,321 +0,0 @@
import argparse
import time
from pathlib import Path
import mlx.core as mx
import numpy as np
from PIL import Image
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights
from mlx_video.generate import create_position_grid
from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from huggingface_hub import snapshot_download
# Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
def denoise(
latents: mx.array,
positions: mx.array,
text_embeddings: mx.array,
transformer: LTXModel,
sigmas: list,
) -> mx.array:
"""Run denoising loop."""
for i in range(len(sigmas) - 1):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
video_modality = Modality(
latent=latents_flat,
timesteps=mx.full((1,), sigma),
positions=positions,
context=text_embeddings,
context_mask=None,
enabled=True,
)
velocity, _ = transformer(video=video_modality, audio=None)
mx.eval(velocity)
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
mx.eval(denoised)
if sigma_next > 0:
latents = denoised + sigma_next * (latents - denoised) / sigma
else:
latents = denoised
mx.eval(latents)
return latents
def generate_video(
model_repo: str,
prompt: str,
height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
):
"""Generate video from text prompt.
Args:
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
seed: Random seed for reproducibility
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
"""
start_time = time.time()
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
print(f"Generating {width}x{height} video with {num_frames} frames")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
# Get model path
model_path = get_model_path(model_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
mx.random.seed(seed)
# Load text encoder
print("Loading text encoder...")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
mx.eval(text_encoder.parameters())
text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings)
del text_encoder
mx.clear_cache()
# Load transformer
print("Loading transformer...")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
# Stage 1: Generate at half resolution
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
# Upsample latents
print("Upsampling latents 2x...")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=True
)
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(latents)
del upsampler
mx.clear_cache()
# Stage 2: Refine at full resolution
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
# Add noise for refinement
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
del transformer
mx.clear_cache()
# Decode to video
print("Decoding video...")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
import imageio
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
print(f"Saved video to {output_path}")
except Exception as e:
print(f"Could not save video: {e}")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}")
elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
return video_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python main.py --prompt "A cat walking on grass"
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
"""
)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text description of the video to generate"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Output video height (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Output video width (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--num-frames", "-n",
type=int,
default=33,
help="Number of frames (default: 33, must be 1 + 8*k)"
)
parser.add_argument(
"--seed", "-s",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=24,
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output", "-o",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
)
parser.add_argument(
"--save-frames",
action="store_true",
help="Save individual frames as images"
)
parser.add_argument(
"--model-repo",
type=str,
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
args = parser.parse_args()
generate_video(
model_repo=args.model_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
)
if __name__ == "__main__":
main()

View File

@@ -1,13 +1,9 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.generate import LTXVideoPipeline, GenerationConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights from mlx_video.convert import load_transformer_weights, load_vae_weights
__all__ = [ __all__ = [
"LTXModel", "LTXModel",
"LTXModelConfig", "LTXModelConfig",
"LTXVideoPipeline",
"GenerationConfig",
"load_transformer_weights", "load_transformer_weights",
"load_vae_weights", "load_vae_weights",
] ]

View File

@@ -1,114 +1,25 @@
from dataclasses import dataclass import argparse
import time
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Iterator, Union
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
from PIL import Image
from mlx_video.models.ltx.ltx import LTXModel, X0Model from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality from mlx_video.models.ltx.transformer import Modality
from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder from mlx_video.convert import sanitize_transformer_weights
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from huggingface_hub import snapshot_download
@dataclass # Distilled sigma schedules
class GenerationConfig: STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
"""Configuration for video generation.""" STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
# Video dimensions
height: int = 512
width: int = 512
num_frames: int = 33 # Must be 1 + 8*k
# Diffusion parameters
num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True)
guidance_scale: float = 3.0
use_distilled: bool = True # Use hardcoded sigma values for distilled model
# Latent dimensions (computed from video dimensions)
@property
def latent_height(self) -> int:
return self.height // 32
@property
def latent_width(self) -> int:
return self.width // 32
@property
def latent_frames(self) -> int:
return 1 + (self.num_frames - 1) // 8
# Hardcoded sigma values for distilled model (from LTX-2 pipeline)
# These were tuned to match the distillation process
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
# Scheduler constants for dynamic sigma computation (non-distilled models)
BASE_SHIFT_ANCHOR = 1024
MAX_SHIFT_ANCHOR = 4096
def get_sigmas(
num_steps: int,
num_tokens: int,
max_shift: float = 2.05,
base_shift: float = 0.95,
stretch: bool = True,
terminal: float = 0.1,
use_distilled: bool = True,
) -> mx.array:
"""Get sigma schedule for diffusion.
Args:
num_steps: Number of diffusion steps
num_tokens: Number of latent tokens (T * H * W)
max_shift: Maximum shift for sigma schedule
base_shift: Base shift for sigma schedule
stretch: Whether to stretch sigmas to terminal value
terminal: Terminal value for stretching
use_distilled: If True, use hardcoded distilled sigma values
Returns:
Array of sigma values
"""
import math
# For distilled model, use hardcoded sigma values
if use_distilled:
return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32)
# For non-distilled models, compute dynamically using LTX2Scheduler logic
# Linear base schedule
sigmas = mx.linspace(1.0, 0.0, num_steps + 1)
# Compute token-dependent sigma shift
x1 = BASE_SHIFT_ANCHOR
x2 = MAX_SHIFT_ANCHOR
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = num_tokens * mm + b
# Apply exponential transformation
# sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1)
power = 1
exp_shift = math.exp(sigma_shift)
# Convert to numpy for computation then back to mx
sigmas_np = np.array(sigmas)
result = np.zeros_like(sigmas_np)
non_zero = sigmas_np != 0
result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power)
# Stretch sigmas so final value matches terminal
if stretch:
non_zero_mask = result != 0
non_zero_sigmas = result[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
result[non_zero_mask] = stretched
return mx.array(result, dtype=mx.float32)
def create_position_grid( def create_position_grid(
@@ -141,7 +52,6 @@ def create_position_grid(
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1 patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
# Generate grid coordinates for each dimension (frame, height, width) # Generate grid coordinates for each dimension (frame, height, width)
# These are the starting coordinates for each patch in latent space
t_coords = np.arange(0, num_frames, patch_size_t) t_coords = np.arange(0, num_frames, patch_size_t)
h_coords = np.arange(0, height, patch_size_h) h_coords = np.arange(0, height, patch_size_h)
w_coords = np.arange(0, width, patch_size_w) w_coords = np.arange(0, width, patch_size_w)
@@ -173,7 +83,6 @@ def create_position_grid(
# Apply causal fix for first frame temporal axis # Apply causal fix for first frame temporal axis
if causal_fix: if causal_fix:
# VAE temporal stride for first frame is 1 instead of temporal_scale # VAE temporal stride for first frame is 1 instead of temporal_scale
# Shift and clamp to keep first-frame timestamps non-negative
pixel_coords[:, 0, :, :] = np.clip( pixel_coords[:, 0, :, :] = np.clip(
pixel_coords[:, 0, :, :] + 1 - temporal_scale, pixel_coords[:, 0, :, :] + 1 - temporal_scale,
a_min=0, a_min=0,
@@ -186,307 +95,117 @@ def create_position_grid(
return mx.array(pixel_coords, dtype=mx.float32) return mx.array(pixel_coords, dtype=mx.float32)
class LTXVideoPipeline: def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
def __init__(
self,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_encoder: Optional[VideoEncoder] = None,
vae_decoder: Optional[VideoDecoder] = None,
):
"""Initialize pipeline.
Args: def denoise(
transformer: LTX transformer model latents: mx.array,
text_encoder: Optional LTX text encoder positions: mx.array,
tokenizer: Optional tokenizer for text encoding text_embeddings: mx.array,
vae_encoder: Optional VAE encoder transformer: LTXModel,
vae_decoder: Optional VAE decoder sigmas: list,
""" ) -> mx.array:
self.transformer = transformer """Run denoising loop."""
self.text_encoder = text_encoder for i in range(len(sigmas) - 1):
self.tokenizer = tokenizer sigma, sigma_next = sigmas[i], sigmas[i + 1]
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder
self.x0_model = X0Model(transformer)
def prepare_latents(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: mx.Dtype = mx.float16,
) -> mx.array:
"""Prepare initial noise latents.
Args:
batch_size: Batch size
num_frames: Number of latent frames
height: Latent height
width: Latent width
dtype: Data type
Returns:
Random latent noise
"""
# Use in_channels from transformer config
in_channels = self.transformer.config.in_channels
shape = (batch_size, in_channels, num_frames, height, width)
latents = mx.random.normal(shape).astype(dtype)
return latents
def prepare_text_embeddings(
self,
prompt: Union[str, List[str]],
batch_size: int,
max_length: int = 1024,
) -> Tuple[mx.array, Optional[mx.array]]:
"""Prepare text embeddings.
Args:
prompt: Text prompt or list of prompts
batch_size: Batch size
max_length: Maximum sequence length for tokenization
Returns:
Tuple of (text_embeddings, attention_mask)
"""
# If text encoder is available, use it
if self.text_encoder is not None and self.tokenizer is not None:
# Handle single or multiple prompts
if isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = prompt
# Tokenize
tokens = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
input_ids = mx.array(tokens["input_ids"])
attention_mask = mx.array(tokens["attention_mask"])
# Encode
embeddings = self.text_encoder(input_ids, attention_mask)
mx.eval(embeddings)
return embeddings, None # Connector handles masking internally
# Fallback: random embeddings (for testing without text encoder)
print("Warning: No text encoder provided, using random embeddings")
seq_len = max_length + 128 # Account for learnable registers
embed_dim = self.transformer.config.caption_channels
embeddings = mx.random.normal((batch_size, seq_len, embed_dim))
mask = mx.ones((batch_size, seq_len))
return embeddings, mask
def denoise_step(
self,
latents: mx.array,
sigma: float,
sigma_next: float,
text_embeddings: mx.array,
positions: mx.array,
text_mask: Optional[mx.array] = None,
) -> mx.array:
"""Perform one denoising step.
Args:
latents: Current noisy latents
sigma: Current noise level
sigma_next: Next noise level
text_embeddings: Text conditioning
positions: Position grid for RoPE
text_mask: Optional attention mask for text
Returns:
Denoised latents
"""
batch_size = latents.shape[0]
# Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C)
b, c, f, h, w = latents.shape b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1)) latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
# Create timestep tensor
timesteps = mx.full((batch_size,), sigma)
# Create video modality input
video_modality = Modality( video_modality = Modality(
latent=latents_flat, latent=latents_flat,
timesteps=timesteps, timesteps=mx.full((1,), sigma),
positions=positions, positions=positions,
context=text_embeddings, context=text_embeddings,
context_mask=text_mask, context_mask=None,
enabled=True, enabled=True,
) )
# Run denoising velocity, _ = transformer(video=video_modality, audio=None)
denoised_video, _ = self.x0_model(video=video_modality, audio=None) mx.eval(velocity)
# Reshape back: (B, F*H*W, C) -> (B, C, F, H, W) velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised_video = mx.transpose(denoised_video, (0, 2, 1)) denoised = to_denoised(latents, velocity, sigma)
denoised_video = mx.reshape(denoised_video, (b, c, f, h, w)) mx.eval(denoised)
# Euler step
if sigma_next > 0: if sigma_next > 0:
# x_next = x0 + sigma_next * (x - x0) / sigma latents = denoised + sigma_next * (latents - denoised) / sigma
noise = (latents - denoised_video) / sigma
latents = denoised_video + sigma_next * noise
else: else:
latents = denoised_video latents = denoised
mx.eval(latents)
return latents return latents
def __call__(
self,
prompt: str,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
"""Generate video from text prompt.
Args:
prompt: Text prompt
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor of shape (B, C, F, H, W)
"""
if config is None:
config = GenerationConfig()
if seed is not None:
mx.random.seed(seed)
batch_size = 1
# Prepare text embeddings
text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size)
# Prepare initial latents
latents = self.prepare_latents(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Prepare position grid
positions = create_position_grid(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Get sigma schedule
num_tokens = config.latent_frames * config.latent_height * config.latent_width
sigmas = get_sigmas(
config.num_inference_steps,
num_tokens,
use_distilled=config.use_distilled,
)
# Denoising loop
for i in range(len(sigmas) - 1):
sigma = float(sigmas[i])
sigma_next = float(sigmas[i + 1])
latents = self.denoise_step(
latents=latents,
sigma=sigma,
sigma_next=sigma_next,
text_embeddings=text_embeddings,
positions=positions,
text_mask=text_mask,
)
mx.eval(latents)
# Decode latents to video
if self.vae_decoder is not None:
video = self.vae_decoder(latents)
else:
video = latents
return video
def generate_video( def generate_video(
model_repo: str,
prompt: str, prompt: str,
transformer: LTXModel, height: int = 512,
text_encoder: Optional[LTX2TextEncoder] = None, width: int = 512,
tokenizer: Optional[any] = None, num_frames: int = 33,
vae_decoder: Optional[VideoDecoder] = None, seed: int = 42,
config: Optional[GenerationConfig] = None, fps: int = 24,
seed: Optional[int] = None, output_path: str = "output.mp4",
) -> mx.array: save_frames: bool = False,
):
"""Generate video from text prompt. """Generate video from text prompt.
Args: Args:
prompt: Text prompt prompt: Text description of the video to generate
transformer: LTX transformer model height: Output video height (must be divisible by 64)
text_encoder: Optional text encoder width: Output video width (must be divisible by 64)
tokenizer: Optional tokenizer num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
vae_decoder: Optional VAE decoder seed: Random seed for reproducibility
config: Generation configuration fps: Frames per second for output video
seed: Random seed output_path: Path to save the output video
save_frames: Whether to save individual frames as images
Returns:
Generated video tensor
""" """
pipeline = LTXVideoPipeline( start_time = time.time()
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
)
return pipeline(prompt, config, seed) # Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
print(f"Generating {width}x{height} video with {num_frames} frames")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
def load_pipeline( # Get model path
model_path: str, model_path = get_model_path(model_repo)
text_encoder_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
load_text_encoder_weights: bool = True,
) -> LTXVideoPipeline:
"""Load complete LTX-2 video generation pipeline.
Args: # Calculate latent dimensions
model_path: Path to LTX-2 model weights (safetensors) stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
text_encoder_path: Path to text encoder weights directory stage2_h, stage2_w = height // 32, width // 32
tokenizer_path: Path to tokenizer directory latent_frames = 1 + (num_frames - 1) // 8
load_text_encoder_weights: Whether to load text encoder weights
Returns: mx.random.seed(seed)
Configured LTXVideoPipeline
"""
from transformers import AutoTokenizer
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType # Load text encoder
from mlx_video.models.ltx.ltx import LTXModel print("Loading text encoder...")
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
from mlx_video.convert import sanitize_transformer_weights text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
mx.eval(text_encoder.parameters())
print("Loading LTX-2 pipeline...") text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings)
del text_encoder
mx.clear_cache()
# Load transformer # Load transformer
print(" Loading transformer...") print("Loading transformer...")
raw_weights = mx.load(model_path) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig( config = LTXModelConfig(
@@ -498,89 +217,177 @@ def load_pipeline(
num_layers=48, num_layers=48,
cross_attention_dim=4096, cross_attention_dim=4096,
caption_channels=3840, caption_channels=3840,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
) )
transformer = LTXModel(config) transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False) transformer.load_weights(list(sanitized.items()), strict=False)
print(" Transformer loaded") mx.eval(transformer.parameters())
# Load VAE decoder # Stage 1: Generate at half resolution
print(" Loading VAE decoder...") print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True) mx.random.seed(seed)
print(" VAE decoder loaded") latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
# Load text encoder if paths provided positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
text_encoder = None mx.eval(positions)
tokenizer = None
if load_text_encoder_weights and text_encoder_path is not None: latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
print(" Loading text encoder...")
text_encoder = load_text_encoder(model_path, text_encoder_path)
print(" Text encoder loaded")
if tokenizer_path is not None: # Upsample latents
print(" Loading tokenizer...") print("Upsampling latents 2x...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
print(" Tokenizer loaded") mx.eval(upsampler.parameters())
print("Pipeline ready!") vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
return LTXVideoPipeline( timestep_conditioning=True
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
) )
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(latents)
def video_to_numpy(video: mx.array) -> np.ndarray: del upsampler
"""Convert video tensor to numpy array. mx.clear_cache()
Args: # Stage 2: Refine at full resolution
video: Video tensor of shape (B, C, F, H, W) in range [-1, 1] print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
Returns: # Add noise for refinement
Numpy array of shape (B, F, H, W, C) in range [0, 255] noise_scale = STAGE_2_SIGMAS[0]
""" noise = mx.random.normal(latents.shape)
# Clamp to [-1, 1] latents = noise * noise_scale + latents * (1 - noise_scale)
video = mx.clip(video, -1.0, 1.0) mx.eval(latents)
# Scale to [0, 255] latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8)
# Rearrange: (B, C, F, H, W) -> (B, F, H, W, C) del transformer
video = mx.transpose(video, (0, 2, 3, 4, 1)) mx.clear_cache()
return np.array(video) # Decode to video
print("Decoding video...")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
import imageio
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
print(f"Saved video to {output_path}")
except Exception as e:
print(f"Could not save video: {e}")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}")
elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
return video_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m mlx_video.generate --prompt "A cat walking on grass"
python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768
python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
"""
)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text description of the video to generate"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Output video height (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Output video width (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--num-frames", "-n",
type=int,
default=33,
help="Number of frames (default: 33, must be 1 + 8*k)"
)
parser.add_argument(
"--seed", "-s",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=24,
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output", "-o",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
)
parser.add_argument(
"--save-frames",
action="store_true",
help="Save individual frames as images"
)
parser.add_argument(
"--model-repo",
type=str,
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
args = parser.parse_args()
generate_video(
model_repo=args.model_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
)
if __name__ == "__main__": if __name__ == "__main__":
# Example usage main()
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
# Create a small test config
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_layers=2, # Reduced for testing
num_attention_heads=4,
attention_head_dim=32,
)
# Create model
model = LTXModel(config)
# Generate video
gen_config = GenerationConfig(
height=256,
width=256,
num_frames=9,
num_inference_steps=4,
)
print("Testing generation pipeline...")
pipeline = LTXVideoPipeline(transformer=model)
# This would require proper text embeddings in practice
# video = pipeline("A cat walking", gen_config, seed=42)
# print(f"Generated video shape: {video.shape}")
print("Pipeline initialized successfully!")