Enhance precision in denoising functions by ensuring all latents and calculations are consistently handled in float32. Update model input casting and return types to maintain dtype integrity across audio and video processing. Add precision parameter to video generation for improved memory management.

This commit is contained in:
Prince Canuma
2026-01-24 15:40:42 +01:00
parent cb2d19c84d
commit 87962c7f83

View File

@@ -249,6 +249,11 @@ def denoise_distilled(
if state is not None:
latents = state.latent
# Keep latents in float32 throughout to avoid quantization noise accumulation.
latents = latents.astype(mx.float32)
if enable_audio:
audio_latents = audio_latents.astype(mx.float32)
desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]"
num_steps = len(sigmas) - 1
@@ -268,7 +273,8 @@ def denoise_distilled(
b, c, f, h, w = latents.shape
num_tokens = f * h * w
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
# Cast to model dtype for transformer input
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
if state is not None:
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
@@ -291,7 +297,7 @@ def denoise_distilled(
if enable_audio:
ab, ac, at, af = audio_latents.shape
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
audio_modality = Modality(
latent=audio_flat,
@@ -307,34 +313,36 @@ def denoise_distilled(
if audio_velocity is not None:
mx.eval(audio_velocity)
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
# Compute denoised (x0) using per-token timesteps in float32
# x0 = latent - timestep * velocity
# For conditioned tokens (timestep=0): x0 = latent
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
sigma_f32 = mx.array(sigma, dtype=mx.float32)
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
x0_f32 = latents_flat_f32 - timesteps_f32 * velocity.astype(mx.float32)
denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w))
audio_denoised = None
if enable_audio and audio_velocity is not None:
ab, ac, at, af = audio_latents.shape
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
audio_denoised = audio_latents - sigma_f32 * audio_velocity.astype(mx.float32)
if state is not None:
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask)
mx.eval(denoised)
if audio_denoised is not None:
mx.eval(audio_denoised)
# Euler step in float32 (latents stay in float32)
if sigma_next > 0:
# Compute Euler step in float32 for precision (matching PyTorch behavior)
latents_f32 = latents.astype(mx.float32)
denoised_f32 = denoised.astype(mx.float32)
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
sigma_f32 = mx.array(sigma, dtype=mx.float32)
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
if enable_audio and audio_denoised is not None:
audio_latents_f32 = audio_latents.astype(mx.float32)
audio_denoised_f32 = audio_denoised.astype(mx.float32)
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32
else:
latents = denoised
if enable_audio and audio_denoised is not None:
@@ -346,7 +354,7 @@ def denoise_distilled(
progress.advance(task)
return latents, audio_latents if enable_audio else None
return latents.astype(dtype), audio_latents.astype(dtype) if enable_audio else None
# =============================================================================
@@ -371,6 +379,11 @@ def denoise_dev(
if state is not None:
latents = state.latent
# Keep latents in float32 throughout the denoising loop to avoid
# quantization noise accumulation over many steps.
# Model input is cast to model dtype; all denoising math stays in float32.
latents = latents.astype(mx.float32)
sigmas_list = sigmas.tolist()
use_cfg = cfg_scale != 1.0
num_steps = len(sigmas_list) - 1
@@ -405,7 +418,8 @@ def denoise_dev(
b, c, f, h, w = latents.shape
num_tokens = f * h * w
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
# Cast to model dtype for transformer input
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
if state is not None:
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
@@ -427,6 +441,14 @@ def denoise_dev(
)
velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
# Convert velocity to x0 (denoised) using per-token timesteps
# Matches PyTorch's X0Model: x0 = latent - timestep * velocity
# For conditioned tokens (timestep=0): x0 = latent (correct regardless of velocity)
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32)
if use_cfg:
# Negative conditioning pass
video_modality_neg = Modality(
@@ -440,31 +462,34 @@ def denoise_dev(
)
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
# Apply CFG
velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg)
else:
velocity_flat = velocity_pos
# Convert negative velocity to x0 using per-token timesteps
x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32)
velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w))
denoised = to_denoised(latents, velocity, sigma)
# Apply CFG to x0 predictions (matches PyTorch CFGGuider)
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
else:
x0_guided_f32 = x0_pos_f32
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
sigma_f32 = mx.array(sigma, dtype=mx.float32)
if state is not None:
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask)
# Euler step in float32 (latents stay in float32)
if sigma_next > 0:
# Compute Euler step in float32 for precision (matching PyTorch behavior)
latents_f32 = latents.astype(mx.float32)
denoised_f32 = denoised.astype(mx.float32)
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
sigma_f32 = mx.array(sigma, dtype=mx.float32)
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
else:
latents = denoised
mx.eval(latents)
progress.advance(task)
return latents
return latents.astype(dtype)
def denoise_dev_av(
@@ -1055,9 +1080,8 @@ def generate_video(
mx.clear_cache()
console.print("[green]✓[/] VAE encoder loaded and image encoded")
# Generate sigma schedule
# PyTorch LTX-2 does NOT pass the latent to the scheduler, so it uses
# the default MAX_SHIFT_ANCHOR (4096) for the shift calculation
# Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation)
num_tokens = latent_frames * latent_h * latent_w
sigmas = ltx2_scheduler(steps=num_inference_steps)
mx.eval(sigmas)
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f}{sigmas[-2].item():.4f}{sigmas[-1].item():.4f}[/]")
@@ -1117,7 +1141,7 @@ def generate_video(
latents, video_positions,
video_embeddings_pos, video_embeddings_neg,
transformer, sigmas, cfg_scale=cfg_scale,
cfg_rescale=cfg_rescale, verbose=verbose, state=video_state
verbose=verbose, state=video_state
)
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
@@ -1347,7 +1371,6 @@ Examples:
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
args = parser.parse_args()
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED