Cast dtype to bf16 in video and audio generation processes

This commit is contained in:
Prince Canuma
2026-01-17 17:20:22 +01:00
parent 883c6b0ad8
commit 78244a2d66
3 changed files with 86 additions and 76 deletions

View File

@@ -1,9 +1,10 @@
import argparse
import time
from pathlib import Path
from typing import Optional, List, Tuple
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
@@ -110,6 +111,7 @@ def create_position_grid(
# Convert temporal to time in seconds by dividing by fps
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
# Always return float32 for RoPE precision - bfloat16 causes quality degradation
return mx.array(pixel_coords, dtype=mx.float32)
@@ -137,6 +139,7 @@ def denoise(
Denoised latent tensor
"""
# If state is provided, use its latent (which may have conditioning applied)
dtype = latents.dtype
if state is not None:
latents = state.latent
@@ -154,11 +157,11 @@ def denoise(
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
# Per-token timesteps: sigma * mask
timesteps = sigma * denoise_mask_flat
# Per-token timesteps: sigma * mask (preserve dtype)
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else:
# All tokens get the same timestep
timesteps = mx.full((b, num_tokens), sigma)
# All tokens get the same timestep (use latent dtype)
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
video_modality = Modality(
latent=latents_flat,
@@ -181,8 +184,11 @@ def denoise(
mx.eval(denoised)
# Euler step (preserve dtype by converting Python floats to arrays)
if sigma_next > 0:
latents = denoised + sigma_next * (latents - denoised) / sigma
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
else:
latents = denoised
mx.eval(latents)
@@ -283,6 +289,7 @@ def generate_video(
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
model_dtype = text_embeddings.dtype # bfloat16 from text encoder
mx.eval(text_embeddings)
del text_encoder
@@ -292,6 +299,8 @@ def generate_video(
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig(
model_type=LTXModelType.VideoOnly,
@@ -310,7 +319,7 @@ def generate_video(
timestep_scale_multiplier=1000,
)
transformer = LTXModel(config)
transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters())
@@ -323,15 +332,15 @@ def generate_video(
mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
@@ -343,6 +352,7 @@ def generate_video(
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed)
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions)
@@ -353,24 +363,26 @@ def generate_video(
# Create initial state with zeros
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
state1 = LatentState(
latent=mx.zeros(latent_shape),
clean_latent=mx.zeros(latent_shape),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
latent=mx.zeros(latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent,
frame_idx=image_frame_idx,
strength=image_strength,
)
state1 = apply_conditioning(state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 1, noise_scale = 1.0 (first sigma)
noise = mx.random.normal(latent_shape)
noise_scale = STAGE_1_SIGMAS[0] # 1.0
noise = mx.random.normal(latent_shape, dtype=model_dtype)
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
scaled_mask = state1.denoise_mask * noise_scale
state1 = LatentState(
latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask),
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state1.clean_latent,
denoise_mask=state1.denoise_mask,
)
@@ -378,7 +390,7 @@ def generate_video(
mx.eval(latents)
else:
# T2V: just use random noise
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype)
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
@@ -401,6 +413,7 @@ def generate_video(
# Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions)
@@ -411,7 +424,7 @@ def generate_video(
state2 = LatentState(
latent=latents, # Start with upscaled latent
clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
)
conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent,
@@ -423,11 +436,11 @@ def generate_video(
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 2, noise_scale = stage_2_sigmas[0]
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
noise = mx.random.normal(latents.shape)
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape).astype(model_dtype)
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = state2.denoise_mask * noise_scale
state2 = LatentState(
latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask),
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state2.clean_latent,
denoise_mask=state2.denoise_mask,
)
@@ -435,9 +448,10 @@ def generate_video(
mx.eval(latents)
else:
# T2V: add noise to all frames for refinement
noise_scale = STAGE_2_SIGMAS[0]
noise = mx.random.normal(latents.shape)
latents = noise * noise_scale + latents * (1 - noise_scale)
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
noise = mx.random.normal(latents.shape).astype(model_dtype)
latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)