Refactor generate.py to ensure temporal coordinates and position grids are processed in bfloat16 for consistency with PyTorch's precision behavior. Update denoise_dev_av function to apply standard ratio rescaling for audio and video guidance, enhancing numerical fidelity and model compatibility.

This commit is contained in:
Prince Canuma
2026-03-12 21:26:38 +01:00
parent b07b1e3213
commit e0aafd72fc
3 changed files with 36 additions and 12 deletions

View File

@@ -147,6 +147,12 @@ class LTXModelConfig(BaseModelConfig):
if self.audio_positional_embedding_max_pos is None:
self.audio_positional_embedding_max_pos = [20]
# PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision"
# instead of "rope_double_precision" from the config, so double_precision_rope
# is always False in PyTorch regardless of what the config file says. Since the
# model was trained with this behavior, we must match it.
self.double_precision_rope = False
# Convert string enum values if loading from dict
if isinstance(self.model_type, str):
self.model_type = LTXModelType(self.model_type)

View File

@@ -399,6 +399,14 @@ def precompute_freqs_cis(
num_attention_heads, rope_type
)
# Cast positions to bfloat16 to match PyTorch's behavior.
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
# generate_freqs computation — fractional positions, scaling, etc. are all
# computed in bfloat16. The multiplication with float32 freq_indices then
# upcasts to float32. This precision behavior is what the model was trained
# with, so we must replicate it.
indices_grid = indices_grid.astype(mx.bfloat16)
# Generate frequency indices
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)