Cast dtype to bf16 in video and audio generation processes
This commit is contained in:
@@ -1,9 +1,10 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -110,6 +111,7 @@ def create_position_grid(
|
|||||||
# Convert temporal to time in seconds by dividing by fps
|
# Convert temporal to time in seconds by dividing by fps
|
||||||
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
||||||
|
|
||||||
|
# Always return float32 for RoPE precision - bfloat16 causes quality degradation
|
||||||
return mx.array(pixel_coords, dtype=mx.float32)
|
return mx.array(pixel_coords, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
@@ -137,6 +139,7 @@ def denoise(
|
|||||||
Denoised latent tensor
|
Denoised latent tensor
|
||||||
"""
|
"""
|
||||||
# If state is provided, use its latent (which may have conditioning applied)
|
# If state is provided, use its latent (which may have conditioning applied)
|
||||||
|
dtype = latents.dtype
|
||||||
if state is not None:
|
if state is not None:
|
||||||
latents = state.latent
|
latents = state.latent
|
||||||
|
|
||||||
@@ -154,11 +157,11 @@ def denoise(
|
|||||||
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
||||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens))
|
||||||
# Per-token timesteps: sigma * mask
|
# Per-token timesteps: sigma * mask (preserve dtype)
|
||||||
timesteps = sigma * denoise_mask_flat
|
timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||||
else:
|
else:
|
||||||
# All tokens get the same timestep
|
# All tokens get the same timestep (use latent dtype)
|
||||||
timesteps = mx.full((b, num_tokens), sigma)
|
timesteps = mx.full((b, num_tokens), sigma, dtype=dtype)
|
||||||
|
|
||||||
video_modality = Modality(
|
video_modality = Modality(
|
||||||
latent=latents_flat,
|
latent=latents_flat,
|
||||||
@@ -181,8 +184,11 @@ def denoise(
|
|||||||
|
|
||||||
mx.eval(denoised)
|
mx.eval(denoised)
|
||||||
|
|
||||||
|
# Euler step (preserve dtype by converting Python floats to arrays)
|
||||||
if sigma_next > 0:
|
if sigma_next > 0:
|
||||||
latents = denoised + sigma_next * (latents - denoised) / sigma
|
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||||
|
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||||
|
latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr
|
||||||
else:
|
else:
|
||||||
latents = denoised
|
latents = denoised
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
@@ -283,6 +289,7 @@ def generate_video(
|
|||||||
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}")
|
||||||
|
|
||||||
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
|
text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False)
|
||||||
|
model_dtype = text_embeddings.dtype # bfloat16 from text encoder
|
||||||
mx.eval(text_embeddings)
|
mx.eval(text_embeddings)
|
||||||
|
|
||||||
del text_encoder
|
del text_encoder
|
||||||
@@ -292,6 +299,8 @@ def generate_video(
|
|||||||
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
|
print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}")
|
||||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||||
sanitized = sanitize_transformer_weights(raw_weights)
|
sanitized = sanitize_transformer_weights(raw_weights)
|
||||||
|
# Convert transformer weights to bfloat16 for memory efficiency
|
||||||
|
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||||
|
|
||||||
config = LTXModelConfig(
|
config = LTXModelConfig(
|
||||||
model_type=LTXModelType.VideoOnly,
|
model_type=LTXModelType.VideoOnly,
|
||||||
@@ -310,7 +319,7 @@ def generate_video(
|
|||||||
timestep_scale_multiplier=1000,
|
timestep_scale_multiplier=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformer = LTXModel(config)
|
transformer = LTXModel(config)
|
||||||
transformer.load_weights(list(sanitized.items()), strict=False)
|
transformer.load_weights(list(sanitized.items()), strict=False)
|
||||||
mx.eval(transformer.parameters())
|
mx.eval(transformer.parameters())
|
||||||
|
|
||||||
@@ -323,15 +332,15 @@ def generate_video(
|
|||||||
mx.eval(vae_encoder.parameters())
|
mx.eval(vae_encoder.parameters())
|
||||||
|
|
||||||
# Load and prepare image for stage 1 (half resolution)
|
# Load and prepare image for stage 1 (half resolution)
|
||||||
input_image = load_image(image, height=height // 2, width=width // 2)
|
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
|
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||||
mx.eval(stage1_image_latent)
|
mx.eval(stage1_image_latent)
|
||||||
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
||||||
|
|
||||||
# Load and prepare image for stage 2 (full resolution)
|
# Load and prepare image for stage 2 (full resolution)
|
||||||
input_image = load_image(image, height=height, width=width)
|
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
|
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||||
mx.eval(stage2_image_latent)
|
mx.eval(stage2_image_latent)
|
||||||
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
||||||
@@ -343,6 +352,7 @@ def generate_video(
|
|||||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
# Position grids stay float32 for RoPE precision
|
||||||
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
positions = create_position_grid(1, latent_frames, stage1_h, stage1_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
@@ -353,24 +363,26 @@ def generate_video(
|
|||||||
# Create initial state with zeros
|
# Create initial state with zeros
|
||||||
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
latent_shape = (1, 128, latent_frames, stage1_h, stage1_w)
|
||||||
state1 = LatentState(
|
state1 = LatentState(
|
||||||
latent=mx.zeros(latent_shape),
|
latent=mx.zeros(latent_shape, dtype=model_dtype),
|
||||||
clean_latent=mx.zeros(latent_shape),
|
clean_latent=mx.zeros(latent_shape, dtype=model_dtype),
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||||
)
|
)
|
||||||
conditioning = VideoConditionByLatentIndex(
|
conditioning = VideoConditionByLatentIndex(
|
||||||
latent=stage1_image_latent,
|
latent=stage1_image_latent,
|
||||||
frame_idx=image_frame_idx,
|
frame_idx=image_frame_idx,
|
||||||
strength=image_strength,
|
strength=image_strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
state1 = apply_conditioning(state1, [conditioning])
|
state1 = apply_conditioning(state1, [conditioning])
|
||||||
|
|
||||||
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||||
# For Stage 1, noise_scale = 1.0 (first sigma)
|
# For Stage 1, noise_scale = 1.0 (first sigma)
|
||||||
noise = mx.random.normal(latent_shape)
|
noise = mx.random.normal(latent_shape, dtype=model_dtype)
|
||||||
noise_scale = STAGE_1_SIGMAS[0] # 1.0
|
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
|
||||||
scaled_mask = state1.denoise_mask * noise_scale
|
scaled_mask = state1.denoise_mask * noise_scale
|
||||||
|
|
||||||
state1 = LatentState(
|
state1 = LatentState(
|
||||||
latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask),
|
latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||||
clean_latent=state1.clean_latent,
|
clean_latent=state1.clean_latent,
|
||||||
denoise_mask=state1.denoise_mask,
|
denoise_mask=state1.denoise_mask,
|
||||||
)
|
)
|
||||||
@@ -378,7 +390,7 @@ def generate_video(
|
|||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
else:
|
else:
|
||||||
# T2V: just use random noise
|
# T2V: just use random noise
|
||||||
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w))
|
latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype)
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1)
|
||||||
@@ -401,6 +413,7 @@ def generate_video(
|
|||||||
|
|
||||||
# Stage 2: Refine at full resolution
|
# Stage 2: Refine at full resolution
|
||||||
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
||||||
|
# Position grids stay float32 for RoPE precision
|
||||||
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
positions = create_position_grid(1, latent_frames, stage2_h, stage2_w)
|
||||||
mx.eval(positions)
|
mx.eval(positions)
|
||||||
|
|
||||||
@@ -411,7 +424,7 @@ def generate_video(
|
|||||||
state2 = LatentState(
|
state2 = LatentState(
|
||||||
latent=latents, # Start with upscaled latent
|
latent=latents, # Start with upscaled latent
|
||||||
clean_latent=mx.zeros_like(latents),
|
clean_latent=mx.zeros_like(latents),
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||||
)
|
)
|
||||||
conditioning = VideoConditionByLatentIndex(
|
conditioning = VideoConditionByLatentIndex(
|
||||||
latent=stage2_image_latent,
|
latent=stage2_image_latent,
|
||||||
@@ -423,11 +436,11 @@ def generate_video(
|
|||||||
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||||
# For Stage 2, noise_scale = stage_2_sigmas[0]
|
# For Stage 2, noise_scale = stage_2_sigmas[0]
|
||||||
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
# Conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
||||||
noise = mx.random.normal(latents.shape)
|
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
||||||
noise_scale = STAGE_2_SIGMAS[0]
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
scaled_mask = state2.denoise_mask * noise_scale
|
scaled_mask = state2.denoise_mask * noise_scale
|
||||||
state2 = LatentState(
|
state2 = LatentState(
|
||||||
latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask),
|
latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||||
clean_latent=state2.clean_latent,
|
clean_latent=state2.clean_latent,
|
||||||
denoise_mask=state2.denoise_mask,
|
denoise_mask=state2.denoise_mask,
|
||||||
)
|
)
|
||||||
@@ -435,9 +448,10 @@ def generate_video(
|
|||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
else:
|
else:
|
||||||
# T2V: add noise to all frames for refinement
|
# T2V: add noise to all frames for refinement
|
||||||
noise_scale = STAGE_2_SIGMAS[0]
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
noise = mx.random.normal(latents.shape)
|
one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
latents = noise * noise_scale + latents * (1 - noise_scale)
|
noise = mx.random.normal(latents.shape).astype(model_dtype)
|
||||||
|
latents = noise * noise_scale + latents * one_minus_scale
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
|
|
||||||
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
|
latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -164,6 +164,7 @@ def denoise_av(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (video_latents, audio_latents)
|
Tuple of (video_latents, audio_latents)
|
||||||
"""
|
"""
|
||||||
|
dtype = video_latents.dtype
|
||||||
# If video state is provided, use its latent
|
# If video state is provided, use its latent
|
||||||
if video_state is not None:
|
if video_state is not None:
|
||||||
video_latents = video_state.latent
|
video_latents = video_state.latent
|
||||||
@@ -189,10 +190,10 @@ def denoise_av(
|
|||||||
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w))
|
||||||
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens))
|
||||||
# Per-token timesteps: sigma * mask
|
# Per-token timesteps: sigma * mask
|
||||||
video_timesteps = sigma * denoise_mask_flat
|
video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat
|
||||||
else:
|
else:
|
||||||
# All tokens get the same timestep
|
# All tokens get the same timestep
|
||||||
video_timesteps = mx.full((b, num_video_tokens), sigma)
|
video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype)
|
||||||
|
|
||||||
video_modality = Modality(
|
video_modality = Modality(
|
||||||
latent=video_flat,
|
latent=video_flat,
|
||||||
@@ -205,7 +206,7 @@ def denoise_av(
|
|||||||
|
|
||||||
audio_modality = Modality(
|
audio_modality = Modality(
|
||||||
latent=audio_flat,
|
latent=audio_flat,
|
||||||
timesteps=mx.full((ab, at), sigma),
|
timesteps=mx.full((ab, at), sigma, dtype=dtype),
|
||||||
positions=audio_positions,
|
positions=audio_positions,
|
||||||
context=audio_embeddings,
|
context=audio_embeddings,
|
||||||
context_mask=None,
|
context_mask=None,
|
||||||
@@ -230,10 +231,12 @@ def denoise_av(
|
|||||||
|
|
||||||
mx.eval(video_denoised, audio_denoised)
|
mx.eval(video_denoised, audio_denoised)
|
||||||
|
|
||||||
# Euler step
|
# Euler step - use dtype-preserving arrays to avoid float32 promotion
|
||||||
if sigma_next > 0:
|
if sigma_next > 0:
|
||||||
video_latents = video_denoised + sigma_next * (video_latents - video_denoised) / sigma
|
sigma_next_arr = mx.array(sigma_next, dtype=dtype)
|
||||||
audio_latents = audio_denoised + sigma_next * (audio_latents - audio_denoised) / sigma
|
sigma_arr = mx.array(sigma, dtype=dtype)
|
||||||
|
video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr
|
||||||
|
audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr
|
||||||
else:
|
else:
|
||||||
video_latents = video_denoised
|
video_latents = video_denoised
|
||||||
audio_latents = audio_denoised
|
audio_latents = audio_denoised
|
||||||
@@ -435,6 +438,7 @@ def generate_video_with_audio(
|
|||||||
|
|
||||||
# Get both video and audio embeddings
|
# Get both video and audio embeddings
|
||||||
video_embeddings, audio_embeddings = text_encoder(prompt)
|
video_embeddings, audio_embeddings = text_encoder(prompt)
|
||||||
|
model_dtype = video_embeddings.dtype # bfloat16 from text encoder
|
||||||
mx.eval(video_embeddings, audio_embeddings)
|
mx.eval(video_embeddings, audio_embeddings)
|
||||||
|
|
||||||
del text_encoder
|
del text_encoder
|
||||||
@@ -445,6 +449,9 @@ def generate_video_with_audio(
|
|||||||
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors'))
|
||||||
sanitized = sanitize_transformer_weights(raw_weights)
|
sanitized = sanitize_transformer_weights(raw_weights)
|
||||||
|
|
||||||
|
# Convert transformer weights to bfloat16 for memory efficiency
|
||||||
|
sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()}
|
||||||
|
|
||||||
config = LTXModelConfig(
|
config = LTXModelConfig(
|
||||||
model_type=LTXModelType.AudioVideo,
|
model_type=LTXModelType.AudioVideo,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
@@ -482,18 +489,16 @@ def generate_video_with_audio(
|
|||||||
mx.eval(vae_encoder.parameters())
|
mx.eval(vae_encoder.parameters())
|
||||||
|
|
||||||
# Load and prepare image for stage 1 (half resolution)
|
# Load and prepare image for stage 1 (half resolution)
|
||||||
input_image = load_image(image, height=height // 2, width=width // 2)
|
input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype)
|
||||||
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2)
|
stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype)
|
||||||
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
stage1_image_latent = vae_encoder(stage1_image_tensor)
|
||||||
mx.eval(stage1_image_latent)
|
mx.eval(stage1_image_latent)
|
||||||
print(f" Stage 1 image latent: {stage1_image_latent.shape}")
|
|
||||||
|
|
||||||
# Load and prepare image for stage 2 (full resolution)
|
# Load and prepare image for stage 2 (full resolution)
|
||||||
input_image = load_image(image, height=height, width=width)
|
input_image = load_image(image, height=height, width=width, dtype=model_dtype)
|
||||||
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width)
|
stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype)
|
||||||
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
stage2_image_latent = vae_encoder(stage2_image_tensor)
|
||||||
mx.eval(stage2_image_latent)
|
mx.eval(stage2_image_latent)
|
||||||
print(f" Stage 2 image latent: {stage2_image_latent.shape}")
|
|
||||||
|
|
||||||
del vae_encoder
|
del vae_encoder
|
||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
@@ -502,9 +507,10 @@ def generate_video_with_audio(
|
|||||||
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}")
|
||||||
mx.random.seed(seed)
|
mx.random.seed(seed)
|
||||||
|
|
||||||
# Create position grids
|
# Create position grids - MUST stay float32 for RoPE precision
|
||||||
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w)
|
# bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations
|
||||||
audio_positions = create_audio_position_grid(1, audio_frames)
|
video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32
|
||||||
|
audio_positions = create_audio_position_grid(1, audio_frames) # float32
|
||||||
mx.eval(video_positions, audio_positions)
|
mx.eval(video_positions, audio_positions)
|
||||||
|
|
||||||
# Apply I2V conditioning for stage 1 if provided
|
# Apply I2V conditioning for stage 1 if provided
|
||||||
@@ -513,9 +519,9 @@ def generate_video_with_audio(
|
|||||||
if is_i2v and stage1_image_latent is not None:
|
if is_i2v and stage1_image_latent is not None:
|
||||||
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
|
# PyTorch flow: create zeros -> apply conditioning -> apply noiser
|
||||||
video_state1 = LatentState(
|
video_state1 = LatentState(
|
||||||
latent=mx.zeros(video_latent_shape),
|
latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
||||||
clean_latent=mx.zeros(video_latent_shape),
|
clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype),
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||||
)
|
)
|
||||||
conditioning = VideoConditionByLatentIndex(
|
conditioning = VideoConditionByLatentIndex(
|
||||||
latent=stage1_image_latent,
|
latent=stage1_image_latent,
|
||||||
@@ -525,11 +531,11 @@ def generate_video_with_audio(
|
|||||||
video_state1 = apply_conditioning(video_state1, [conditioning])
|
video_state1 = apply_conditioning(video_state1, [conditioning])
|
||||||
|
|
||||||
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
# Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale)
|
||||||
noise = mx.random.normal(video_latent_shape)
|
noise = mx.random.normal(video_latent_shape).astype(model_dtype)
|
||||||
noise_scale = STAGE_1_SIGMAS[0] # 1.0
|
noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0
|
||||||
scaled_mask = video_state1.denoise_mask * noise_scale
|
scaled_mask = video_state1.denoise_mask * noise_scale
|
||||||
video_state1 = LatentState(
|
video_state1 = LatentState(
|
||||||
latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask),
|
latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||||
clean_latent=video_state1.clean_latent,
|
clean_latent=video_state1.clean_latent,
|
||||||
denoise_mask=video_state1.denoise_mask,
|
denoise_mask=video_state1.denoise_mask,
|
||||||
)
|
)
|
||||||
@@ -537,11 +543,11 @@ def generate_video_with_audio(
|
|||||||
mx.eval(video_latents)
|
mx.eval(video_latents)
|
||||||
else:
|
else:
|
||||||
# T2V: just use random noise
|
# T2V: just use random noise
|
||||||
video_latents = mx.random.normal(video_latent_shape)
|
video_latents = mx.random.normal(video_latent_shape).astype(model_dtype)
|
||||||
mx.eval(video_latents)
|
mx.eval(video_latents)
|
||||||
|
|
||||||
# Audio always uses pure noise (no I2V for audio)
|
# Audio always uses pure noise (no I2V for audio)
|
||||||
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS))
|
audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype)
|
||||||
mx.eval(audio_latents)
|
mx.eval(audio_latents)
|
||||||
|
|
||||||
# Stage 1 denoising
|
# Stage 1 denoising
|
||||||
@@ -571,7 +577,8 @@ def generate_video_with_audio(
|
|||||||
|
|
||||||
# Stage 2: Refine at full resolution
|
# Stage 2: Refine at full resolution
|
||||||
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}")
|
||||||
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w)
|
# Position grids stay float32 for RoPE precision
|
||||||
|
video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32
|
||||||
mx.eval(video_positions)
|
mx.eval(video_positions)
|
||||||
|
|
||||||
# Apply I2V conditioning for stage 2 if provided
|
# Apply I2V conditioning for stage 2 if provided
|
||||||
@@ -581,7 +588,7 @@ def generate_video_with_audio(
|
|||||||
video_state2 = LatentState(
|
video_state2 = LatentState(
|
||||||
latent=video_latents, # Start with upscaled latent
|
latent=video_latents, # Start with upscaled latent
|
||||||
clean_latent=mx.zeros_like(video_latents),
|
clean_latent=mx.zeros_like(video_latents),
|
||||||
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)),
|
denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype),
|
||||||
)
|
)
|
||||||
conditioning = VideoConditionByLatentIndex(
|
conditioning = VideoConditionByLatentIndex(
|
||||||
latent=stage2_image_latent,
|
latent=stage2_image_latent,
|
||||||
@@ -591,11 +598,11 @@ def generate_video_with_audio(
|
|||||||
video_state2 = apply_conditioning(video_state2, [conditioning])
|
video_state2 = apply_conditioning(video_state2, [conditioning])
|
||||||
|
|
||||||
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
# Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise
|
||||||
video_noise = mx.random.normal(video_latents.shape)
|
video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
|
||||||
noise_scale = STAGE_2_SIGMAS[0]
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
scaled_mask = video_state2.denoise_mask * noise_scale
|
scaled_mask = video_state2.denoise_mask * noise_scale
|
||||||
video_state2 = LatentState(
|
video_state2 = LatentState(
|
||||||
latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask),
|
latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask),
|
||||||
clean_latent=video_state2.clean_latent,
|
clean_latent=video_state2.clean_latent,
|
||||||
denoise_mask=video_state2.denoise_mask,
|
denoise_mask=video_state2.denoise_mask,
|
||||||
)
|
)
|
||||||
@@ -603,16 +610,18 @@ def generate_video_with_audio(
|
|||||||
mx.eval(video_latents)
|
mx.eval(video_latents)
|
||||||
|
|
||||||
# Audio still gets noise (no I2V for audio)
|
# Audio still gets noise (no I2V for audio)
|
||||||
audio_noise = mx.random.normal(audio_latents.shape)
|
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
||||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
|
||||||
|
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
||||||
mx.eval(audio_latents)
|
mx.eval(audio_latents)
|
||||||
else:
|
else:
|
||||||
# T2V: add noise to all frames for refinement
|
# T2V: add noise to all frames for refinement
|
||||||
noise_scale = STAGE_2_SIGMAS[0]
|
noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype)
|
||||||
video_noise = mx.random.normal(video_latents.shape)
|
one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale
|
||||||
audio_noise = mx.random.normal(audio_latents.shape)
|
video_noise = mx.random.normal(video_latents.shape).astype(model_dtype)
|
||||||
video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale)
|
audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype)
|
||||||
audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale)
|
video_latents = video_noise * noise_scale + video_latents * one_minus_scale
|
||||||
|
audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale
|
||||||
mx.eval(video_latents, audio_latents)
|
mx.eval(video_latents, audio_latents)
|
||||||
|
|
||||||
video_latents, audio_latents = denoise_av(
|
video_latents, audio_latents = denoise_av(
|
||||||
@@ -671,27 +680,13 @@ def generate_video_with_audio(
|
|||||||
vocoder = load_vocoder(model_path)
|
vocoder = load_vocoder(model_path)
|
||||||
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
mx.eval(audio_decoder.parameters(), vocoder.parameters())
|
||||||
|
|
||||||
# Debug: check per-channel statistics are loaded
|
|
||||||
pcs = audio_decoder.per_channel_statistics
|
|
||||||
print(f"Per-channel stats: mean_of_means range=[{pcs._mean_of_means.min():.4f}, {pcs._mean_of_means.max():.4f}], std_of_means range=[{pcs._std_of_means.min():.4f}, {pcs._std_of_means.max():.4f}]")
|
|
||||||
|
|
||||||
# Debug: check audio latent statistics
|
|
||||||
print(f"Audio latents shape: {audio_latents.shape}")
|
|
||||||
print(f"Audio latents stats: min={audio_latents.min():.4f}, max={audio_latents.max():.4f}, mean={audio_latents.mean():.4f}, std={mx.std(audio_latents):.4f}")
|
|
||||||
|
|
||||||
mel_spectrogram = audio_decoder(audio_latents)
|
mel_spectrogram = audio_decoder(audio_latents)
|
||||||
mx.eval(mel_spectrogram)
|
mx.eval(mel_spectrogram)
|
||||||
|
|
||||||
print(f"Mel spectrogram shape: {mel_spectrogram.shape}")
|
|
||||||
print(f"Mel spectrogram stats: min={mel_spectrogram.min():.4f}, max={mel_spectrogram.max():.4f}, mean={mel_spectrogram.mean():.4f}")
|
|
||||||
|
|
||||||
# Audio decoder output is already in vocoder format (B, C, T, F)
|
# Audio decoder output is already in vocoder format (B, C, T, F)
|
||||||
audio_waveform = vocoder(mel_spectrogram)
|
audio_waveform = vocoder(mel_spectrogram)
|
||||||
mx.eval(audio_waveform)
|
mx.eval(audio_waveform)
|
||||||
|
|
||||||
print(f"Audio waveform shape: {audio_waveform.shape}")
|
|
||||||
print(f"Audio waveform stats: min={audio_waveform.min():.4f}, max={audio_waveform.max():.4f}, mean={audio_waveform.mean():.4f}")
|
|
||||||
|
|
||||||
audio_np = np.array(audio_waveform)
|
audio_np = np.array(audio_waveform)
|
||||||
if audio_np.ndim == 3:
|
if audio_np.ndim == 3:
|
||||||
audio_np = audio_np[0] # Remove batch dim
|
audio_np = audio_np[0] # Remove batch dim
|
||||||
|
|||||||
@@ -171,6 +171,7 @@ def load_image(
|
|||||||
image_path: Union[str, Path],
|
image_path: Union[str, Path],
|
||||||
height: Optional[int] = None,
|
height: Optional[int] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Load and preprocess an image for I2V conditioning.
|
"""Load and preprocess an image for I2V conditioning.
|
||||||
|
|
||||||
@@ -210,7 +211,7 @@ def load_image(
|
|||||||
|
|
||||||
# Convert to numpy then MLX
|
# Convert to numpy then MLX
|
||||||
image_np = np.array(image).astype(np.float32) / 255.0
|
image_np = np.array(image).astype(np.float32) / 255.0
|
||||||
return mx.array(image_np)
|
return mx.array(image_np, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def resize_image_aspect_ratio(
|
def resize_image_aspect_ratio(
|
||||||
|
|||||||
Reference in New Issue
Block a user