- Refactor video generation script
- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions. - Updated VAE weight sanitization for compatibility and improved activation function handling in text projection. - Added support for saving individual frames and refined output video generation process.
This commit is contained in:
451
main.py
451
main.py
@@ -1,6 +1,9 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||
@@ -10,205 +13,127 @@ from mlx_video.convert import sanitize_transformer_weights
|
||||
from mlx_video.generate import create_position_grid
|
||||
from mlx_video.utils import to_denoised
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.upsampler import LatentUpsampler, load_upsampler, upsample_latents
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
|
||||
# Paths
|
||||
from huggingface_hub import snapshot_download
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
LTX2_REPO = "Lightricks/LTX-2"
|
||||
|
||||
def get_ltx2_cache_dir():
|
||||
# Try to get local cache (local_only), will not download files
|
||||
# Distilled sigma schedules
|
||||
STAGE_1_SIGMAS = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_SIGMAS = [0.909375, 0.725, 0.421875, 0.0]
|
||||
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
ref_path = snapshot_download(
|
||||
repo_id=LTX2_REPO,
|
||||
local_files_only=True,
|
||||
allow_patterns=["*"],
|
||||
ignore_patterns=[],
|
||||
# leave as default revision and cache_dir, only local
|
||||
)
|
||||
return ref_path
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
# If not present locally, download from hub
|
||||
return snapshot_download(
|
||||
repo_id=LTX2_REPO,
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
ignore_patterns=[]
|
||||
)
|
||||
|
||||
LTX2_PATH = Path(get_ltx2_cache_dir())
|
||||
MODEL_PATH = str(LTX2_PATH / 'ltx-2-19b-distilled.safetensors')
|
||||
UPSAMPLER_PATH = str(LTX2_PATH / 'ltx-2-spatial-upscaler-x2-1.0.safetensors')
|
||||
TEXT_ENCODER_PATH = str(LTX2_PATH / 'text_encoder')
|
||||
TOKENIZER_PATH = str(LTX2_PATH / 'tokenizer')
|
||||
|
||||
# Distilled sigma schedules (from PyTorch)
|
||||
STAGE_1_SIGMA_SCHEDULE = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_SIGMA_SCHEDULE = [0.909375, 0.725, 0.421875, 0.0] # Refinement steps
|
||||
))
|
||||
|
||||
|
||||
def denoise_loop(
|
||||
def denoise(
|
||||
latents: mx.array,
|
||||
positions: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigma_schedule: list,
|
||||
stage_name: str = "Stage",
|
||||
negative_embeddings: mx.array = None,
|
||||
cfg_scale: float = 1.0,
|
||||
sigmas: list,
|
||||
) -> mx.array:
|
||||
"""Run denoising loop for given sigma schedule.
|
||||
|
||||
Args:
|
||||
latents: Noisy latent tensor
|
||||
positions: Position embeddings
|
||||
text_embeddings: Positive prompt embeddings
|
||||
transformer: The transformer model
|
||||
sigma_schedule: List of sigma values for each step
|
||||
stage_name: Name for logging
|
||||
negative_embeddings: Negative prompt embeddings for CFG (optional)
|
||||
cfg_scale: Classifier-free guidance scale (1.0 = no guidance)
|
||||
"""
|
||||
use_cfg = negative_embeddings is not None and cfg_scale > 1.0
|
||||
|
||||
for i in range(len(sigma_schedule) - 1):
|
||||
sigma = sigma_schedule[i]
|
||||
sigma_next = sigma_schedule[i + 1]
|
||||
|
||||
print(f" {stage_name} step {i+1}/{len(sigma_schedule)-1}: sigma={sigma:.4f} -> {sigma_next:.4f}")
|
||||
"""Run denoising loop."""
|
||||
for i in range(len(sigmas) - 1):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
|
||||
b, c, f, h, w = latents.shape
|
||||
latents_flat = mx.reshape(latents, (b, c, -1))
|
||||
latents_flat = mx.transpose(latents_flat, (0, 2, 1))
|
||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||
|
||||
timesteps = mx.full((1,), sigma)
|
||||
|
||||
# Positive (conditioned) prediction
|
||||
video_modality = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
timesteps=mx.full((1,), sigma),
|
||||
positions=positions,
|
||||
context=text_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
vx_cond, _ = transformer(video=video_modality, audio=None)
|
||||
mx.eval(vx_cond)
|
||||
velocity, _ = transformer(video=video_modality, audio=None)
|
||||
mx.eval(velocity)
|
||||
|
||||
if use_cfg:
|
||||
# Negative (unconditioned) prediction
|
||||
video_modality_neg = Modality(
|
||||
latent=latents_flat,
|
||||
timesteps=timesteps,
|
||||
positions=positions,
|
||||
context=negative_embeddings,
|
||||
context_mask=None,
|
||||
enabled=True,
|
||||
)
|
||||
vx_uncond, _ = transformer(video=video_modality_neg, audio=None)
|
||||
mx.eval(vx_uncond)
|
||||
|
||||
# CFG: output = uncond + cfg_scale * (cond - uncond)
|
||||
vx = vx_uncond + cfg_scale * (vx_cond - vx_uncond)
|
||||
else:
|
||||
vx = vx_cond
|
||||
|
||||
vx_reshaped = mx.transpose(vx, (0, 2, 1))
|
||||
vx_reshaped = mx.reshape(vx_reshaped, (b, c, f, h, w))
|
||||
|
||||
# Debug: Print velocity stats
|
||||
vx_np = np.array(vx_reshaped)
|
||||
print(f" Velocity: min={vx_np.min():.4f}, max={vx_np.max():.4f}, mean={vx_np.mean():.4f}")
|
||||
|
||||
# Get denoised prediction: x_0 = x_t - sigma * velocity
|
||||
denoised = to_denoised(latents, vx_reshaped, sigma)
|
||||
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
||||
denoised = to_denoised(latents, velocity, sigma)
|
||||
mx.eval(denoised)
|
||||
|
||||
# Debug: Print denoised stats
|
||||
denoised_np = np.array(denoised)
|
||||
print(f" Denoised: min={denoised_np.min():.4f}, max={denoised_np.max():.4f}, mean={denoised_np.mean():.4f}")
|
||||
|
||||
# Euler step: x_next = x_0 + sigma_next * (x_t - x_0) / sigma
|
||||
if sigma_next > 0:
|
||||
velocity = (latents - denoised) / sigma
|
||||
latents = denoised + sigma_next * velocity
|
||||
latents = denoised + sigma_next * (latents - denoised) / sigma
|
||||
else:
|
||||
latents = denoised
|
||||
mx.eval(latents)
|
||||
|
||||
# Debug: Print latents after step
|
||||
latents_np = np.array(latents)
|
||||
print(f" Latents after step: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def main():
|
||||
print("="*60)
|
||||
print("MLX LTX-2 Video Generation (Two-Stage)")
|
||||
print("="*60)
|
||||
def generate_video(
|
||||
model_repo: str,
|
||||
prompt: str,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_frames: int = 33,
|
||||
seed: int = 42,
|
||||
fps: int = 24,
|
||||
output_path: str = "output.mp4",
|
||||
save_frames: bool = False,
|
||||
):
|
||||
"""Generate video from text prompt.
|
||||
|
||||
# Config - same as PyTorch reference
|
||||
prompt = "A beautiful woman with flowing dark hair stands on a tropical beach at golden hour, gentle waves lapping at her feet, she turns and smiles at the camera, warm sunlight illuminating her face, palm trees swaying in the background, cinematic lighting, photorealistic"
|
||||
negative_prompt = "" # PyTorch script doesn't use negative prompt
|
||||
cfg_scale = 1.0 # No CFG in the distilled pipeline
|
||||
height, width, num_frames = 512, 512, 500 # Must be divisible by 64 for two-stage
|
||||
seed = 123
|
||||
Args:
|
||||
prompt: Text description of the video to generate
|
||||
height: Output video height (must be divisible by 64)
|
||||
width: Output video width (must be divisible by 64)
|
||||
num_frames: Number of frames (must be 1 + 8*k, e.g., 33, 65, 97)
|
||||
seed: Random seed for reproducibility
|
||||
fps: Frames per second for output video
|
||||
output_path: Path to save the output video
|
||||
save_frames: Whether to save individual frames as images
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Stage 1: Half resolution
|
||||
stage1_height = height // 2
|
||||
stage1_width = width // 2
|
||||
stage1_latent_height = stage1_height // 32
|
||||
stage1_latent_width = stage1_width // 32
|
||||
# Validate dimensions
|
||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||
|
||||
print(f"Generating {width}x{height} video with {num_frames} frames")
|
||||
print(f"Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||||
|
||||
# Get model path
|
||||
model_path = get_model_path(model_repo)
|
||||
|
||||
# Calculate latent dimensions
|
||||
stage1_h, stage1_w = height // 2 // 32, width // 2 // 32
|
||||
stage2_h, stage2_w = height // 32, width // 32
|
||||
latent_frames = 1 + (num_frames - 1) // 8
|
||||
|
||||
# Stage 2: Full resolution
|
||||
latent_height = height // 32
|
||||
latent_width = width // 32
|
||||
|
||||
print(f"\nConfig:")
|
||||
print(f" Prompt: {prompt}")
|
||||
print(f" Negative prompt: '{negative_prompt}'")
|
||||
print(f" CFG scale: {cfg_scale}")
|
||||
print(f" Final resolution: {width}x{height}, {num_frames} frames")
|
||||
print(f" Stage 1: {stage1_width}x{stage1_height} -> latent {stage1_latent_width}x{stage1_latent_height}")
|
||||
print(f" Stage 2: {width}x{height} -> latent {latent_width}x{latent_height}")
|
||||
print(f" Seed: {seed}")
|
||||
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Load text encoder
|
||||
print("\nLoading text encoder...")
|
||||
print("Loading text encoder...")
|
||||
from mlx_video.models.ltx.text_encoder import LTX2TextEncoder
|
||||
|
||||
text_encoder = LTX2TextEncoder(model_path=str(LTX2_PATH))
|
||||
text_encoder.load(str(LTX2_PATH))
|
||||
text_encoder = LTX2TextEncoder(model_path=str(model_path))
|
||||
text_encoder.load(str(model_path))
|
||||
mx.eval(text_encoder.parameters())
|
||||
|
||||
# Encode positive prompt
|
||||
print("Encoding text...")
|
||||
text_embeddings, attention_mask = text_encoder(prompt)
|
||||
text_embeddings, _ = text_encoder(prompt)
|
||||
mx.eval(text_embeddings)
|
||||
print(f" Positive embeddings: {text_embeddings.shape}")
|
||||
|
||||
# Encode negative prompt for CFG
|
||||
negative_embeddings, _ = text_encoder(negative_prompt)
|
||||
mx.eval(negative_embeddings)
|
||||
print(f" Negative embeddings: {negative_embeddings.shape}")
|
||||
|
||||
# Free text encoder memory
|
||||
del text_encoder
|
||||
mx.clear_cache()
|
||||
|
||||
# Load transformer
|
||||
print("\nLoading transformer...")
|
||||
raw_weights = mx.load(MODEL_PATH)
|
||||
print("Loading transformer...")
|
||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||
sanitized = sanitize_transformer_weights(raw_weights)
|
||||
|
||||
config = LTXModelConfig(
|
||||
@@ -231,164 +156,166 @@ def main():
|
||||
transformer = LTXModel(config)
|
||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||
mx.eval(transformer.parameters())
|
||||
print(" Transformer loaded!")
|
||||
|
||||
# ========================================
|
||||
# Stage 1: Generate at half resolution
|
||||
# ========================================
|
||||
print("\n" + "="*60)
|
||||
print("Stage 1: Generating at half resolution")
|
||||
print("="*60)
|
||||
|
||||
print(f"Stage 1: Generating at {width//2}x{height//2} (8 steps)...")
|
||||
mx.random.seed(seed)
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_latent_height, stage1_latent_width))
|
||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
||||
mx.eval(latents)
|
||||
print(f" Initial latents: {latents.shape}")
|
||||
|
||||
positions = create_position_grid(1, latent_frames, stage1_latent_height, stage1_latent_width)
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
latents = denoise_loop(
|
||||
latents=latents,
|
||||
positions=positions,
|
||||
text_embeddings=text_embeddings,
|
||||
transformer=transformer,
|
||||
sigma_schedule=STAGE_1_SIGMA_SCHEDULE,
|
||||
stage_name="Stage 1",
|
||||
negative_embeddings=negative_embeddings,
|
||||
cfg_scale=cfg_scale,
|
||||
)
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
||||
|
||||
print(f"\nStage 1 latents: {latents.shape}")
|
||||
latents_np = np.array(latents)
|
||||
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
|
||||
# ========================================
|
||||
# Upsample latents 2x
|
||||
# ========================================
|
||||
print("\n" + "="*60)
|
||||
print("Upsampling latents 2x")
|
||||
print("="*60)
|
||||
|
||||
# Load upsampler
|
||||
print(" Loading spatial upsampler...")
|
||||
upsampler = load_upsampler(UPSAMPLER_PATH)
|
||||
# Upsample latents
|
||||
print("Upsampling latents 2x...")
|
||||
upsampler = load_upsampler(str(model_path / 'ltx-2-spatial-upscaler-x2-1.0.safetensors'))
|
||||
mx.eval(upsampler.parameters())
|
||||
|
||||
# Load latent statistics for normalization
|
||||
vae_decoder = load_vae_decoder(MODEL_PATH, timestep_conditioning=True)
|
||||
# EXPERIMENT: Disable VAE decode noise for sharper output
|
||||
# vae_decoder.decode_noise_scale = 0.0
|
||||
# print(f" VAE decode_noise_scale set to {vae_decoder.decode_noise_scale}")
|
||||
latent_mean = vae_decoder.latents_mean
|
||||
latent_std = vae_decoder.latents_std
|
||||
vae_decoder = load_vae_decoder(
|
||||
str(model_path / 'ltx-2-19b-distilled.safetensors'),
|
||||
timestep_conditioning=True
|
||||
)
|
||||
|
||||
# Upsample
|
||||
print(" Upsampling...")
|
||||
latents = upsample_latents(latents, upsampler, latent_mean, latent_std, debug=False)
|
||||
latents = upsample_latents(latents, upsampler, vae_decoder.latents_mean, vae_decoder.latents_std)
|
||||
mx.eval(latents)
|
||||
print(f" Upsampled latents: {latents.shape}")
|
||||
|
||||
# Free upsampler memory
|
||||
del upsampler
|
||||
mx.clear_cache()
|
||||
|
||||
# ========================================
|
||||
# Stage 2: Refine at full resolution
|
||||
# ========================================
|
||||
print("\n" + "="*60)
|
||||
print("Stage 2: Refining at full resolution")
|
||||
print("="*60)
|
||||
|
||||
# Debug: Print upsampled latent stats before adding noise
|
||||
latents_np = np.array(latents)
|
||||
print(f" Upsampled latents (before noise): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
|
||||
# Create new position grid for full resolution
|
||||
positions = create_position_grid(1, latent_frames, latent_height, latent_width)
|
||||
print(f"Stage 2: Refining at {width}x{height} (3 steps)...")
|
||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||
mx.eval(positions)
|
||||
|
||||
# Add noise at initial sigma for stage 2
|
||||
# PyTorch uses interpolation: noisy = noise * scale + clean * (1 - scale)
|
||||
# NOT addition: noisy = clean + scale * noise
|
||||
noise_scale = STAGE_2_SIGMA_SCHEDULE[0]
|
||||
# Add noise for refinement
|
||||
noise_scale = STAGE_2_SIGMAS[0]
|
||||
noise = mx.random.normal(latents.shape)
|
||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||
mx.eval(latents)
|
||||
|
||||
# Debug: Print latents after adding noise
|
||||
latents_np = np.array(latents)
|
||||
print(f" After adding noise (sigma={noise_scale}): min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
||||
|
||||
latents = denoise_loop(
|
||||
latents=latents,
|
||||
positions=positions,
|
||||
text_embeddings=text_embeddings,
|
||||
transformer=transformer,
|
||||
sigma_schedule=STAGE_2_SIGMA_SCHEDULE,
|
||||
stage_name="Stage 2",
|
||||
)
|
||||
|
||||
print(f"\nFinal latents: {latents.shape}")
|
||||
latents_np = np.array(latents)
|
||||
print(f" Stats: min={latents_np.min():.4f}, max={latents_np.max():.4f}, mean={latents_np.mean():.4f}")
|
||||
|
||||
# Save latents for PyTorch comparison
|
||||
np.save("mlx_final_latents.npy", latents_np)
|
||||
print(" Saved latents to mlx_final_latents.npy")
|
||||
|
||||
# Free transformer memory
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
# ========================================
|
||||
# Decode to video
|
||||
# ========================================
|
||||
print("\n" + "="*60)
|
||||
print("Decoding with VAE")
|
||||
print("="*60)
|
||||
|
||||
# Decode latents to video
|
||||
video = vae_decoder(latents, debug=True)
|
||||
print("Decoding video...")
|
||||
video = vae_decoder(latents)
|
||||
mx.eval(video)
|
||||
print(f" Video shape: {video.shape}")
|
||||
mx.clear_cache()
|
||||
|
||||
# Convert to frames
|
||||
# Convert to uint8 frames
|
||||
video = mx.squeeze(video, axis=0) # (C, F, H, W)
|
||||
|
||||
# Debug: check raw RGB values before conversion
|
||||
video_raw = np.array(video)
|
||||
print(f" Raw video per-channel means: R={video_raw[0].mean():.4f}, G={video_raw[1].mean():.4f}, B={video_raw[2].mean():.4f}")
|
||||
print(f" Raw video range: [{video_raw.min():.4f}, {video_raw.max():.4f}]")
|
||||
|
||||
video = mx.transpose(video, (1, 2, 3, 0)) # (F, H, W, C)
|
||||
video = (video + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
||||
video = mx.clip(video, 0.0, 1.0)
|
||||
video = mx.clip((video + 1.0) / 2.0, 0.0, 1.0)
|
||||
video = (video * 255).astype(mx.uint8)
|
||||
video_np = np.array(video)
|
||||
|
||||
print(f" Converted video RGB means: R={video_np[:,:,:,0].mean():.1f}, G={video_np[:,:,:,1].mean():.1f}, B={video_np[:,:,:,2].mean():.1f}")
|
||||
# Save outputs
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save first frame
|
||||
output_path = Path("mlx_output_frame0_2.png")
|
||||
Image.fromarray(video_np[0]).save(output_path)
|
||||
print(f"\nSaved first frame to {output_path}")
|
||||
|
||||
# Save video
|
||||
try:
|
||||
import imageio
|
||||
video_path = "mlx_output_video_2.mp4"
|
||||
imageio.mimwrite(video_path, video_np, fps=24, codec='libx264')
|
||||
print(f"Saved video to {video_path}")
|
||||
imageio.mimwrite(str(output_path), video_np, fps=fps, codec='libx264')
|
||||
print(f"Saved video to {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Could not save video: {e}")
|
||||
|
||||
print("\nDone!")
|
||||
if save_frames:
|
||||
frames_dir = output_path.parent / f"{output_path.stem}_frames"
|
||||
frames_dir.mkdir(exist_ok=True)
|
||||
for i, frame in enumerate(video_np):
|
||||
Image.fromarray(frame).save(frames_dir / f"frame_{i:04d}.png")
|
||||
print(f"Saved {len(video_np)} frames to {frames_dir}")
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame)")
|
||||
|
||||
return video_np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate videos with MLX LTX-2",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python main.py --prompt "A cat walking on grass"
|
||||
python main.py --prompt "Ocean waves at sunset" --height 768 --width 768
|
||||
python main.py --prompt "..." --num-frames 65 --seed 123 --output my_video.mp4
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt", "-p",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Text description of the video to generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height", "-H",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output video height (default: 512, must be divisible by 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width", "-W",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output video width (default: 512, must be divisible by 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-frames", "-n",
|
||||
type=int,
|
||||
default=33,
|
||||
help="Number of frames (default: 33, must be 1 + 8*k)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", "-s",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducibility (default: 42)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=24,
|
||||
help="Frames per second for output video (default: 24)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o",
|
||||
type=str,
|
||||
default="output.mp4",
|
||||
help="Output video path (default: output.mp4)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-frames",
|
||||
action="store_true",
|
||||
help="Save individual frames as images"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-repo",
|
||||
type=str,
|
||||
default="Lightricks/LTX-2",
|
||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
model_repo=args.model_repo,
|
||||
prompt=args.prompt,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_frames=args.num_frames,
|
||||
seed=args.seed,
|
||||
fps=args.fps,
|
||||
output_path=args.output,
|
||||
save_frames=args.save_frames,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
start_time = time.time()
|
||||
main()
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
|
||||
Reference in New Issue
Block a user