Cast dtype to bf16 in video and audio generation processes

This commit is contained in:
Prince Canuma
2026-01-17 17:20:22 +01:00
parent 883c6b0ad8
commit 78244a2d66
3 changed files with 86 additions and 76 deletions

View File

@@ -1,9 +1,10 @@
import argparse import argparse
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional, List, Tuple from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
@@ -110,6 +111,7 @@ def create_position_grid(
# Convert temporal to time in seconds by dividing by fps # Convert temporal to time in seconds by dividing by fps
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
# Always return float32 for RoPE precision - bfloat16 causes quality degradation
return mx.array(pixel_coords, dtype=mx.float32) return mx.array(pixel_coords, dtype=mx.float32)
@@ -137,6 +139,7 @@ def denoise(
Denoised latent tensor Denoised latent tensor
""" """
# If state is provided, use its latent (which may have conditioning applied) # If state is provided, use its latent (which may have conditioning applied)
dtype = latents.dtype
if state is not None: if state is not None:
latents = state.latent latents = state.latent
@@ -154,11 +157,11 @@ def denoise(
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
# Per-token timesteps: sigma * mask # Per-token timesteps: sigma * mask (preserve dtype)
timesteps = sigma * denoise_mask_flat timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else: else:
# All tokens get the same timestep # All tokens get the same timestep (use latent dtype)
timesteps = mx.full((b, num_tokens), sigma) timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
video_modality = Modality( video_modality = Modality(
latent=latents_flat, latent=latents_flat,
@@ -181,8 +184,11 @@ def denoise(
mx.eval(denoised) mx.eval(denoised)
# Euler step (preserve dtype by converting Python floats to arrays)
if sigma_next > 0: if sigma_next > 0:
latents = denoised + sigma_next * (latents - denoised) / sigma sigma_next_arr = mx.array(sigma_next, dtype=dtype)
sigma_arr = mx.array(sigma, dtype=dtype)
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
else: else:
latents = denoised latents = denoised
mx.eval(latents) mx.eval(latents)
@@ -283,6 +289,7 @@ def generate_video(
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
model_dtype = text_embeddings.dtype # bfloat16 from text encoder
mx.eval(text_embeddings) mx.eval(text_embeddings)
del text_encoder del text_encoder
@@ -292,6 +299,8 @@ def generate_video(
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}") print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig( config = LTXModelConfig(
model_type=LTXModelType.VideoOnly, model_type=LTXModelType.VideoOnly,
@@ -310,7 +319,7 @@ def generate_video(
timestep_scale_multiplier=1000, timestep_scale_multiplier=1000,
) )
transformer = LTXModel(config) transformer = LTXModel(config)
transformer.load_weights(list(sanitized.items()), strict=False) transformer.load_weights(list(sanitized.items()), strict=False)
mx.eval(transformer.parameters()) mx.eval(transformer.parameters())
@@ -323,15 +332,15 @@ def generate_video(
mx.eval(vae_encoder.parameters()) mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution) # Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2) input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor) stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent) mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}") print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution) # Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width) input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor) stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent) mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}") print(f" Stage 2 image latent: {stage2_image_latent.shape}")
@@ -343,6 +352,7 @@ def generate_video(
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed) mx.random.seed(seed)
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
mx.eval(positions) mx.eval(positions)
@@ -353,24 +363,26 @@ def generate_video(
# Create initial state with zeros # Create initial state with zeros
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
state1 = LatentState( state1 = LatentState(
latent=mx.zeros(latent_shape), latent=mx.zeros(latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(latent_shape), clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent, latent=stage1_image_latent,
frame_idx=image_frame_idx, frame_idx=image_frame_idx,
strength=image_strength, strength=image_strength,
) )
state1 = apply_conditioning(state1, [conditioning]) state1 = apply_conditioning(state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 1, noise_scale = 1.0 (first sigma) # For Stage 1, noise_scale = 1.0 (first sigma)
noise = mx.random.normal(latent_shape) noise = mx.random.normal(latent_shape, dtype=model_dtype)
noise_scale = STAGE_1_SIGMAS[0] # 1.0 noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
scaled_mask = state1.denoise_mask * noise_scale scaled_mask = state1.denoise_mask * noise_scale
state1 = LatentState( state1 = LatentState(
latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask), latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state1.clean_latent, clean_latent=state1.clean_latent,
denoise_mask=state1.denoise_mask, denoise_mask=state1.denoise_mask,
) )
@@ -378,7 +390,7 @@ def generate_video(
mx.eval(latents) mx.eval(latents)
else: else:
# T2V: just use random noise # T2V: just use random noise
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype)
mx.eval(latents) mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
@@ -401,6 +413,7 @@ def generate_video(
# Stage 2: Refine at full resolution # Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
# Position grids stay float32 for RoPE precision
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
mx.eval(positions) mx.eval(positions)
@@ -411,7 +424,7 @@ def generate_video(
state2 = LatentState( state2 = LatentState(
latent=latents, # Start with upscaled latent latent=latents, # Start with upscaled latent
clean_latent=mx.zeros_like(latents), clean_latent=mx.zeros_like(latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent, latent=stage2_image_latent,
@@ -423,11 +436,11 @@ def generate_video(
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
# For Stage 2, noise_scale = stage_2_sigmas[0] # For Stage 2, noise_scale = stage_2_sigmas[0]
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise # Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
noise = mx.random.normal(latents.shape) noise = mx.random.normal(latents.shape).astype(model_dtype)
noise_scale = STAGE_2_SIGMAS[0] noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = state2.denoise_mask * noise_scale scaled_mask = state2.denoise_mask * noise_scale
state2 = LatentState( state2 = LatentState(
latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask), latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=state2.clean_latent, clean_latent=state2.clean_latent,
denoise_mask=state2.denoise_mask, denoise_mask=state2.denoise_mask,
) )
@@ -435,9 +448,10 @@ def generate_video(
mx.eval(latents) mx.eval(latents)
else: else:
# T2V: add noise to all frames for refinement # T2V: add noise to all frames for refinement
noise_scale = STAGE_2_SIGMAS[0] noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
noise = mx.random.normal(latents.shape) one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
latents = noise * noise_scale + latents * (1 - noise_scale) noise = mx.random.normal(latents.shape).astype(model_dtype)
latents = noise * noise_scale + latents * one_minus_scale
mx.eval(latents) mx.eval(latents)
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)

View File

@@ -3,7 +3,7 @@
import argparse import argparse
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional, List from typing import Optional
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
@@ -164,6 +164,7 @@ def denoise_av(
Returns: Returns:
Tuple of (video_latents, audio_latents) Tuple of (video_latents, audio_latents)
""" """
dtype = video_latents.dtype
# If video state is provided, use its latent # If video state is provided, use its latent
if video_state is not None: if video_state is not None:
video_latents = video_state.latent video_latents = video_state.latent
@@ -189,10 +190,10 @@ def denoise_av(
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
# Per-token timesteps: sigma * mask # Per-token timesteps: sigma * mask
video_timesteps = sigma * denoise_mask_flat video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
else: else:
# All tokens get the same timestep # All tokens get the same timestep
video_timesteps = mx.full((b, num_video_tokens), sigma) video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
video_modality = Modality( video_modality = Modality(
latent=video_flat, latent=video_flat,
@@ -205,7 +206,7 @@ def denoise_av(
audio_modality = Modality( audio_modality = Modality(
latent=audio_flat, latent=audio_flat,
timesteps=mx.full((ab, at), sigma), timesteps=mx.full((ab, at), sigma, dtype=dtype),
positions=audio_positions, positions=audio_positions,
context=audio_embeddings, context=audio_embeddings,
context_mask=None, context_mask=None,
@@ -230,10 +231,12 @@ def denoise_av(
mx.eval(video_denoised, audio_denoised) mx.eval(video_denoised, audio_denoised)
# Euler step # Euler step - use dtype-preserving arrays to avoid float32 promotion
if sigma_next > 0: if sigma_next > 0:
video_latents = video_denoised + sigma_next * (video_latents - video_denoised) / sigma sigma_next_arr = mx.array(sigma_next, dtype=dtype)
audio_latents = audio_denoised + sigma_next * (audio_latents - audio_denoised) / sigma sigma_arr = mx.array(sigma, dtype=dtype)
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
else: else:
video_latents = video_denoised video_latents = video_denoised
audio_latents = audio_denoised audio_latents = audio_denoised
@@ -435,6 +438,7 @@ def generate_video_with_audio(
# Get both video and audio embeddings # Get both video and audio embeddings
video_embeddings, audio_embeddings = text_encoder(prompt) video_embeddings, audio_embeddings = text_encoder(prompt)
model_dtype = video_embeddings.dtype # bfloat16 from text encoder
mx.eval(video_embeddings, audio_embeddings) mx.eval(video_embeddings, audio_embeddings)
del text_encoder del text_encoder
@@ -445,6 +449,9 @@ def generate_video_with_audio(
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
sanitized = sanitize_transformer_weights(raw_weights) sanitized = sanitize_transformer_weights(raw_weights)
# Convert transformer weights to bfloat16 for memory efficiency
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
config = LTXModelConfig( config = LTXModelConfig(
model_type=LTXModelType.AudioVideo, model_type=LTXModelType.AudioVideo,
num_attention_heads=32, num_attention_heads=32,
@@ -482,18 +489,16 @@ def generate_video_with_audio(
mx.eval(vae_encoder.parameters()) mx.eval(vae_encoder.parameters())
# Load and prepare image for stage 1 (half resolution) # Load and prepare image for stage 1 (half resolution)
input_image = load_image(image, height=height // 2, width=width // 2) input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
stage1_image_latent = vae_encoder(stage1_image_tensor) stage1_image_latent = vae_encoder(stage1_image_tensor)
mx.eval(stage1_image_latent) mx.eval(stage1_image_latent)
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
# Load and prepare image for stage 2 (full resolution) # Load and prepare image for stage 2 (full resolution)
input_image = load_image(image, height=height, width=width) input_image = load_image(image, height=height, width=width, dtype=model_dtype)
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
stage2_image_latent = vae_encoder(stage2_image_tensor) stage2_image_latent = vae_encoder(stage2_image_tensor)
mx.eval(stage2_image_latent) mx.eval(stage2_image_latent)
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
del vae_encoder del vae_encoder
mx.clear_cache() mx.clear_cache()
@@ -502,9 +507,10 @@ def generate_video_with_audio(
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
mx.random.seed(seed) mx.random.seed(seed)
# Create position grids # Create position grids - MUST stay float32 for RoPE precision
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations
audio_positions = create_audio_position_grid(1, audio_frames) video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32
audio_positions = create_audio_position_grid(1, audio_frames) # float32
mx.eval(video_positions, audio_positions) mx.eval(video_positions, audio_positions)
# Apply I2V conditioning for stage 1 if provided # Apply I2V conditioning for stage 1 if provided
@@ -513,9 +519,9 @@ def generate_video_with_audio(
if is_i2v and stage1_image_latent is not None: if is_i2v and stage1_image_latent is not None:
# PyTorch flow: create zeros -> apply conditioning -> apply noiser # PyTorch flow: create zeros -> apply conditioning -> apply noiser
video_state1 = LatentState( video_state1 = LatentState(
latent=mx.zeros(video_latent_shape), latent=mx.zeros(video_latent_shape, dtype=model_dtype),
clean_latent=mx.zeros(video_latent_shape), clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
latent=stage1_image_latent, latent=stage1_image_latent,
@@ -525,11 +531,11 @@ def generate_video_with_audio(
video_state1 = apply_conditioning(video_state1, [conditioning]) video_state1 = apply_conditioning(video_state1, [conditioning])
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
noise = mx.random.normal(video_latent_shape) noise = mx.random.normal(video_latent_shape).astype(model_dtype)
noise_scale = STAGE_1_SIGMAS[0] # 1.0 noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
scaled_mask = video_state1.denoise_mask * noise_scale scaled_mask = video_state1.denoise_mask * noise_scale
video_state1 = LatentState( video_state1 = LatentState(
latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask), latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=video_state1.clean_latent, clean_latent=video_state1.clean_latent,
denoise_mask=video_state1.denoise_mask, denoise_mask=video_state1.denoise_mask,
) )
@@ -537,11 +543,11 @@ def generate_video_with_audio(
mx.eval(video_latents) mx.eval(video_latents)
else: else:
# T2V: just use random noise # T2V: just use random noise
video_latents = mx.random.normal(video_latent_shape) video_latents = mx.random.normal(video_latent_shape).astype(model_dtype)
mx.eval(video_latents) mx.eval(video_latents)
# Audio always uses pure noise (no I2V for audio) # Audio always uses pure noise (no I2V for audio)
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)) audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
mx.eval(audio_latents) mx.eval(audio_latents)
# Stage 1 denoising # Stage 1 denoising
@@ -571,7 +577,8 @@ def generate_video_with_audio(
# Stage 2: Refine at full resolution # Stage 2: Refine at full resolution
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # Position grids stay float32 for RoPE precision
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32
mx.eval(video_positions) mx.eval(video_positions)
# Apply I2V conditioning for stage 2 if provided # Apply I2V conditioning for stage 2 if provided
@@ -581,7 +588,7 @@ def generate_video_with_audio(
video_state2 = LatentState( video_state2 = LatentState(
latent=video_latents, # Start with upscaled latent latent=video_latents, # Start with upscaled latent
clean_latent=mx.zeros_like(video_latents), clean_latent=mx.zeros_like(video_latents),
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
) )
conditioning = VideoConditionByLatentIndex( conditioning = VideoConditionByLatentIndex(
latent=stage2_image_latent, latent=stage2_image_latent,
@@ -591,11 +598,11 @@ def generate_video_with_audio(
video_state2 = apply_conditioning(video_state2, [conditioning]) video_state2 = apply_conditioning(video_state2, [conditioning])
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise # Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
video_noise = mx.random.normal(video_latents.shape) video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
noise_scale = STAGE_2_SIGMAS[0] noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
scaled_mask = video_state2.denoise_mask * noise_scale scaled_mask = video_state2.denoise_mask * noise_scale
video_state2 = LatentState( video_state2 = LatentState(
latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask), latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
clean_latent=video_state2.clean_latent, clean_latent=video_state2.clean_latent,
denoise_mask=video_state2.denoise_mask, denoise_mask=video_state2.denoise_mask,
) )
@@ -603,16 +610,18 @@ def generate_video_with_audio(
mx.eval(video_latents) mx.eval(video_latents)
# Audio still gets noise (no I2V for audio) # Audio still gets noise (no I2V for audio)
audio_noise = mx.random.normal(audio_latents.shape) audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(audio_latents) mx.eval(audio_latents)
else: else:
# T2V: add noise to all frames for refinement # T2V: add noise to all frames for refinement
noise_scale = STAGE_2_SIGMAS[0] noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
video_noise = mx.random.normal(video_latents.shape) one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
audio_noise = mx.random.normal(audio_latents.shape) video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale) audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) video_latents = video_noise * noise_scale + video_latents * one_minus_scale
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
mx.eval(video_latents, audio_latents) mx.eval(video_latents, audio_latents)
video_latents, audio_latents = denoise_av( video_latents, audio_latents = denoise_av(
@@ -671,27 +680,13 @@ def generate_video_with_audio(
vocoder = load_vocoder(model_path) vocoder = load_vocoder(model_path)
mx.eval(audio_decoder.parameters(), vocoder.parameters()) mx.eval(audio_decoder.parameters(), vocoder.parameters())
# Debug: check per-channel statistics are loaded
pcs = audio_decoder.per_channel_statistics
print(f"Per-channel stats: mean_of_means range=[{pcs._mean_of_means.min():.4f}, {pcs._mean_of_means.max():.4f}], std_of_means range=[{pcs._std_of_means.min():.4f}, {pcs._std_of_means.max():.4f}]")
# Debug: check audio latent statistics
print(f"Audio latents shape: {audio_latents.shape}")
print(f"Audio latents stats: min={audio_latents.min():.4f}, max={audio_latents.max():.4f}, mean={audio_latents.mean():.4f}, std={mx.std(audio_latents):.4f}")
mel_spectrogram = audio_decoder(audio_latents) mel_spectrogram = audio_decoder(audio_latents)
mx.eval(mel_spectrogram) mx.eval(mel_spectrogram)
print(f"Mel spectrogram shape: {mel_spectrogram.shape}")
print(f"Mel spectrogram stats: min={mel_spectrogram.min():.4f}, max={mel_spectrogram.max():.4f}, mean={mel_spectrogram.mean():.4f}")
# Audio decoder output is already in vocoder format (B, C, T, F) # Audio decoder output is already in vocoder format (B, C, T, F)
audio_waveform = vocoder(mel_spectrogram) audio_waveform = vocoder(mel_spectrogram)
mx.eval(audio_waveform) mx.eval(audio_waveform)
print(f"Audio waveform shape: {audio_waveform.shape}")
print(f"Audio waveform stats: min={audio_waveform.min():.4f}, max={audio_waveform.max():.4f}, mean={audio_waveform.mean():.4f}")
audio_np = np.array(audio_waveform) audio_np = np.array(audio_waveform)
if audio_np.ndim == 3: if audio_np.ndim == 3:
audio_np = audio_np[0] # Remove batch dim audio_np = audio_np[0] # Remove batch dim

View File

@@ -171,6 +171,7 @@ def load_image(
image_path: Union[str, Path], image_path: Union[str, Path],
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
dtype: mx.Dtype = mx.float32,
) -> mx.array: ) -> mx.array:
"""Load and preprocess an image for I2V conditioning. """Load and preprocess an image for I2V conditioning.
@@ -210,7 +211,7 @@ def load_image(
# Convert to numpy then MLX # Convert to numpy then MLX
image_np = np.array(image).astype(np.float32) / 255.0 image_np = np.array(image).astype(np.float32) / 255.0
return mx.array(image_np) return mx.array(image_np, dtype=dtype)
def resize_image_aspect_ratio( def resize_image_aspect_ratio(