Add frame number validation in video generation and update Gemma3 text encoder to use validated mlx-vlm implementation

This commit is contained in:
Prince Canuma
2026-01-13 17:12:11 +01:00
parent 61b003ff2c
commit 01d895bc77
4 changed files with 1546 additions and 197 deletions

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import mlx.core as mx
import numpy as np
from PIL import Image
from tqdm import tqdm
# ANSI color codes
class Colors:
@@ -113,9 +114,10 @@ def denoise(
text_embeddings: mx.array,
transformer: LTXModel,
sigmas: list,
verbose: bool = True,
) -> mx.array:
"""Run denoising loop."""
for i in range(len(sigmas) - 1):
for i in tqdm(range(len(sigmas) - 1), desc="Denoising", disable=not verbose):
sigma, sigma_next = sigmas[i], sigmas[i + 1]
b, c, f, h, w = latents.shape
@@ -156,6 +158,7 @@ def generate_video(
fps: int = 24,
output_path: str = "output.mp4",
save_frames: bool = False,
verbose: bool = True,
):
"""Generate video from text prompt.
@@ -174,6 +177,12 @@ def generate_video(
# Validate dimensions
assert height % 64 == 0, f"Height must be divisible by 64, got {height}"
assert width % 64 == 0, f"Width must be divisible by 64, got {width}"
if num_frames % 8 != 1:
adjusted_num_frames = round((num_frames - 1) / 8) * 8 + 1
print(f"{Colors.YELLOW}⚠️ Number of frames must be 1 + 8*k. Using nearest valid value: {adjusted_num_frames}{Colors.RESET}")
num_frames = adjusted_num_frames
print(f"{Colors.BOLD}{Colors.CYAN}🎬 Generating {width}x{height} video with {num_frames} frames{Colors.RESET}")
print(f"{Colors.DIM}Prompt: {prompt[:80]}{'...' if len(prompt) > 80 else ''}{Colors.RESET}")
@@ -236,7 +245,7 @@ def generate_video(
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose)
# Upsample latents
print(f"{Colors.MAGENTA}🔍 Upsampling latents 2x...{Colors.RESET}")
@@ -265,7 +274,7 @@ def generate_video(
latents = noise * noise_scale + latents * (1 - noise_scale)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose)
del transformer
mx.clear_cache()
@@ -295,7 +304,7 @@ def generate_video(
for frame in video_np:
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
out.release()
print(f"{Colors.GREEN}✅ Saved video to {output_path}{Colors.RESET}")
print(f"{Colors.GREEN}✅ Saved video to{Colors.RESET} {output_path}")
except Exception as e:
print(f"{Colors.RED}❌ Could not save video: {e}{Colors.RESET}")
@@ -308,6 +317,7 @@ def generate_video(
elapsed = time.time() - start_time
print(f"{Colors.BOLD}{Colors.GREEN}🎉 Done! Generated in {elapsed:.1f}s ({elapsed/num_frames:.2f}s/frame){Colors.RESET}")
print(f"{Colors.BOLD}{Colors.GREEN} ✨ Peak memory: {mx.get_peak_memory() / (1024 ** 3):.2f}GB{Colors.RESET}")
return video_np
@@ -377,6 +387,11 @@ Examples:
default="Lightricks/LTX-2",
help="Model repository to use (default: Lightricks/LTX-2)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Verbose output"
)
args = parser.parse_args()
generate_video(
@@ -389,6 +404,7 @@ Examples:
fps=args.fps,
output_path=args.output,
save_frames=args.save_frames,
verbose=args.verbose,
)