From e0aafd72fc705c0fee3c84a55648a99f4dfc5480 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 12 Mar 2026 21:26:38 +0100 Subject: [PATCH] Refactor generate.py to ensure temporal coordinates and position grids are processed in bfloat16 for consistency with PyTorch's precision behavior. Update denoise_dev_av function to apply standard ratio rescaling for audio and video guidance, enhancing numerical fidelity and model compatibility. --- mlx_video/generate.py | 34 ++++++++++++++++++++++------------ mlx_video/models/ltx/config.py | 6 ++++++ mlx_video/models/ltx/rope.py | 8 ++++++++ 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 37f0824..998b7f8 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -236,15 +236,16 @@ def create_position_grid( a_max=None ) - # Compute temporal division in bfloat16 to match PyTorch's precision behavior - # This ensures RoPE frequencies are computed identically to the reference implementation - temporal_coords = mx.array(pixel_coords[:, 0, :, :], dtype=mx.bfloat16) - fps_bf16 = mx.array(fps, dtype=mx.bfloat16) - temporal_coords = temporal_coords / fps_bf16 - mx.eval(temporal_coords) - pixel_coords[:, 0, :, :] = np.array(temporal_coords.astype(mx.float32)) + # Divide temporal coords by fps + pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps - return mx.array(pixel_coords, dtype=mx.float32) + # Cast entire position grid through bfloat16 to match PyTorch's behavior. + # PyTorch does: positions = positions.to(bfloat16) on ALL coordinates before + # passing to the transformer/RoPE. This quantization is what the model was + # trained with, so we must replicate it for numerical fidelity. + positions_bf16 = mx.array(pixel_coords, dtype=mx.bfloat16) + mx.eval(positions_bf16) + return positions_bf16.astype(mx.float32) def create_audio_position_grid( @@ -270,7 +271,10 @@ def create_audio_position_grid( positions = positions[np.newaxis, np.newaxis, :, :] positions = np.tile(positions, (batch_size, 1, 1, 1)) - return mx.array(positions, dtype=mx.float32) + # Cast through bfloat16 to match PyTorch's precision behavior + positions_bf16 = mx.array(positions, dtype=mx.bfloat16) + mx.eval(positions_bf16) + return positions_bf16.astype(mx.float32) def compute_audio_frames(num_video_frames: int, fps: float) -> int: @@ -735,10 +739,16 @@ def denoise_dev_av( # Always use standard CFG for audio audio_x0_guided_f32 = audio_x0_pos_f32 + (cfg_scale - 1.0) * (audio_x0_pos_f32 - audio_x0_neg_f32) - # Apply CFG rescale if enabled + # Apply CFG rescale if enabled (std-ratio rescaling to reduce over-saturation) + # factor = rescale * (cond_std / pred_std) + (1 - rescale) + # pred = pred * factor if cfg_rescale > 0.0: - video_x0_guided_f32 = cfg_rescale * video_x0_pos_f32 + (1.0 - cfg_rescale) * video_x0_guided_f32 - audio_x0_guided_f32 = cfg_rescale * audio_x0_pos_f32 + (1.0 - cfg_rescale) * audio_x0_guided_f32 + v_factor = video_x0_pos_f32.std() / (video_x0_guided_f32.std() + 1e-8) + v_factor = cfg_rescale * v_factor + (1.0 - cfg_rescale) + video_x0_guided_f32 = video_x0_guided_f32 * v_factor + a_factor = audio_x0_pos_f32.std() / (audio_x0_guided_f32.std() + 1e-8) + a_factor = cfg_rescale * a_factor + (1.0 - cfg_rescale) + audio_x0_guided_f32 = audio_x0_guided_f32 * a_factor else: video_x0_guided_f32 = video_x0_pos_f32 audio_x0_guided_f32 = audio_x0_pos_f32 diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index 40bb9ef..b7dfa0a 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -147,6 +147,12 @@ class LTXModelConfig(BaseModelConfig): if self.audio_positional_embedding_max_pos is None: self.audio_positional_embedding_max_pos = [20] + # PyTorch LTX-2 configurator has a bug: it reads "frequencies_precision" + # instead of "rope_double_precision" from the config, so double_precision_rope + # is always False in PyTorch regardless of what the config file says. Since the + # model was trained with this behavior, we must match it. + self.double_precision_rope = False + # Convert string enum values if loading from dict if isinstance(self.model_type, str): self.model_type = LTXModelType(self.model_type) diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index 66a8710..d9ae359 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -399,6 +399,14 @@ def precompute_freqs_cis( num_attention_heads, rope_type ) + # Cast positions to bfloat16 to match PyTorch's behavior. + # In PyTorch, positions are in bfloat16 (model dtype) during the entire + # generate_freqs computation — fractional positions, scaling, etc. are all + # computed in bfloat16. The multiplication with float32 freq_indices then + # upcasts to float32. This precision behavior is what the model was trained + # with, so we must replicate it. + indices_grid = indices_grid.astype(mx.bfloat16) + # Generate frequency indices indices = generate_freq_grid(theta, indices_grid.shape[1], dim)