Enhance precision in denoising functions by ensuring all latents and calculations are consistently handled in float32. Update model input casting and return types to maintain dtype integrity across audio and video processing. Add precision parameter to video generation for improved memory management.
This commit is contained in:
@@ -249,6 +249,11 @@ def denoise_distilled(
|
|||||||
if state is not None:
|
if state is not None:
|
||||||
latents = state.latent
|
latents = state.latent
|
||||||
|
|
||||||
|
# Keep latents in float32 throughout to avoid quantization noise accumulation.
|
||||||
|
latents = latents.astype(mx.float32)
|
||||||
|
if enable_audio:
|
||||||
|
audio_latents = audio_latents.astype(mx.float32)
|
||||||
|
|
||||||
desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]"
|
desc = "[cyan]Denoising A/V[/]" if enable_audio else "[cyan]Denoising[/]"
|
||||||
num_steps = len(sigmas) - 1
|
num_steps = len(sigmas) - 1
|
||||||
|
|
||||||
@@ -268,7 +273,8 @@ def denoise_distilled(
|
|||||||
|
|
||||||
b, c, f, h, w = latents.shape
|
b, c, f, h, w = latents.shape
|
||||||
num_tokens = f * h * w
|
num_tokens = f * h * w
|
||||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
# Cast to model dtype for transformer input
|
||||||
|
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
||||||
|
|
||||||
if state is not None:
|
if state is not None:
|
||||||
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
||||||
@@ -291,7 +297,7 @@ def denoise_distilled(
|
|||||||
if enable_audio:
|
if enable_audio:
|
||||||
ab, ac, at, af = audio_latents.shape
|
ab, ac, at, af = audio_latents.shape
|
||||||
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
|
audio_flat = mx.transpose(audio_latents, (0, 2, 1, 3))
|
||||||
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af))
|
audio_flat = mx.reshape(audio_flat, (ab, at, ac * af)).astype(dtype)
|
||||||
|
|
||||||
audio_modality = Modality(
|
audio_modality = Modality(
|
||||||
latent=audio_flat,
|
latent=audio_flat,
|
||||||
@@ -307,34 +313,36 @@ def denoise_distilled(
|
|||||||
if audio_velocity is not None:
|
if audio_velocity is not None:
|
||||||
mx.eval(audio_velocity)
|
mx.eval(audio_velocity)
|
||||||
|
|
||||||
velocity = mx.reshape(mx.transpose(velocity, (0, 2, 1)), (b, c, f, h, w))
|
# Compute denoised (x0) using per-token timesteps in float32
|
||||||
denoised = to_denoised(latents, velocity, sigma)
|
# x0 = latent - timestep * velocity
|
||||||
|
# For conditioned tokens (timestep=0): x0 = latent
|
||||||
|
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
|
||||||
|
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
||||||
|
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||||
|
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
|
||||||
|
x0_f32 = latents_flat_f32 - timesteps_f32 * velocity.astype(mx.float32)
|
||||||
|
denoised = mx.reshape(mx.transpose(x0_f32, (0, 2, 1)), (b, c, f, h, w))
|
||||||
|
|
||||||
audio_denoised = None
|
audio_denoised = None
|
||||||
if enable_audio and audio_velocity is not None:
|
if enable_audio and audio_velocity is not None:
|
||||||
ab, ac, at, af = audio_latents.shape
|
ab, ac, at, af = audio_latents.shape
|
||||||
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
|
audio_velocity = mx.reshape(audio_velocity, (ab, at, ac, af))
|
||||||
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
|
audio_velocity = mx.transpose(audio_velocity, (0, 2, 1, 3))
|
||||||
audio_denoised = to_denoised(audio_latents, audio_velocity, sigma)
|
audio_denoised = audio_latents - sigma_f32 * audio_velocity.astype(mx.float32)
|
||||||
|
|
||||||
if state is not None:
|
if state is not None:
|
||||||
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
|
denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask)
|
||||||
|
|
||||||
mx.eval(denoised)
|
mx.eval(denoised)
|
||||||
if audio_denoised is not None:
|
if audio_denoised is not None:
|
||||||
mx.eval(audio_denoised)
|
mx.eval(audio_denoised)
|
||||||
|
|
||||||
|
# Euler step in float32 (latents stay in float32)
|
||||||
if sigma_next > 0:
|
if sigma_next > 0:
|
||||||
# Compute Euler step in float32 for precision (matching PyTorch behavior)
|
|
||||||
latents_f32 = latents.astype(mx.float32)
|
|
||||||
denoised_f32 = denoised.astype(mx.float32)
|
|
||||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||||
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
|
||||||
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
|
|
||||||
if enable_audio and audio_denoised is not None:
|
if enable_audio and audio_denoised is not None:
|
||||||
audio_latents_f32 = audio_latents.astype(mx.float32)
|
audio_latents = audio_denoised + sigma_next_f32 * (audio_latents - audio_denoised) / sigma_f32
|
||||||
audio_denoised_f32 = audio_denoised.astype(mx.float32)
|
|
||||||
audio_latents = (audio_denoised_f32 + sigma_next_f32 * (audio_latents_f32 - audio_denoised_f32) / sigma_f32).astype(dtype)
|
|
||||||
else:
|
else:
|
||||||
latents = denoised
|
latents = denoised
|
||||||
if enable_audio and audio_denoised is not None:
|
if enable_audio and audio_denoised is not None:
|
||||||
@@ -346,7 +354,7 @@ def denoise_distilled(
|
|||||||
|
|
||||||
progress.advance(task)
|
progress.advance(task)
|
||||||
|
|
||||||
return latents, audio_latents if enable_audio else None
|
return latents.astype(dtype), audio_latents.astype(dtype) if enable_audio else None
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -371,6 +379,11 @@ def denoise_dev(
|
|||||||
if state is not None:
|
if state is not None:
|
||||||
latents = state.latent
|
latents = state.latent
|
||||||
|
|
||||||
|
# Keep latents in float32 throughout the denoising loop to avoid
|
||||||
|
# quantization noise accumulation over many steps.
|
||||||
|
# Model input is cast to model dtype; all denoising math stays in float32.
|
||||||
|
latents = latents.astype(mx.float32)
|
||||||
|
|
||||||
sigmas_list = sigmas.tolist()
|
sigmas_list = sigmas.tolist()
|
||||||
use_cfg = cfg_scale != 1.0
|
use_cfg = cfg_scale != 1.0
|
||||||
num_steps = len(sigmas_list) - 1
|
num_steps = len(sigmas_list) - 1
|
||||||
@@ -405,7 +418,8 @@ def denoise_dev(
|
|||||||
|
|
||||||
b, c, f, h, w = latents.shape
|
b, c, f, h, w = latents.shape
|
||||||
num_tokens = f * h * w
|
num_tokens = f * h * w
|
||||||
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
# Cast to model dtype for transformer input
|
||||||
|
latents_flat = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1)).astype(dtype)
|
||||||
|
|
||||||
if state is not None:
|
if state is not None:
|
||||||
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1))
|
||||||
@@ -427,6 +441,14 @@ def denoise_dev(
|
|||||||
)
|
)
|
||||||
velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
|
velocity_pos, _ = transformer(video=video_modality_pos, audio=None)
|
||||||
|
|
||||||
|
# Convert velocity to x0 (denoised) using per-token timesteps
|
||||||
|
# Matches PyTorch's X0Model: x0 = latent - timestep * velocity
|
||||||
|
# For conditioned tokens (timestep=0): x0 = latent (correct regardless of velocity)
|
||||||
|
# For unconditioned tokens (timestep=sigma): x0 = latent - sigma * velocity
|
||||||
|
latents_flat_f32 = mx.transpose(mx.reshape(latents, (b, c, -1)), (0, 2, 1))
|
||||||
|
timesteps_f32 = mx.expand_dims(timesteps.astype(mx.float32), axis=-1)
|
||||||
|
x0_pos_f32 = latents_flat_f32 - timesteps_f32 * velocity_pos.astype(mx.float32)
|
||||||
|
|
||||||
if use_cfg:
|
if use_cfg:
|
||||||
# Negative conditioning pass
|
# Negative conditioning pass
|
||||||
video_modality_neg = Modality(
|
video_modality_neg = Modality(
|
||||||
@@ -440,31 +462,34 @@ def denoise_dev(
|
|||||||
)
|
)
|
||||||
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
|
velocity_neg, _ = transformer(video=video_modality_neg, audio=None)
|
||||||
|
|
||||||
# Apply CFG
|
# Convert negative velocity to x0 using per-token timesteps
|
||||||
velocity_flat = velocity_pos + (cfg_scale - 1.0) * (velocity_pos - velocity_neg)
|
x0_neg_f32 = latents_flat_f32 - timesteps_f32 * velocity_neg.astype(mx.float32)
|
||||||
else:
|
|
||||||
velocity_flat = velocity_pos
|
|
||||||
|
|
||||||
velocity = mx.reshape(mx.transpose(velocity_flat, (0, 2, 1)), (b, c, f, h, w))
|
# Apply CFG to x0 predictions (matches PyTorch CFGGuider)
|
||||||
denoised = to_denoised(latents, velocity, sigma)
|
# For conditioned tokens: x0_pos = x0_neg = latent, so delta = 0
|
||||||
|
x0_guided_f32 = x0_pos_f32 + (cfg_scale - 1.0) * (x0_pos_f32 - x0_neg_f32)
|
||||||
|
else:
|
||||||
|
x0_guided_f32 = x0_pos_f32
|
||||||
|
|
||||||
|
# Reshape x0 from token space (b, tokens, c) to spatial (b, c, f, h, w)
|
||||||
|
denoised = mx.reshape(mx.transpose(x0_guided_f32, (0, 2, 1)), (b, c, f, h, w))
|
||||||
|
|
||||||
|
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
||||||
|
|
||||||
if state is not None:
|
if state is not None:
|
||||||
denoised = apply_denoise_mask(denoised, state.clean_latent, state.denoise_mask)
|
denoised = apply_denoise_mask(denoised, state.clean_latent.astype(mx.float32), state.denoise_mask)
|
||||||
|
|
||||||
|
# Euler step in float32 (latents stay in float32)
|
||||||
if sigma_next > 0:
|
if sigma_next > 0:
|
||||||
# Compute Euler step in float32 for precision (matching PyTorch behavior)
|
|
||||||
latents_f32 = latents.astype(mx.float32)
|
|
||||||
denoised_f32 = denoised.astype(mx.float32)
|
|
||||||
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
sigma_next_f32 = mx.array(sigma_next, dtype=mx.float32)
|
||||||
sigma_f32 = mx.array(sigma, dtype=mx.float32)
|
latents = denoised + sigma_next_f32 * (latents - denoised) / sigma_f32
|
||||||
latents = (denoised_f32 + sigma_next_f32 * (latents_f32 - denoised_f32) / sigma_f32).astype(dtype)
|
|
||||||
else:
|
else:
|
||||||
latents = denoised
|
latents = denoised
|
||||||
|
|
||||||
mx.eval(latents)
|
mx.eval(latents)
|
||||||
progress.advance(task)
|
progress.advance(task)
|
||||||
|
|
||||||
return latents
|
return latents.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
def denoise_dev_av(
|
def denoise_dev_av(
|
||||||
@@ -1055,9 +1080,8 @@ def generate_video(
|
|||||||
mx.clear_cache()
|
mx.clear_cache()
|
||||||
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
console.print("[green]✓[/] VAE encoder loaded and image encoded")
|
||||||
|
|
||||||
# Generate sigma schedule
|
# Generate sigma schedule (uses MAX_SHIFT_ANCHOR=4096 like the reference implementation)
|
||||||
# PyTorch LTX-2 does NOT pass the latent to the scheduler, so it uses
|
num_tokens = latent_frames * latent_h * latent_w
|
||||||
# the default MAX_SHIFT_ANCHOR (4096) for the shift calculation
|
|
||||||
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
sigmas = ltx2_scheduler(steps=num_inference_steps)
|
||||||
mx.eval(sigmas)
|
mx.eval(sigmas)
|
||||||
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
console.print(f"[dim]Sigma schedule: {sigmas[0].item():.4f} → {sigmas[-2].item():.4f} → {sigmas[-1].item():.4f}[/]")
|
||||||
@@ -1117,7 +1141,7 @@ def generate_video(
|
|||||||
latents, video_positions,
|
latents, video_positions,
|
||||||
video_embeddings_pos, video_embeddings_neg,
|
video_embeddings_pos, video_embeddings_neg,
|
||||||
transformer, sigmas, cfg_scale=cfg_scale,
|
transformer, sigmas, cfg_scale=cfg_scale,
|
||||||
cfg_rescale=cfg_rescale, verbose=verbose, state=video_state
|
verbose=verbose, state=video_state
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
# Load VAE decoder (for dev pipeline, loaded here instead of during upsampling)
|
||||||
@@ -1347,7 +1371,6 @@ Examples:
|
|||||||
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
|
parser.add_argument("--stream", action="store_true", help="Stream frames to output as they're decoded")
|
||||||
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
|
parser.add_argument("--audio", "-a", action="store_true", help="Enable synchronized audio generation")
|
||||||
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
|
parser.add_argument("--output-audio", type=str, default=None, help="Output audio path")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED
|
pipeline = PipelineType.DEV if args.pipeline == "dev" else PipelineType.DISTILLED
|
||||||
|
|||||||
Reference in New Issue
Block a user