Remove main.py and refactor video generation logic into generate.py.
This commit is contained in:
321
main.py
321
main.py
@@ -1,321 +0,0 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.transformer import Modality
|
||||
from mlx_video.convert import sanitize_transformer_weights
|
||||
from mlx_video.generate import create_position_grid
|
||||
from mlx_video.utils import to_denoised
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
# Distilled sigma schedules
|
||||
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
|
||||
|
||||
def denoise(
|
||||
latents: mx.array,
|
||||
positions: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigmas: list,
|
||||
) -> mx.array:
|
||||
"""Run denoising loop."""
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
|
||||
b, c, f, h, w = latents.shape
|
||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||
|
||||
video_modality = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=mx.full((1,), sigma),
|
||||
positions=positions,
|
||||
context=text_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
velocity, _ = transformer(video=video_modality, audio=None)
|
||||
mx.eval(velocity)
|
||||
|
||||
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||
denoised = to_denoised(latents, velocity, sigma)
|
||||
mx.eval(denoised)
|
||||
|
||||
if sigma_next > 0:
|
||||
latents = denoised + sigma_next * (latents - denoised) / sigma
|
||||
else:
|
||||
latents = denoised
|
||||
mx.eval(latents)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def generate_video(
|
||||
model_repo: str,
|
||||
prompt: str,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_frames: int = 33,
|
||||
seed: int = 42,
|
||||
fps: int = 24,
|
||||
output_path: str = "output.mp4",
|
||||
save_frames: bool = False,
|
||||
):
|
||||
"""Generate video from text prompt.
|
||||
|
||||
Args:
|
||||
prompt: Text description of the video to generate
|
||||
height: Output video height (must be divisible by 64)
|
||||
width: Output video width (must be divisible by 64)
|
||||
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
|
||||
seed: Random seed for reproducibility
|
||||
fps: Frames per second for output video
|
||||
output_path: Path to save the output video
|
||||
save_frames: Whether to save individual frames as images
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Validate dimensions
|
||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||
|
||||
print(f"Generating {width}x{height} video with {num_frames} frames")
|
||||
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||||
|
||||
# Get model path
|
||||
model_path = get_model_path(model_repo)
|
||||
|
||||
# Calculate latent dimensions
|
||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||
stage2_h, stage2_w = height // 32, width // 32
|
||||
latent_frames = 1 + (num_frames - 1) // 8
|
||||
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Load text encoder
|
||||
print("Loading text encoder...")
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
||||
text_encoder.load(str(model_path))
|
||||
mx.eval(text_encoder.parameters())
|
||||
|
||||
text_embeddings, _ = text_encoder(prompt)
|
||||
mx.eval(text_embeddings)
|
||||
|
||||
del text_encoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Load transformer
|
||||
print("Loading transformer...")
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.VideoOnly,
|
||||
num_attention_heads=32,
|
||||
attention_head_dim=128,
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision_rope=True,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
mx.eval(transformer.parameters())
|
||||
|
||||
# Stage 1: Generate at half resolution
|
||||
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
|
||||
mx.random.seed(seed)
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||
mx.eval(latents)
|
||||
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
||||
|
||||
# Upsample latents
|
||||
print("Upsampling latents 2x...")
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
vae_decoder = load_vae_decoder(
|
||||
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||
timestep_conditioning=True
|
||||
)
|
||||
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||
mx.eval(latents)
|
||||
|
||||
del upsampler
|
||||
mx.clear_cache()
|
||||
|
||||
# Stage 2: Refine at full resolution
|
||||
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
# Add noise for refinement
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
noise = mx.random.normal(latents.shape)
|
||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||
mx.eval(latents)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode to video
|
||||
print("Decoding video...")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
mx.clear_cache()
|
||||
|
||||
# Convert to uint8 frames
|
||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
import imageio
|
||||
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
|
||||
print(f"Saved video to {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Could not save video: {e}")
|
||||
|
||||
if save_frames:
|
||||
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
||||
frames_dir.mkdir(exist_ok=True)
|
||||
for i, frame in enumerate(video_np):
|
||||
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
|
||||
print(f"Saved {len(video_np)} frames to {frames_dir}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
|
||||
|
||||
return video_np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate videos with MLX LTX-2",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python main.py --prompt "A cat walking on grass"
|
||||
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
|
||||
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt", "-p",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Text description of the video to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height", "-H",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output video height (default: 512, must be divisible by 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width", "-W",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output video width (default: 512, must be divisible by 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames", "-n",
|
||||
type=int,
|
||||
default=33,
|
||||
help="Number of frames (default: 33, must be 1 + 8*k)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", "-s",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducibility (default: 42)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=24,
|
||||
help="Frames per second for output video (default: 24)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
type=str,
|
||||
default="output.mp4",
|
||||
help="Output video path (default: output.mp4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
action="store_true",
|
||||
help="Save individual frames as images"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-repo",
|
||||
type=str,
|
||||
default="Lightricks/LTX-2",
|
||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
model_repo=args.model_repo,
|
||||
prompt=args.prompt,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_frames=args.num_frames,
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
output_path=args.output,
|
||||
save_frames=args.save_frames,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,13 +1,9 @@
|
||||
|
||||
from mlx_video.models.ltx import LTXModel, LTXModelConfig
|
||||
from mlx_video.generate import LTXVideoPipeline, GenerationConfig
|
||||
from mlx_video.convert import load_transformer_weights, load_vae_weights
|
||||
|
||||
__all__ = [
|
||||
"LTXModel",
|
||||
"LTXModelConfig",
|
||||
"LTXVideoPipeline",
|
||||
"GenerationConfig",
|
||||
"load_transformer_weights",
|
||||
"load_vae_weights",
|
||||
]
|
||||
@@ -1,114 +1,25 @@
|
||||
from dataclasses import dataclass
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Iterator, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from mlx_video.models.ltx.ltx import LTXModel, X0Model
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.transformer import Modality
|
||||
from mlx_video.models.ltx.video_vae import VideoEncoder, VideoDecoder
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder, load_text_encoder
|
||||
from mlx_video.convert import sanitize_transformer_weights
|
||||
from mlx_video.utils import to_denoised
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
"""Configuration for video generation."""
|
||||
# Video dimensions
|
||||
height: int = 512
|
||||
width: int = 512
|
||||
num_frames: int = 33 # Must be 1 + 8*k
|
||||
|
||||
# Diffusion parameters
|
||||
num_inference_steps: int = 8 # For distilled model (ignored if use_distilled=True)
|
||||
guidance_scale: float = 3.0
|
||||
use_distilled: bool = True # Use hardcoded sigma values for distilled model
|
||||
|
||||
# Latent dimensions (computed from video dimensions)
|
||||
@property
|
||||
def latent_height(self) -> int:
|
||||
return self.height // 32
|
||||
|
||||
@property
|
||||
def latent_width(self) -> int:
|
||||
return self.width // 32
|
||||
|
||||
@property
|
||||
def latent_frames(self) -> int:
|
||||
return 1 + (self.num_frames - 1) // 8
|
||||
|
||||
|
||||
# Hardcoded sigma values for distilled model (from LTX-2 pipeline)
|
||||
# These were tuned to match the distillation process
|
||||
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
|
||||
|
||||
# Scheduler constants for dynamic sigma computation (non-distilled models)
|
||||
BASE_SHIFT_ANCHOR = 1024
|
||||
MAX_SHIFT_ANCHOR = 4096
|
||||
|
||||
|
||||
def get_sigmas(
|
||||
num_steps: int,
|
||||
num_tokens: int,
|
||||
max_shift: float = 2.05,
|
||||
base_shift: float = 0.95,
|
||||
stretch: bool = True,
|
||||
terminal: float = 0.1,
|
||||
use_distilled: bool = True,
|
||||
) -> mx.array:
|
||||
"""Get sigma schedule for diffusion.
|
||||
|
||||
Args:
|
||||
num_steps: Number of diffusion steps
|
||||
num_tokens: Number of latent tokens (T * H * W)
|
||||
max_shift: Maximum shift for sigma schedule
|
||||
base_shift: Base shift for sigma schedule
|
||||
stretch: Whether to stretch sigmas to terminal value
|
||||
terminal: Terminal value for stretching
|
||||
use_distilled: If True, use hardcoded distilled sigma values
|
||||
|
||||
Returns:
|
||||
Array of sigma values
|
||||
"""
|
||||
import math
|
||||
|
||||
# For distilled model, use hardcoded sigma values
|
||||
if use_distilled:
|
||||
return mx.array(DISTILLED_SIGMA_VALUES, dtype=mx.float32)
|
||||
|
||||
# For non-distilled models, compute dynamically using LTX2Scheduler logic
|
||||
# Linear base schedule
|
||||
sigmas = mx.linspace(1.0, 0.0, num_steps + 1)
|
||||
|
||||
# Compute token-dependent sigma shift
|
||||
x1 = BASE_SHIFT_ANCHOR
|
||||
x2 = MAX_SHIFT_ANCHOR
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
sigma_shift = num_tokens * mm + b
|
||||
|
||||
# Apply exponential transformation
|
||||
# sigmas = exp(sigma_shift) / (exp(sigma_shift) + (1/sigmas - 1)^1)
|
||||
power = 1
|
||||
exp_shift = math.exp(sigma_shift)
|
||||
|
||||
# Convert to numpy for computation then back to mx
|
||||
sigmas_np = np.array(sigmas)
|
||||
result = np.zeros_like(sigmas_np)
|
||||
non_zero = sigmas_np != 0
|
||||
result[non_zero] = exp_shift / (exp_shift + (1.0 / sigmas_np[non_zero] - 1.0) ** power)
|
||||
|
||||
# Stretch sigmas so final value matches terminal
|
||||
if stretch:
|
||||
non_zero_mask = result != 0
|
||||
non_zero_sigmas = result[non_zero_mask]
|
||||
one_minus_z = 1.0 - non_zero_sigmas
|
||||
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||
result[non_zero_mask] = stretched
|
||||
|
||||
return mx.array(result, dtype=mx.float32)
|
||||
# Distilled sigma schedules
|
||||
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||
|
||||
|
||||
def create_position_grid(
|
||||
@@ -141,7 +52,6 @@ def create_position_grid(
|
||||
patch_size_t, patch_size_h, patch_size_w = 1, 1, 1
|
||||
|
||||
# Generate grid coordinates for each dimension (frame, height, width)
|
||||
# These are the starting coordinates for each patch in latent space
|
||||
t_coords = np.arange(0, num_frames, patch_size_t)
|
||||
h_coords = np.arange(0, height, patch_size_h)
|
||||
w_coords = np.arange(0, width, patch_size_w)
|
||||
@@ -173,7 +83,6 @@ def create_position_grid(
|
||||
# Apply causal fix for first frame temporal axis
|
||||
if causal_fix:
|
||||
# VAE temporal stride for first frame is 1 instead of temporal_scale
|
||||
# Shift and clamp to keep first-frame timestamps non-negative
|
||||
pixel_coords[:, 0, :, :] = np.clip(
|
||||
pixel_coords[:, 0, :, :] + 1 - temporal_scale,
|
||||
a_min=0,
|
||||
@@ -186,307 +95,117 @@ def create_position_grid(
|
||||
return mx.array(pixel_coords, dtype=mx.float32)
|
||||
|
||||
|
||||
class LTXVideoPipeline:
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: LTXModel,
|
||||
text_encoder: Optional[LTX2TextEncoder] = None,
|
||||
tokenizer: Optional[any] = None,
|
||||
vae_encoder: Optional[VideoEncoder] = None,
|
||||
vae_decoder: Optional[VideoDecoder] = None,
|
||||
):
|
||||
"""Initialize pipeline.
|
||||
|
||||
Args:
|
||||
transformer: LTX transformer model
|
||||
text_encoder: Optional LTX text encoder
|
||||
tokenizer: Optional tokenizer for text encoding
|
||||
vae_encoder: Optional VAE encoder
|
||||
vae_decoder: Optional VAE decoder
|
||||
"""
|
||||
self.transformer = transformer
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.vae_encoder = vae_encoder
|
||||
self.vae_decoder = vae_decoder
|
||||
self.x0_model = X0Model(transformer)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: mx.Dtype = mx.float16,
|
||||
) -> mx.array:
|
||||
"""Prepare initial noise latents.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size
|
||||
num_frames: Number of latent frames
|
||||
height: Latent height
|
||||
width: Latent width
|
||||
dtype: Data type
|
||||
|
||||
Returns:
|
||||
Random latent noise
|
||||
"""
|
||||
# Use in_channels from transformer config
|
||||
in_channels = self.transformer.config.in_channels
|
||||
shape = (batch_size, in_channels, num_frames, height, width)
|
||||
latents = mx.random.normal(shape).astype(dtype)
|
||||
return latents
|
||||
|
||||
def prepare_text_embeddings(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
batch_size: int,
|
||||
max_length: int = 1024,
|
||||
) -> Tuple[mx.array, Optional[mx.array]]:
|
||||
"""Prepare text embeddings.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt or list of prompts
|
||||
batch_size: Batch size
|
||||
max_length: Maximum sequence length for tokenization
|
||||
|
||||
Returns:
|
||||
Tuple of (text_embeddings, attention_mask)
|
||||
"""
|
||||
# If text encoder is available, use it
|
||||
if self.text_encoder is not None and self.tokenizer is not None:
|
||||
# Handle single or multiple prompts
|
||||
if isinstance(prompt, str):
|
||||
prompts = [prompt] * batch_size
|
||||
else:
|
||||
prompts = prompt
|
||||
|
||||
# Tokenize
|
||||
tokens = self.tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
input_ids = mx.array(tokens["input_ids"])
|
||||
attention_mask = mx.array(tokens["attention_mask"])
|
||||
|
||||
# Encode
|
||||
embeddings = self.text_encoder(input_ids, attention_mask)
|
||||
mx.eval(embeddings)
|
||||
|
||||
return embeddings, None # Connector handles masking internally
|
||||
|
||||
# Fallback: random embeddings (for testing without text encoder)
|
||||
print("Warning: No text encoder provided, using random embeddings")
|
||||
seq_len = max_length + 128 # Account for learnable registers
|
||||
embed_dim = self.transformer.config.caption_channels
|
||||
|
||||
embeddings = mx.random.normal((batch_size, seq_len, embed_dim))
|
||||
mask = mx.ones((batch_size, seq_len))
|
||||
|
||||
return embeddings, mask
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
def denoise(
|
||||
latents: mx.array,
|
||||
sigma: float,
|
||||
sigma_next: float,
|
||||
text_embeddings: mx.array,
|
||||
positions: mx.array,
|
||||
text_mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
"""Perform one denoising step.
|
||||
text_embeddings: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigmas: list,
|
||||
) -> mx.array:
|
||||
"""Run denoising loop."""
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
|
||||
Args:
|
||||
latents: Current noisy latents
|
||||
sigma: Current noise level
|
||||
sigma_next: Next noise level
|
||||
text_embeddings: Text conditioning
|
||||
positions: Position grid for RoPE
|
||||
text_mask: Optional attention mask for text
|
||||
|
||||
Returns:
|
||||
Denoised latents
|
||||
"""
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
# Flatten latents for transformer: (B, C, F, H, W) -> (B, F*H*W, C)
|
||||
b, c, f, h, w = latents.shape
|
||||
latents_flat = mx.reshape(latents, (b, c, -1))
|
||||
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
|
||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||
|
||||
# Create timestep tensor
|
||||
timesteps = mx.full((batch_size,), sigma)
|
||||
|
||||
# Create video modality input
|
||||
video_modality = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
timesteps=mx.full((1,), sigma),
|
||||
positions=positions,
|
||||
context=text_embeddings,
|
||||
context_mask=text_mask,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Run denoising
|
||||
denoised_video, _ = self.x0_model(video=video_modality, audio=None)
|
||||
velocity, _ = transformer(video=video_modality, audio=None)
|
||||
mx.eval(velocity)
|
||||
|
||||
# Reshape back: (B, F*H*W, C) -> (B, C, F, H, W)
|
||||
denoised_video = mx.transpose(denoised_video, (0, 2, 1))
|
||||
denoised_video = mx.reshape(denoised_video, (b, c, f, h, w))
|
||||
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||
denoised = to_denoised(latents, velocity, sigma)
|
||||
mx.eval(denoised)
|
||||
|
||||
# Euler step
|
||||
if sigma_next > 0:
|
||||
# x_next = x0 + sigma_next * (x - x0) / sigma
|
||||
noise = (latents - denoised_video) / sigma
|
||||
latents = denoised_video + sigma_next * noise
|
||||
latents = denoised + sigma_next * (latents - denoised) / sigma
|
||||
else:
|
||||
latents = denoised_video
|
||||
latents = denoised
|
||||
mx.eval(latents)
|
||||
|
||||
return latents
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
config: Optional[GenerationConfig] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
"""Generate video from text prompt.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
config: Generation configuration
|
||||
seed: Random seed
|
||||
|
||||
Returns:
|
||||
Generated video tensor of shape (B, C, F, H, W)
|
||||
"""
|
||||
if config is None:
|
||||
config = GenerationConfig()
|
||||
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
batch_size = 1
|
||||
|
||||
# Prepare text embeddings
|
||||
text_embeddings, text_mask = self.prepare_text_embeddings(prompt, batch_size)
|
||||
|
||||
# Prepare initial latents
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_frames=config.latent_frames,
|
||||
height=config.latent_height,
|
||||
width=config.latent_width,
|
||||
)
|
||||
|
||||
# Prepare position grid
|
||||
positions = create_position_grid(
|
||||
batch_size=batch_size,
|
||||
num_frames=config.latent_frames,
|
||||
height=config.latent_height,
|
||||
width=config.latent_width,
|
||||
)
|
||||
|
||||
# Get sigma schedule
|
||||
num_tokens = config.latent_frames * config.latent_height * config.latent_width
|
||||
sigmas = get_sigmas(
|
||||
config.num_inference_steps,
|
||||
num_tokens,
|
||||
use_distilled=config.use_distilled,
|
||||
)
|
||||
|
||||
# Denoising loop
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma = float(sigmas[i])
|
||||
sigma_next = float(sigmas[i + 1])
|
||||
|
||||
latents = self.denoise_step(
|
||||
latents=latents,
|
||||
sigma=sigma,
|
||||
sigma_next=sigma_next,
|
||||
text_embeddings=text_embeddings,
|
||||
positions=positions,
|
||||
text_mask=text_mask,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Decode latents to video
|
||||
if self.vae_decoder is not None:
|
||||
video = self.vae_decoder(latents)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
return video
|
||||
|
||||
|
||||
def generate_video(
|
||||
model_repo: str,
|
||||
prompt: str,
|
||||
transformer: LTXModel,
|
||||
text_encoder: Optional[LTX2TextEncoder] = None,
|
||||
tokenizer: Optional[any] = None,
|
||||
vae_decoder: Optional[VideoDecoder] = None,
|
||||
config: Optional[GenerationConfig] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> mx.array:
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_frames: int = 33,
|
||||
seed: int = 42,
|
||||
fps: int = 24,
|
||||
output_path: str = "output.mp4",
|
||||
save_frames: bool = False,
|
||||
):
|
||||
"""Generate video from text prompt.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
transformer: LTX transformer model
|
||||
text_encoder: Optional text encoder
|
||||
tokenizer: Optional tokenizer
|
||||
vae_decoder: Optional VAE decoder
|
||||
config: Generation configuration
|
||||
seed: Random seed
|
||||
|
||||
Returns:
|
||||
Generated video tensor
|
||||
prompt: Text description of the video to generate
|
||||
height: Output video height (must be divisible by 64)
|
||||
width: Output video width (must be divisible by 64)
|
||||
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
|
||||
seed: Random seed for reproducibility
|
||||
fps: Frames per second for output video
|
||||
output_path: Path to save the output video
|
||||
save_frames: Whether to save individual frames as images
|
||||
"""
|
||||
pipeline = LTXVideoPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae_decoder=vae_decoder,
|
||||
)
|
||||
start_time = time.time()
|
||||
|
||||
return pipeline(prompt, config, seed)
|
||||
# Validate dimensions
|
||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||
|
||||
print(f"Generating {width}x{height} video with {num_frames} frames")
|
||||
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||||
|
||||
def load_pipeline(
|
||||
model_path: str,
|
||||
text_encoder_path: Optional[str] = None,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
load_text_encoder_weights: bool = True,
|
||||
) -> LTXVideoPipeline:
|
||||
"""Load complete LTX-2 video generation pipeline.
|
||||
# Get model path
|
||||
model_path = get_model_path(model_repo)
|
||||
|
||||
Args:
|
||||
model_path: Path to LTX-2 model weights (safetensors)
|
||||
text_encoder_path: Path to text encoder weights directory
|
||||
tokenizer_path: Path to tokenizer directory
|
||||
load_text_encoder_weights: Whether to load text encoder weights
|
||||
# Calculate latent dimensions
|
||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||
stage2_h, stage2_w = height // 32, width // 32
|
||||
latent_frames = 1 + (num_frames - 1) // 8
|
||||
|
||||
Returns:
|
||||
Configured LTXVideoPipeline
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
mx.random.seed(seed)
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||
from mlx_video.models.ltx.ltx import LTXModel
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.convert import sanitize_transformer_weights
|
||||
# Load text encoder
|
||||
print("Loading text encoder...")
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
||||
text_encoder.load(str(model_path))
|
||||
mx.eval(text_encoder.parameters())
|
||||
|
||||
print("Loading LTX-2 pipeline...")
|
||||
text_embeddings, _ = text_encoder(prompt)
|
||||
mx.eval(text_embeddings)
|
||||
|
||||
del text_encoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Load transformer
|
||||
print(" Loading transformer...")
|
||||
raw_weights = mx.load(model_path)
|
||||
print("Loading transformer...")
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
|
||||
config = LTXModelConfig(
|
||||
@@ -498,89 +217,177 @@ def load_pipeline(
|
||||
num_layers=48,
|
||||
cross_attention_dim=4096,
|
||||
caption_channels=3840,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision_rope=True,
|
||||
positional_embedding_theta=10000.0,
|
||||
positional_embedding_max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
timestep_scale_multiplier=1000,
|
||||
)
|
||||
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
print(" Transformer loaded")
|
||||
mx.eval(transformer.parameters())
|
||||
|
||||
# Load VAE decoder
|
||||
print(" Loading VAE decoder...")
|
||||
vae_decoder = load_vae_decoder(model_path, timestep_conditioning=True)
|
||||
print(" VAE decoder loaded")
|
||||
# Stage 1: Generate at half resolution
|
||||
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
|
||||
mx.random.seed(seed)
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||
mx.eval(latents)
|
||||
|
||||
# Load text encoder if paths provided
|
||||
text_encoder = None
|
||||
tokenizer = None
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
if load_text_encoder_weights and text_encoder_path is not None:
|
||||
print(" Loading text encoder...")
|
||||
text_encoder = load_text_encoder(model_path, text_encoder_path)
|
||||
print(" Text encoder loaded")
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
||||
|
||||
if tokenizer_path is not None:
|
||||
print(" Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
print(" Tokenizer loaded")
|
||||
# Upsample latents
|
||||
print("Upsampling latents 2x...")
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
print("Pipeline ready!")
|
||||
|
||||
return LTXVideoPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae_decoder=vae_decoder,
|
||||
vae_decoder = load_vae_decoder(
|
||||
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||
timestep_conditioning=True
|
||||
)
|
||||
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||
mx.eval(latents)
|
||||
|
||||
def video_to_numpy(video: mx.array) -> np.ndarray:
|
||||
"""Convert video tensor to numpy array.
|
||||
del upsampler
|
||||
mx.clear_cache()
|
||||
|
||||
Args:
|
||||
video: Video tensor of shape (B, C, F, H, W) in range [-1, 1]
|
||||
# Stage 2: Refine at full resolution
|
||||
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
Returns:
|
||||
Numpy array of shape (B, F, H, W, C) in range [0, 255]
|
||||
# Add noise for refinement
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
noise = mx.random.normal(latents.shape)
|
||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||
mx.eval(latents)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode to video
|
||||
print("Decoding video...")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
mx.clear_cache()
|
||||
|
||||
# Convert to uint8 frames
|
||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
import imageio
|
||||
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
|
||||
print(f"Saved video to {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Could not save video: {e}")
|
||||
|
||||
if save_frames:
|
||||
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
||||
frames_dir.mkdir(exist_ok=True)
|
||||
for i, frame in enumerate(video_np):
|
||||
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
|
||||
print(f"Saved {len(video_np)} frames to {frames_dir}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
|
||||
|
||||
return video_np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate videos with MLX LTX-2",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m mlx_video.generate --prompt "A cat walking on grass"
|
||||
python -m mlx_video.generate --prompt "Ocean waves at sunset" --height 768 --width 768
|
||||
python -m mlx_video.generate --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
|
||||
"""
|
||||
# Clamp to [-1, 1]
|
||||
video = mx.clip(video, -1.0, 1.0)
|
||||
)
|
||||
|
||||
# Scale to [0, 255]
|
||||
video = ((video + 1.0) / 2.0 * 255.0).astype(mx.uint8)
|
||||
parser.add_argument(
|
||||
"--prompt", "-p",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Text description of the video to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height", "-H",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output video height (default: 512, must be divisible by 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width", "-W",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output video width (default: 512, must be divisible by 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames", "-n",
|
||||
type=int,
|
||||
default=33,
|
||||
help="Number of frames (default: 33, must be 1 + 8*k)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", "-s",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducibility (default: 42)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=24,
|
||||
help="Frames per second for output video (default: 24)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
type=str,
|
||||
default="output.mp4",
|
||||
help="Output video path (default: output.mp4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
action="store_true",
|
||||
help="Save individual frames as images"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-repo",
|
||||
type=str,
|
||||
default="Lightricks/LTX-2",
|
||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Rearrange: (B, C, F, H, W) -> (B, F, H, W, C)
|
||||
video = mx.transpose(video, (0, 2, 3, 4, 1))
|
||||
|
||||
return np.array(video)
|
||||
generate_video(
|
||||
model_repo=args.model_repo,
|
||||
prompt=args.prompt,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_frames=args.num_frames,
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
output_path=args.output,
|
||||
save_frames=args.save_frames,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType
|
||||
|
||||
# Create a small test config
|
||||
config = LTXModelConfig(
|
||||
model_type=LTXModelType.VideoOnly,
|
||||
num_layers=2, # Reduced for testing
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=32,
|
||||
)
|
||||
|
||||
# Create model
|
||||
model = LTXModel(config)
|
||||
|
||||
# Generate video
|
||||
gen_config = GenerationConfig(
|
||||
height=256,
|
||||
width=256,
|
||||
num_frames=9,
|
||||
num_inference_steps=4,
|
||||
)
|
||||
|
||||
print("Testing generation pipeline...")
|
||||
pipeline = LTXVideoPipeline(transformer=model)
|
||||
|
||||
# This would require proper text embeddings in practice
|
||||
# video = pipeline("A cat walking", gen_config, seed=42)
|
||||
# print(f"Generated video shape: {video.shape}")
|
||||
|
||||
print("Pipeline initialized successfully!")
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user