Refactor generate.py to ensure temporal coordinates and position grids are processed in bfloat16 for consistency with PyTorch's precision behavior. Update denoise_dev_av function to apply standard ratio rescaling for audio and video guidance, enhancing numerical fidelity and model compatibility.
This commit is contained in:
@@ -236,15 +236,16 @@ def create_position_grid(
|
|||||||
a_max=None
|
a_max=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute temporal division in bfloat16 to match PyTorch's precision behavior
|
# Divide temporal coords by fps
|
||||||
# This ensures RoPE frequencies are computed identically to the reference implementation
|
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps
|
||||||
temporal_coords = mx.array(pixel_coords[:, 0, :, :], dtype=mx.bfloat16)
|
|
||||||
fps_bf16 = mx.array(fps, dtype=mx.bfloat16)
|
|
||||||
temporal_coords = temporal_coords / fps_bf16
|
|
||||||
mx.eval(temporal_coords)
|
|
||||||
pixel_coords[:, 0, :, :] = np.array(temporal_coords.astype(mx.float32))
|
|
||||||
|
|
||||||
return mx.array(pixel_coords, dtype=mx.float32)
|
# Cast entire position grid through bfloat16 to match PyTorch's behavior.
|
||||||
|
# PyTorch does: positions = positions.to(bfloat16) on ALL coordinates before
|
||||||
|
# passing to the transformer/RoPE. This quantization is what the model was
|
||||||
|
# trained with, so we must replicate it for numerical fidelity.
|
||||||
|
positions_bf16 = mx.array(pixel_coords, dtype=mx.bfloat16)
|
||||||
|
mx.eval(positions_bf16)
|
||||||
|
return positions_bf16.astype(mx.float32)
|
||||||
|
|
||||||
|
|
||||||
def create_audio_position_grid(
|
def create_audio_position_grid(
|
||||||
@@ -270,7 +271,10 @@ def create_audio_position_grid(
|
|||||||
positions = positions[np.newaxis, np.newaxis, :, :]
|
positions = positions[np.newaxis, np.newaxis, :, :]
|
||||||
positions = np.tile(positions, (batch_size, 1, 1, 1))
|
positions = np.tile(positions, (batch_size, 1, 1, 1))
|
||||||
|
|
||||||
return mx.array(positions, dtype=mx.float32)
|
# Cast through bfloat16 to match PyTorch's precision behavior
|
||||||
|
positions_bf16 = mx.array(positions, dtype=mx.bfloat16)
|
||||||
|
mx.eval(positions_bf16)
|
||||||
|
return positions_bf16.astype(mx.float32)
|
||||||
|
|
||||||
|
|
||||||
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
|
def compute_audio_frames(num_video_frames: int, fps: float) -> int:
|
||||||
@@ -735,10 +739,16 @@ def denoise_dev_av(
|
|||||||
# Always use standard CFG for audio
|
# Always use standard CFG for audio
|
||||||
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32)
|
||||||
|
|
||||||
# Apply CFG rescale if enabled
|
# Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation)
|
||||||
|
# factor = rescale * (cond_std / pred_std) + (1 - rescale)
|
||||||
|
# pred = pred * factor
|
||||||
if cfg_rescale > 0.0:
|
if cfg_rescale > 0.0:
|
||||||
video_x0_guided_f32 = cfg_rescale * video_x0_pos_f32 + (1.0 - cfg_rescale) * video_x0_guided_f32
|
v_factor = video_x0_pos_f32.std() / (video_x0_guided_f32.std() + 1e-8)
|
||||||
audio_x0_guided_f32 = cfg_rescale * audio_x0_pos_f32 + (1.0 - cfg_rescale) * audio_x0_guided_f32
|
v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale)
|
||||||
|
video_x0_guided_f32 = video_x0_guided_f32 * v_factor
|
||||||
|
a_factor = audio_x0_pos_f32.std() / (audio_x0_guided_f32.std() + 1e-8)
|
||||||
|
a_factor = cfg_rescale * a_factor + (1.0 - cfg_rescale)
|
||||||
|
audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor
|
||||||
else:
|
else:
|
||||||
video_x0_guided_f32 = video_x0_pos_f32
|
video_x0_guided_f32 = video_x0_pos_f32
|
||||||
audio_x0_guided_f32 = audio_x0_pos_f32
|
audio_x0_guided_f32 = audio_x0_pos_f32
|
||||||
|
|||||||
@@ -147,6 +147,12 @@ class LTXModelConfig(BaseModelConfig):
|
|||||||
if self.audio_positional_embedding_max_pos is None:
|
if self.audio_positional_embedding_max_pos is None:
|
||||||
self.audio_positional_embedding_max_pos = [20]
|
self.audio_positional_embedding_max_pos = [20]
|
||||||
|
|
||||||
|
# PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision"
|
||||||
|
# instead of "rope_double_precision" from the config, so double_precision_rope
|
||||||
|
# is always False in PyTorch regardless of what the config file says. Since the
|
||||||
|
# model was trained with this behavior, we must match it.
|
||||||
|
self.double_precision_rope = False
|
||||||
|
|
||||||
# Convert string enum values if loading from dict
|
# Convert string enum values if loading from dict
|
||||||
if isinstance(self.model_type, str):
|
if isinstance(self.model_type, str):
|
||||||
self.model_type = LTXModelType(self.model_type)
|
self.model_type = LTXModelType(self.model_type)
|
||||||
|
|||||||
@@ -399,6 +399,14 @@ def precompute_freqs_cis(
|
|||||||
num_attention_heads, rope_type
|
num_attention_heads, rope_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Cast positions to bfloat16 to match PyTorch's behavior.
|
||||||
|
# In PyTorch, positions are in bfloat16 (model dtype) during the entire
|
||||||
|
# generate_freqs computation — fractional positions, scaling, etc. are all
|
||||||
|
# computed in bfloat16. The multiplication with float32 freq_indices then
|
||||||
|
# upcasts to float32. This precision behavior is what the model was trained
|
||||||
|
# with, so we must replicate it.
|
||||||
|
indices_grid = indices_grid.astype(mx.bfloat16)
|
||||||
|
|
||||||
# Generate frequency indices
|
# Generate frequency indices
|
||||||
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
indices = generate_freq_grid(theta, indices_grid.shape[1], dim)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user