Remove main.py and refactor video generation logic into generate.py.

This commit is contained in:
Prince Canuma
2026-01-12 14:23:02 +01:00
parent 7114b023bd
commit 75511a0b17
3 changed files with 246 additions and 764 deletions

321
main.py
View File

@@ -1,321 +0,0 @@
import argparse
import time
from pathlib import Path
import mlx.core as mx
import numpy as np
from PIL import Image
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.convert import sanitize_transformer_weights
from mlx_video.generate import create_position_grid
from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from huggingface_hub import snapshot_download
# Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
def denoise(
latents: mx.array,
positions: mx.array,
text_embeddings: mx.array,
transformer: LTXModel,
sigmas: list,
) -> mx.array:
"""Run denoising loop."""
for i in range(len(sigmas) - 1):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
video_modality = Modality(
latent=latents_flat,
timesteps=mx.full((1,), sigma),
positions=positions,
context=text_embeddings,
context_mask=None,
enabled=True,
)
velocity, _ = transformer(video=video_modality, audio=None)
mx.eval(velocity)
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
mx.eval(denoised)
if sigma_next > 0:
latents = denoised + sigma_next * (latents - denoised) / sigma
else:
latents = denoised
mx.eval(latents)
return latents
def generate_video(
model_repo: str,
prompt: str,
height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
):
"""Generate video from text prompt.
Args:
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
seed: Random seed for reproducibility
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
"""
start_time = time.time()
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
print(f"Generating {width}x{height} video with {num_frames} frames")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
# Get model path
model_path = get_model_path(model_repo)
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
mx.random.seed(seed)
# Load text encoder
print("Loading text encoder...")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
mx.eval(text_encoder.parameters())
text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings)
del text_encoder
mx.clear_cache()
# Load transformer
print("Loading transformer...")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_attention_heads=32,
attention_head_dim=128,
in_channels=128,
out_channels=128,
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
# Stage 1: Generate at half resolution
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
# Upsample latents
print("Upsampling latents 2x...")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=True
)
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(latents)
del upsampler
mx.clear_cache()
# Stage 2: Refine at full resolution
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
# Add noise for refinement
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
del transformer
mx.clear_cache()
# Decode to video
print("Decoding video...")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
import imageio
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
print(f"Saved video to {output_path}")
except Exception as e:
print(f"Could not save video: {e}")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}")
elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
return video_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python main.py --prompt "A cat walking on grass"
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
"""
)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text description of the video to generate"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Output video height (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Output video width (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--num-frames", "-n",
type=int,
default=33,
help="Number of frames (default: 33, must be 1 + 8*k)"
)
parser.add_argument(
"--seed", "-s",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=24,
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output", "-o",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
)
parser.add_argument(
"--save-frames",
action="store_true",
help="Save individual frames as images"
)
parser.add_argument(
"--model-repo",
type=str,
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
args = parser.parse_args()
generate_video(
model_repo=args.model_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
)
if __name__ == "__main__":
main()

View File

@@ -1,13 +1,9 @@
from mlx_video.models.ltx import LTXModel, LTXModelConfig
from mlx_video.generate import LTXVideoPipeline, GenerationConfig
from mlx_video.convert import load_transformer_weights, load_vae_weights
__all__ = [
"LTXModel",
"LTXModelConfig",
"LTXVideoPipeline",
"GenerationConfig",
"load_transformer_weights",
"load_vae_weights",
]

View File

@@ -1,114 +1,25 @@
from dataclasses import dataclass
import argparse
import time
from pathlib import Path
from typing import List, Optional, Tuple, Iterator, Union
import mlx.core as mx
import numpy as np
from PIL import Image
from mlx_video.models.ltx.ltx import LTXModel, X0Model
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.transformer import Modality
from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder
from mlx_video.convert import sanitize_transformer_weights
from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from huggingface_hub import snapshot_download
@dataclass
class GenerationConfig:
"""Configuration for video generation."""
# Video dimensions
height: int = 512
width: int = 512
num_frames: int = 33 # Must be 1 + 8*k
# Diffusion parameters
num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True)
guidance_scale: float = 3.0
use_distilled: bool = True # Use hardcoded sigma values for distilled model
# Latent dimensions (computed from video dimensions)
@property
def latent_height(self) -> int:
return self.height // 32
@property
def latent_width(self) -> int:
return self.width // 32
@property
def latent_frames(self) -> int:
return 1 + (self.num_frames - 1) // 8
# Hardcoded sigma values for distilled model (from LTX-2 pipeline)
# These were tuned to match the distillation process
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
# Scheduler constants for dynamic sigma computation (non-distilled models)
BASE_SHIFT_ANCHOR = 1024
MAX_SHIFT_ANCHOR = 4096
def get_sigmas(
num_steps: int,
num_tokens: int,
max_shift: float = 2.05,
base_shift: float = 0.95,
stretch: bool = True,
terminal: float = 0.1,
use_distilled: bool = True,
) -> mx.array:
"""Get sigma schedule for diffusion.
Args:
num_steps: Number of diffusion steps
num_tokens: Number of latent tokens (T * H * W)
max_shift: Maximum shift for sigma schedule
base_shift: Base shift for sigma schedule
stretch: Whether to stretch sigmas to terminal value
terminal: Terminal value for stretching
use_distilled: If True, use hardcoded distilled sigma values
Returns:
Array of sigma values
"""
import math
# For distilled model, use hardcoded sigma values
if use_distilled:
return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32)
# For non-distilled models, compute dynamically using LTX2Scheduler logic
# Linear base schedule
sigmas = mx.linspace(1.0, 0.0, num_steps + 1)
# Compute token-dependent sigma shift
x1 = BASE_SHIFT_ANCHOR
x2 = MAX_SHIFT_ANCHOR
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = num_tokens * mm + b
# Apply exponential transformation
# sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1)
power = 1
exp_shift = math.exp(sigma_shift)
# Convert to numpy for computation then back to mx
sigmas_np = np.array(sigmas)
result = np.zeros_like(sigmas_np)
non_zero = sigmas_np != 0
result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power)
# Stretch sigmas so final value matches terminal
if stretch:
non_zero_mask = result != 0
non_zero_sigmas = result[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
result[non_zero_mask] = stretched
return mx.array(result, dtype=mx.float32)
# Distilled sigma schedules
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
def create_position_grid(
@@ -141,7 +52,6 @@ def create_position_grid(
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
# Generate grid coordinates for each dimension (frame, height, width)
# These are the starting coordinates for each patch in latent space
t_coords = np.arange(0, num_frames, patch_size_t)
h_coords = np.arange(0, height, patch_size_h)
w_coords = np.arange(0, width, patch_size_w)
@@ -173,7 +83,6 @@ def create_position_grid(
# Apply causal fix for first frame temporal axis
if causal_fix:
# VAE temporal stride for first frame is 1 instead of temporal_scale
# Shift and clamp to keep first-frame timestamps non-negative
pixel_coords[:, 0, :, :] = np.clip(
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
a_min=0,
@@ -186,307 +95,117 @@ def create_position_grid(
return mx.array(pixel_coords, dtype=mx.float32)
class LTXVideoPipeline:
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
def __init__(
self,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_encoder: Optional[VideoEncoder] = None,
vae_decoder: Optional[VideoDecoder] = None,
):
"""Initialize pipeline.
Args:
transformer: LTX transformer model
text_encoder: Optional LTX text encoder
tokenizer: Optional tokenizer for text encoding
vae_encoder: Optional VAE encoder
vae_decoder: Optional VAE decoder
"""
self.transformer = transformer
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder
self.x0_model = X0Model(transformer)
def prepare_latents(
self,
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: mx.Dtype = mx.float16,
) -> mx.array:
"""Prepare initial noise latents.
Args:
batch_size: Batch size
num_frames: Number of latent frames
height: Latent height
width: Latent width
dtype: Data type
Returns:
Random latent noise
"""
# Use in_channels from transformer config
in_channels = self.transformer.config.in_channels
shape = (batch_size, in_channels, num_frames, height, width)
latents = mx.random.normal(shape).astype(dtype)
return latents
def prepare_text_embeddings(
self,
prompt: Union[str, List[str]],
batch_size: int,
max_length: int = 1024,
) -> Tuple[mx.array, Optional[mx.array]]:
"""Prepare text embeddings.
Args:
prompt: Text prompt or list of prompts
batch_size: Batch size
max_length: Maximum sequence length for tokenization
Returns:
Tuple of (text_embeddings, attention_mask)
"""
# If text encoder is available, use it
if self.text_encoder is not None and self.tokenizer is not None:
# Handle single or multiple prompts
if isinstance(prompt, str):
prompts = [prompt] * batch_size
else:
prompts = prompt
# Tokenize
tokens = self.tokenizer(
prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
input_ids = mx.array(tokens["input_ids"])
attention_mask = mx.array(tokens["attention_mask"])
# Encode
embeddings = self.text_encoder(input_ids, attention_mask)
mx.eval(embeddings)
return embeddings, None # Connector handles masking internally
# Fallback: random embeddings (for testing without text encoder)
print("Warning: No text encoder provided, using random embeddings")
seq_len = max_length + 128 # Account for learnable registers
embed_dim = self.transformer.config.caption_channels
embeddings = mx.random.normal((batch_size, seq_len, embed_dim))
mask = mx.ones((batch_size, seq_len))
return embeddings, mask
def denoise_step(
self,
def denoise(
latents: mx.array,
sigma: float,
sigma_next: float,
text_embeddings: mx.array,
positions: mx.array,
text_mask: Optional[mx.array] = None,
) -> mx.array:
"""Perform one denoising step.
text_embeddings: mx.array,
transformer: LTXModel,
sigmas: list,
) -> mx.array:
"""Run denoising loop."""
for i in range(len(sigmas) - 1):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
Args:
latents: Current noisy latents
sigma: Current noise level
sigma_next: Next noise level
text_embeddings: Text conditioning
positions: Position grid for RoPE
text_mask: Optional attention mask for text
Returns:
Denoised latents
"""
batch_size = latents.shape[0]
# Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C)
b, c, f, h, w = latents.shape
latents_flat = mx.reshape(latents, (b, c, -1))
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
# Create timestep tensor
timesteps = mx.full((batch_size,), sigma)
# Create video modality input
video_modality = Modality(
latent=latents_flat,
timesteps=timesteps,
timesteps=mx.full((1,), sigma),
positions=positions,
context=text_embeddings,
context_mask=text_mask,
context_mask=None,
enabled=True,
)
# Run denoising
denoised_video, _ = self.x0_model(video=video_modality, audio=None)
velocity, _ = transformer(video=video_modality, audio=None)
mx.eval(velocity)
# Reshape back: (B, F*H*W, C) -> (B, C, F, H, W)
denoised_video = mx.transpose(denoised_video, (0, 2, 1))
denoised_video = mx.reshape(denoised_video, (b, c, f, h, w))
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
mx.eval(denoised)
# Euler step
if sigma_next > 0:
# x_next = x0 + sigma_next * (x - x0) / sigma
noise = (latents - denoised_video) / sigma
latents = denoised_video + sigma_next * noise
latents = denoised + sigma_next * (latents - denoised) / sigma
else:
latents = denoised_video
latents = denoised
mx.eval(latents)
return latents
def __call__(
self,
prompt: str,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
"""Generate video from text prompt.
Args:
prompt: Text prompt
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor of shape (B, C, F, H, W)
"""
if config is None:
config = GenerationConfig()
if seed is not None:
mx.random.seed(seed)
batch_size = 1
# Prepare text embeddings
text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size)
# Prepare initial latents
latents = self.prepare_latents(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Prepare position grid
positions = create_position_grid(
batch_size=batch_size,
num_frames=config.latent_frames,
height=config.latent_height,
width=config.latent_width,
)
# Get sigma schedule
num_tokens = config.latent_frames * config.latent_height * config.latent_width
sigmas = get_sigmas(
config.num_inference_steps,
num_tokens,
use_distilled=config.use_distilled,
)
# Denoising loop
for i in range(len(sigmas) - 1):
sigma = float(sigmas[i])
sigma_next = float(sigmas[i + 1])
latents = self.denoise_step(
latents=latents,
sigma=sigma,
sigma_next=sigma_next,
text_embeddings=text_embeddings,
positions=positions,
text_mask=text_mask,
)
mx.eval(latents)
# Decode latents to video
if self.vae_decoder is not None:
video = self.vae_decoder(latents)
else:
video = latents
return video
def generate_video(
model_repo: str,
prompt: str,
transformer: LTXModel,
text_encoder: Optional[LTX2TextEncoder] = None,
tokenizer: Optional[any] = None,
vae_decoder: Optional[VideoDecoder] = None,
config: Optional[GenerationConfig] = None,
seed: Optional[int] = None,
) -> mx.array:
height: int = 512,
width: int = 512,
num_frames: int = 33,
seed: int = 42,
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
):
"""Generate video from text prompt.
Args:
prompt: Text prompt
transformer: LTX transformer model
text_encoder: Optional text encoder
tokenizer: Optional tokenizer
vae_decoder: Optional VAE decoder
config: Generation configuration
seed: Random seed
Returns:
Generated video tensor
prompt: Text description of the video to generate
height: Output video height (must be divisible by 64)
width: Output video width (must be divisible by 64)
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
seed: Random seed for reproducibility
fps: Frames per second for output video
output_path: Path to save the output video
save_frames: Whether to save individual frames as images
"""
pipeline = LTXVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
)
start_time = time.time()
return pipeline(prompt, config, seed)
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
print(f"Generating {width}x{height} video with {num_frames} frames")
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
def load_pipeline(
model_path: str,
text_encoder_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
load_text_encoder_weights: bool = True,
) -> LTXVideoPipeline:
"""Load complete LTX-2 video generation pipeline.
# Get model path
model_path = get_model_path(model_repo)
Args:
model_path: Path to LTX-2 model weights (safetensors)
text_encoder_path: Path to text encoder weights directory
tokenizer_path: Path to tokenizer directory
load_text_encoder_weights: Whether to load text encoder weights
# Calculate latent dimensions
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
stage2_h, stage2_w = height // 32, width // 32
latent_frames = 1 + (num_frames - 1) // 8
Returns:
Configured LTXVideoPipeline
"""
from transformers import AutoTokenizer
mx.random.seed(seed)
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
from mlx_video.models.ltx.ltx import LTXModel
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.convert import sanitize_transformer_weights
# Load text encoder
print("Loading text encoder...")
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
text_encoder = LTX2TextEncoder(model_path=str(model_path))
text_encoder.load(str(model_path))
mx.eval(text_encoder.parameters())
print("Loading LTX-2 pipeline...")
text_embeddings, _ = text_encoder(prompt)
mx.eval(text_embeddings)
del text_encoder
mx.clear_cache()
# Load transformer
print(" Loading transformer...")
raw_weights = mx.load(model_path)
print("Loading transformer...")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
config = LTXModelConfig(
@@ -498,89 +217,177 @@ def load_pipeline(
num_layers=48,
cross_attention_dim=4096,
caption_channels=3840,
rope_type=LTXRopeType.SPLIT,
double_precision_rope=True,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
timestep_scale_multiplier=1000,
)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
print(" Transformer loaded")
mx.eval(transformer.parameters())
# Load VAE decoder
print(" Loading VAE decoder...")
vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True)
print(" VAE decoder loaded")
# Stage 1: Generate at half resolution
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
mx.random.seed(seed)
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
mx.eval(latents)
# Load text encoder if paths provided
text_encoder = None
tokenizer = None
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
if load_text_encoder_weights and text_encoder_path is not None:
print(" Loading text encoder...")
text_encoder = load_text_encoder(model_path, text_encoder_path)
print(" Text encoder loaded")
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
if tokenizer_path is not None:
print(" Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
print(" Tokenizer loaded")
# Upsample latents
print("Upsampling latents 2x...")
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
mx.eval(upsampler.parameters())
print("Pipeline ready!")
return LTXVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae_decoder=vae_decoder,
vae_decoder = load_vae_decoder(
str(model_path / 'ltx-2-19b-distilled.safetensors'),
timestep_conditioning=True
)
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
mx.eval(latents)
def video_to_numpy(video: mx.array) -> np.ndarray:
"""Convert video tensor to numpy array.
del upsampler
mx.clear_cache()
Args:
video: Video tensor of shape (B, C, F, H, W) in range [-1, 1]
# Stage 2: Refine at full resolution
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
Returns:
Numpy array of shape (B, F, H, W, C) in range [0, 255]
# Add noise for refinement
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
del transformer
mx.clear_cache()
# Decode to video
print("Decoding video...")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
# Convert to uint8 frames
video = mx.squeeze(video, axis=0) # (C, F, H, W)
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
video = (video * 255).astype(mx.uint8)
video_np = np.array(video)
# Save outputs
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
import imageio
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
print(f"Saved video to {output_path}")
except Exception as e:
print(f"Could not save video: {e}")
if save_frames:
frames_dir = output_path.parent / f"{output_path.stem}_frames"
frames_dir.mkdir(exist_ok=True)
for i, frame in enumerate(video_np):
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
print(f"Saved {len(video_np)} frames to {frames_dir}")
elapsed = time.time() - start_time
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
return video_np
def main():
parser = argparse.ArgumentParser(
description="Generate videos with MLX LTX-2",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m mlx_video.generate --prompt "A cat walking on grass"
python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768
python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
"""
# Clamp to [-1, 1]
video = mx.clip(video, -1.0, 1.0)
)
# Scale to [0, 255]
video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text description of the video to generate"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Output video height (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Output video width (default: 512, must be divisible by 32)"
)
parser.add_argument(
"--num-frames", "-n",
type=int,
default=33,
help="Number of frames (default: 33, must be 1 + 8*k)"
)
parser.add_argument(
"--seed", "-s",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--fps",
type=int,
default=24,
help="Frames per second for output video (default: 24)"
)
parser.add_argument(
"--output", "-o",
type=str,
default="output.mp4",
help="Output video path (default: output.mp4)"
)
parser.add_argument(
"--save-frames",
action="store_true",
help="Save individual frames as images"
)
parser.add_argument(
"--model-repo",
type=str,
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
args = parser.parse_args()
# Rearrange: (B, C, F, H, W) -> (B, F, H, W, C)
video = mx.transpose(video, (0, 2, 3, 4, 1))
return np.array(video)
generate_video(
model_repo=args.model_repo,
prompt=args.prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
seed=args.seed,
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
)
if __name__ == "__main__":
# Example usage
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
# Create a small test config
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
num_layers=2, # Reduced for testing
num_attention_heads=4,
attention_head_dim=32,
)
# Create model
model = LTXModel(config)
# Generate video
gen_config = GenerationConfig(
height=256,
width=256,
num_frames=9,
num_inference_steps=4,
)
print("Testing generation pipeline...")
pipeline = LTXVideoPipeline(transformer=model)
# This would require proper text embeddings in practice
# video = pipeline("A cat walking", gen_config, seed=42)
# print(f"Generated video shape: {video.shape}")
print("Pipeline initialized successfully!")
main()