Add frame number validation in video generation and update Gemma3 text encoder to use validated mlx-vlm implementation
This commit is contained in:
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# ANSI color codes
|
||||
class Colors:
|
||||
@@ -113,9 +114,10 @@ def denoise(
|
||||
text_embeddings: mx.array,
|
||||
transformer: LTXModel,
|
||||
sigmas: list,
|
||||
verbose: bool = True,
|
||||
) -> mx.array:
|
||||
"""Run denoising loop."""
|
||||
for i in range(len(sigmas) - 1):
|
||||
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
|
||||
sigma, sigma_next = sigmas[i], sigmas[i + 1]
|
||||
|
||||
b, c, f, h, w = latents.shape
|
||||
@@ -156,6 +158,7 @@ def generate_video(
|
||||
fps: int = 24,
|
||||
output_path: str = "output.mp4",
|
||||
save_frames: bool = False,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Generate video from text prompt.
|
||||
|
||||
@@ -174,6 +177,12 @@ def generate_video(
|
||||
# Validate dimensions
|
||||
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
|
||||
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
|
||||
|
||||
if num_frames % 8 != 1:
|
||||
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
|
||||
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
|
||||
num_frames = adjusted_num_frames
|
||||
|
||||
|
||||
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
|
||||
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
|
||||
@@ -236,7 +245,7 @@ def generate_video(
|
||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||
mx.eval(positions)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
|
||||
|
||||
# Upsample latents
|
||||
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
|
||||
@@ -265,7 +274,7 @@ def generate_video(
|
||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
||||
mx.eval(latents)
|
||||
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
|
||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
|
||||
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
@@ -295,7 +304,7 @@ def generate_video(
|
||||
for frame in video_np:
|
||||
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||
out.release()
|
||||
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}")
|
||||
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
|
||||
except Exception as e:
|
||||
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
|
||||
|
||||
@@ -308,6 +317,7 @@ def generate_video(
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
|
||||
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
|
||||
|
||||
return video_np
|
||||
|
||||
@@ -377,6 +387,11 @@ Examples:
|
||||
default="Lightricks/LTX-2",
|
||||
help="Model repository to use (default: Lightricks/LTX-2)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
generate_video(
|
||||
@@ -389,6 +404,7 @@ Examples:
|
||||
fps=args.fps,
|
||||
output_path=args.output,
|
||||
save_frames=args.save_frames,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user