Refactor generate.py to ensure temporal coordinates and position grids are processed in bfloat16 for consistency with PyTorch's precision behavior. Update denoise_dev_av function to apply standard ratio rescaling for audio and video guidance, enhancing numerical fidelity and model compatibility.
This commit is contained in:
@@ -399,6 +399,14 @@ def precompute_freqs_cis(
|
||||
num_attention_heads, rope_type
|
||||
)
|
||||
|
||||
# Cast positions to bfloat16 to match PyTorch's behavior.
|
||||
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
|
||||
# generate_freqs computation — fractional positions, scaling, etc. are all
|
||||
# computed in bfloat16. The multiplication with float32 freq_indices then
|
||||
# upcasts to float32. This precision behavior is what the model was trained
|
||||
# with, so we must replicate it.
|
||||
indices_grid = indices_grid.astype(mx.bfloat16)
|
||||
|
||||
# Generate frequency indices
|
||||
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user