Refactor generate.py to ensure temporal coordinates and position grids are processed in bfloat16 for consistency with PyTorch's precision behavior. Update denoise_dev_av function to apply standard ratio rescaling for audio and video guidance, enhancing numerical fidelity and model compatibility.

This commit is contained in:
Prince Canuma
2026-03-12 21:26:38 +01:00
parent b07b1e3213
commit e0aafd72fc
3 changed files with 36 additions and 12 deletions

View File

@@ -236,15 +236,16 @@ def create_position_grid(
a_max=None
)
# Compute temporal division in bfloat16 to match PyTorch's precision behavior
# This ensures RoPE frequencies are computed identically to the reference implementation
temporal_coords = mx.array(pixel_coords[:, 0, :, :], dtype=mx.bfloat16)
fps_bf16 = mx.array(fps, dtype=mx.bfloat16)
temporal_coords = temporal_coords / fps_bf16
mx.eval(temporal_coords)
pixel_coords[:, 0, :, :] = np.array(temporal_coords.astype(mx.float32))
# Divide temporal coords by fps
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
return mx.array(pixel_coords, dtype=mx.float32)
# Cast entire position grid through bfloat16 to match PyTorch's behavior.
# PyTorch does: positions = positions.to(bfloat16) on ALL coordinates before
# passing to the transformer/RoPE. This quantization is what the model was
# trained with, so we must replicate it for numerical fidelity.
positions_bf16 = mx.array(pixel_coords, dtype=mx.bfloat16)
mx.eval(positions_bf16)
return positions_bf16.astype(mx.float32)
def create_audio_position_grid(
@@ -270,7 +271,10 @@ def create_audio_position_grid(
positions = positions[np.newaxis, np.newaxis, :, :]
positions = np.tile(positions, (batch_size, 1, 1, 1))
return mx.array(positions, dtype=mx.float32)
# Cast through bfloat16 to match PyTorch's precision behavior
positions_bf16 = mx.array(positions, dtype=mx.bfloat16)
mx.eval(positions_bf16)
return positions_bf16.astype(mx.float32)
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
@@ -735,10 +739,16 @@ def denoise_dev_av(
# Always use standard CFG for audio
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
# Apply CFG rescale if enabled
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation)
# factor = rescale * (cond_std / pred_std) + (1 - rescale)
# pred = pred * factor
if cfg_rescale > 0.0:
video_x0_guided_f32 = cfg_rescale * video_x0_pos_f32 + (1.0 - cfg_rescale) * video_x0_guided_f32
audio_x0_guided_f32 = cfg_rescale * audio_x0_pos_f32 + (1.0 - cfg_rescale) * audio_x0_guided_f32
v_factor = video_x0_pos_f32.std() / (video_x0_guided_f32.std() + 1e-8)
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
video_x0_guided_f32 = video_x0_guided_f32 * v_factor
a_factor = audio_x0_pos_f32.std() / (audio_x0_guided_f32.std() + 1e-8)
a_factor = cfg_rescale * a_factor + (1.0 - cfg_rescale)
audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor
else:
video_x0_guided_f32 = video_x0_pos_f32
audio_x0_guided_f32 = audio_x0_pos_f32