Enhance precision in denoising functions by ensuring all latents and calculations are consistently handled in float32. Update model input casting and return types to maintain dtype integrity across audio and video processing. Add precision parameter to video generation for improved memory management.

This commit is contained in:
Prince Canuma
2026-01-24 15:40:42 +01:00
parent cb2d19c84d
commit 87962c7f83

View File

@@ -249,6 +249,11 @@ def denoise_distilled(
if state is not None: if state is not None:
latents = state.latent latents = state.latent
# Keep latents in float32 throughout to avoid quantization noise accumulation.
latents = latents.astype(mx.float32)
if enable_audio:
audio_latents = audio_latents.astype(mx.float32)
desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]" desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]"
num_steps = len(sigmas) - 1 num_steps = len(sigmas) - 1
@@ -268,7 +273,8 @@ def denoise_distilled(
b, c, f, h, w = latents.shape b, c, f, h, w = latents.shape
num_tokens = f * h * w num_tokens = f * h * w
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) # Cast to model dtype for transformer input
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
if state is not None: if state is not None:
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
@@ -291,7 +297,7 @@ def denoise_distilled(
if enable_audio: if enable_audio:
ab, ac, at, af = audio_latents.shape ab, ac, at, af = audio_latents.shape
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3)) audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)) audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
audio_modality = Modality( audio_modality = Modality(
latent=audio_flat, latent=audio_flat,
@@ -307,34 +313,36 @@ def denoise_distilled(
if audio_velocity is not None: if audio_velocity is not None:
mx.eval(audio_velocity) mx.eval(audio_velocity)
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w)) # Compute denoised (x0) using per-token timesteps in float32
denoised = to_denoised(latents, velocity, sigma) # x0 = latent - timestep * velocity
# For conditioned tokens (timestep=0): x0 = latent
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
sigma_f32 = mx.array(sigma, dtype=mx.float32)
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
x0_f32 = latents_flat_f32 - timesteps_f32 * velocity.astype(mx.float32)
denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w))
audio_denoised = None audio_denoised = None
if enable_audio and audio_velocity is not None: if enable_audio and audio_velocity is not None:
ab, ac, at, af = audio_latents.shape ab, ac, at, af = audio_latents.shape
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af)) audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3)) audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma) audio_denoised = audio_latents - sigma_f32 * audio_velocity.astype(mx.float32)
if state is not None: if state is not None:
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask)
mx.eval(denoised) mx.eval(denoised)
if audio_denoised is not None: if audio_denoised is not None:
mx.eval(audio_denoised) mx.eval(audio_denoised)
# Euler step in float32 (latents stay in float32)
if sigma_next > 0: if sigma_next > 0:
# Compute Euler step in float32 for precision (matching PyTorch behavior)
latents_f32 = latents.astype(mx.float32)
denoised_f32 = denoised.astype(mx.float32)
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
sigma_f32 = mx.array(sigma, dtype=mx.float32) latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
if enable_audio and audio_denoised is not None: if enable_audio and audio_denoised is not None:
audio_latents_f32 = audio_latents.astype(mx.float32) audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32
audio_denoised_f32 = audio_denoised.astype(mx.float32)
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
else: else:
latents = denoised latents = denoised
if enable_audio and audio_denoised is not None: if enable_audio and audio_denoised is not None:
@@ -346,7 +354,7 @@ def denoise_distilled(
progress.advance(task) progress.advance(task)
return latents, audio_latents if enable_audio else None return latents.astype(dtype), audio_latents.astype(dtype) if enable_audio else None
# ============================================================================= # =============================================================================
@@ -371,6 +379,11 @@ def denoise_dev(
if state is not None: if state is not None:
latents = state.latent latents = state.latent
# Keep latents in float32 throughout the denoising loop to avoid
# quantization noise accumulation over many steps.
# Model input is cast to model dtype; all denoising math stays in float32.
latents = latents.astype(mx.float32)
sigmas_list = sigmas.tolist() sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0 use_cfg = cfg_scale != 1.0
num_steps = len(sigmas_list) - 1 num_steps = len(sigmas_list) - 1
@@ -405,7 +418,8 @@ def denoise_dev(
b, c, f, h, w = latents.shape b, c, f, h, w = latents.shape
num_tokens = f * h * w num_tokens = f * h * w
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)) # Cast to model dtype for transformer input
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
if state is not None: if state is not None:
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
@@ -427,6 +441,14 @@ def denoise_dev(
) )
velocity_pos, _ = transformer(video=video_modality_pos, audio=None) velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
# Convert velocity to x0 (denoised) using per-token timesteps
# Matches PyTorch's X0Model: x0 = latent - timestep * velocity
# For conditioned tokens (timestep=0): x0 = latent (correct regardless of velocity)
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32)
if use_cfg: if use_cfg:
# Negative conditioning pass # Negative conditioning pass
video_modality_neg = Modality( video_modality_neg = Modality(
@@ -440,31 +462,34 @@ def denoise_dev(
) )
velocity_neg, _ = transformer(video=video_modality_neg, audio=None) velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
# Apply CFG # Convert negative velocity to x0 using per-token timesteps
velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg) x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32)
else:
velocity_flat = velocity_pos
velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w)) # Apply CFG to x0 predictions (matches PyTorch CFGGuider)
denoised = to_denoised(latents, velocity, sigma) # For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
else:
x0_guided_f32 = x0_pos_f32
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
sigma_f32 = mx.array(sigma, dtype=mx.float32)
if state is not None: if state is not None:
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask) denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask)
# Euler step in float32 (latents stay in float32)
if sigma_next > 0: if sigma_next > 0:
# Compute Euler step in float32 for precision (matching PyTorch behavior)
latents_f32 = latents.astype(mx.float32)
denoised_f32 = denoised.astype(mx.float32)
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32) sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
sigma_f32 = mx.array(sigma, dtype=mx.float32) latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
else: else:
latents = denoised latents = denoised
mx.eval(latents) mx.eval(latents)
progress.advance(task) progress.advance(task)
return latents return latents.astype(dtype)
def denoise_dev_av( def denoise_dev_av(
@@ -1055,9 +1080,8 @@ def generate_video(
mx.clear_cache() mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded") console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Generate sigma schedule # Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation)
# PyTorch LTX-2 does NOT pass the latent to the scheduler, so it uses num_tokens = latent_frames * latent_h * latent_w
# the default MAX_SHIFT_ANCHOR (4096) for the shift calculation
sigmas = ltx2_scheduler(steps=num_inference_steps) sigmas = ltx2_scheduler(steps=num_inference_steps)
mx.eval(sigmas) mx.eval(sigmas)
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]") console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]")
@@ -1117,7 +1141,7 @@ def generate_video(
latents, video_positions, latents, video_positions,
video_embeddings_pos, video_embeddings_neg, video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale, transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, state=video_state verbose=verbose, state=video_state
) )
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling) # Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
@@ -1347,7 +1371,6 @@ Examples:
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded") parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation") parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path") parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
args = parser.parse_args() args = parser.parse_args()
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED