Files
mlx-video/tests/test_rope.py
2026-03-15 23:08:12 +01:00

507 lines
21 KiB
Python

import pytest
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.rope import (
precompute_freqs_cis,
)
from mlx_video.models.ltx.config import LTXModelConfig, LTXRopeType
def create_video_position_grid(
batch_size: int,
num_frames: int,
height: int,
width: int,
dtype: mx.Dtype = mx.float32,
) -> mx.array:
"""Create a simple video position grid for testing."""
t_coords = np.arange(0, num_frames)
h_coords = np.arange(0, height)
w_coords = np.arange(0, width)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_ends = patch_starts + 1
latent_coords = np.stack([patch_starts, patch_ends], axis=-1)
num_patches = num_frames * height * width
latent_coords = latent_coords.reshape(3, num_patches, 2)
latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1))
# Scale to pixel space
scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1)
pixel_coords = (latent_coords * scale_factors).astype(np.float32)
pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds
return mx.array(pixel_coords, dtype=dtype)
def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
"""Compute RoPE cos/sin using NumPy float64 as ground truth reference.
This mirrors the regular (non-double-precision) path in rope.py exactly,
but uses float64 throughout, so we can verify that the float32 MLX path
stays close to the true values.
"""
# positions_np: (B, 3, T, 2) in float64
n_pos_dims = positions_np.shape[1]
n_elem = 2 * n_pos_dims
# Middle-of-interval positions
mid = (positions_np[..., 0] + positions_np[..., 1]) / 2.0 # (B, 3, T)
# Frequency grid — matches generate_freq_grid() in rope.py:
# log_start = log(1)/log(theta) = 0
# log_end = log(theta)/log(theta) = 1
# pow_indices = theta^linspace(0, 1, num_indices) * pi/2
num_indices = dim // n_elem
if num_indices == 0:
num_indices = 1
lin_space = np.linspace(0.0, 1.0, num_indices, dtype=np.float64)
freq_indices = np.power(theta, lin_space) * (np.pi / 2) # (num_indices,)
# Fractional positions and scaling — matches generate_freqs()
# frac = pos / max_pos for each dim, then scale to [-1, 1]
frac_list = []
for d in range(n_pos_dims):
frac = mid[:, d, :] / max_pos[d] # (B, T)
frac_list.append(frac)
fractional = np.stack(frac_list, axis=-1) # (B, T, n_dims)
scaled = fractional * 2 - 1 # [-1, 1]
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
cos_ref = np.cos(freqs)
sin_ref = np.sin(freqs)
# Split RoPE: pad to dim//2, reshape to (B, H, T, dim_per_head//2)
expected = dim // 2
pad_size = expected - cos_ref.shape[-1]
if pad_size > 0:
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
B, T, _ = cos_ref.shape
dim_per_head = dim // num_heads
cos_ref = cos_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
sin_ref = sin_ref.reshape(B, T, num_heads, dim_per_head // 2).transpose(0, 2, 1, 3)
return cos_ref, sin_ref
class TestRoPEPositionPrecision:
"""Test suite for RoPE position precision requirements."""
def test_float32_positions_produce_consistent_output(self):
"""Float32 position grids should produce stable RoPE frequencies."""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Verify output dtype is float32
assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}"
assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}"
# Verify no NaN or Inf values
assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN"
assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN"
assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf"
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
# Verify cos/sin are in valid range [-1, 1]
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
"cos_freq values out of [-1, 1] range"
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
"sin_freq values out of [-1, 1] range"
def test_bfloat16_positions_cause_precision_loss(self):
"""bfloat16 positions should produce different (less precise) results than float32.
This test documents the known issue: bfloat16 has only 7 bits of mantissa
vs 23 bits for float32, causing quantization errors that get amplified
by sin/cos calculations in RoPE.
"""
# Create identical position grids in different dtypes
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
# Compute RoPE frequencies
cos_f32, sin_f32 = precompute_freqs_cis(
indices_grid=positions_f32,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
cos_bf16, sin_bf16 = precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Calculate the difference
cos_diff = mx.abs(cos_f32 - cos_bf16)
sin_diff = mx.abs(sin_f32 - sin_bf16)
max_cos_diff = mx.max(cos_diff).item()
max_sin_diff = mx.max(sin_diff).item()
# bfloat16 positions WILL cause measurable differences
# This test documents this known behavior
# The threshold here is intentionally low to catch the issue
precision_threshold = 1e-6
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
# Document the precision loss (this is expected behavior)
if has_precision_loss:
print(f"\nPrecision loss detected (expected):")
print(f" Max cos difference: {max_cos_diff:.6e}")
print(f" Max sin difference: {max_sin_diff:.6e}")
# This assertion documents the issue - bfloat16 positions cause precision loss
assert has_precision_loss, \
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
def test_double_precision_converts_to_float32_internally(self):
"""Verify that double_precision mode converts bfloat16 to float32 first."""
positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_bf16,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Output should still be float32
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
def test_position_grid_should_be_float32_recommendation(self):
"""Test that validates the recommended practice: positions should be float32.
This test serves as documentation that position grids MUST be float32
to avoid quality degradation in generated videos/audio.
"""
# Recommended: create positions in float32
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
assert positions.dtype == mx.float32, \
"Position grids should be created in float32 for RoPE precision"
# Verify the position values are reasonable
# Temporal positions should be small (seconds)
temporal_positions = positions[:, 0, :, :]
assert mx.max(temporal_positions).item() < 100, \
"Temporal positions should be in seconds (small values)"
# Spatial positions should be larger (pixels)
spatial_h = positions[:, 1, :, :]
spatial_w = positions[:, 2, :, :]
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
def test_float32_positions_match_numpy_float64_reference(self):
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
This is the key correctness test. We compute RoPE with NumPy in float64
(ground truth) and verify that the MLX float32 path produces nearly
identical results. The max allowed diff (1e-5) is well below the error
we saw with bfloat16 positions (~2.0 max diff, cosine sim 0.88).
"""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_np = np.array(positions).astype(np.float64)
dim = 128
theta = 10000.0
max_pos = [20, 2048, 2048]
num_heads = 32
# MLX result (float32 path, non-double-precision)
cos_mlx, sin_mlx = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=theta,
max_pos=max_pos,
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
# NumPy float64 reference
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
cos_mlx_np = np.array(cos_mlx)
sin_mlx_np = np.array(sin_mlx)
max_cos_diff = np.max(np.abs(cos_mlx_np - cos_ref))
max_sin_diff = np.max(np.abs(sin_mlx_np - sin_ref))
# Cosine similarity (flatten for single scalar)
cos_flat = cos_mlx_np.flatten()
ref_flat = cos_ref.flatten()
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
assert max_cos_diff < 0.01, \
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert max_sin_diff < 0.01, \
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert cosine_sim > 0.9999, \
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
def test_high_frequency_amplification_regression(self):
"""Regression test for the specific failure mode: high-frequency index amplification.
With production-sized grids (5x16x16 = 1280 tokens), fractional positions
like 0.000391 get multiplied by frequency indices up to ~15708. In bfloat16
the fractional part is quantized, producing raw freq errors of ~6.14 and
cos/sin sign flips (max_diff ~2.0). Float32 must keep max_diff < 0.01.
"""
# Use a production-like grid size
positions = create_video_position_grid(1, 5, 16, 16, dtype=mx.float32)
positions_np = np.array(positions).astype(np.float64)
dim = 128
theta = 10000.0
max_pos = [20, 2048, 2048]
num_heads = 32
cos_mlx, sin_mlx = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=theta,
max_pos=max_pos,
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
assert max_cos_diff < 0.01, \
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
assert max_sin_diff < 0.01, \
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
class TestRoPEInterleaved:
"""Tests for interleaved RoPE mode."""
def test_interleaved_rope_with_float32_positions(self):
"""Interleaved RoPE should work correctly with float32 positions."""
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.INTERLEAVED,
double_precision=False,
)
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
assert not mx.any(mx.isnan(cos_freq)).item()
assert not mx.any(mx.isnan(sin_freq)).item()
class TestRoPEInputCasting:
"""Tests that precompute_freqs_cis casts positions to float32 internally.
The fix in rope.py ensures that regardless of the input dtype, positions are
cast to float32 before any computation. This class verifies that behavior
for both the regular and double-precision paths.
"""
def test_regular_path_outputs_float32(self):
"""Regular path: both float32 and bfloat16 inputs produce float32 output."""
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=False,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs)
# Both produce float32 output regardless of input dtype
assert cos_f32.dtype == mx.float32
assert cos_bf16.dtype == mx.float32
assert sin_f32.dtype == mx.float32
assert sin_bf16.dtype == mx.float32
# No NaN/Inf in either
assert not mx.any(mx.isnan(cos_bf16)).item()
assert not mx.any(mx.isinf(cos_bf16)).item()
def test_double_precision_path_outputs_float32(self):
"""Double-precision path: both float32 and bfloat16 inputs produce float32 output."""
positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=True,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
cos_bf16, sin_bf16 = precompute_freqs_cis(indices_grid=positions_bf16, **kwargs)
assert cos_f32.dtype == mx.float32
assert cos_bf16.dtype == mx.float32
assert sin_f32.dtype == mx.float32
assert sin_bf16.dtype == mx.float32
assert not mx.any(mx.isnan(cos_bf16)).item()
assert not mx.any(mx.isinf(cos_bf16)).item()
def test_float16_input_also_cast_to_float32(self):
"""Float16 input should also be handled correctly."""
positions_f16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float16)
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_f16,
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=False,
)
assert cos_freq.dtype == mx.float32
assert sin_freq.dtype == mx.float32
assert not mx.any(mx.isnan(cos_freq)).item()
class TestDoublePrecisionRopeConfig:
"""Tests for the conditional double_precision_rope logic in LTXModelConfig."""
def test_ltx2_forces_double_precision_rope_false(self):
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
assert config.double_precision_rope is False, \
"LTX-2 should force double_precision_rope=False regardless of input"
def test_ltx23_preserves_double_precision_rope_true(self):
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
assert config.double_precision_rope is True, \
"LTX-2.3 should preserve double_precision_rope=True"
def test_ltx23_preserves_double_precision_rope_false(self):
"""LTX-2.3 with double_precision_rope=False should stay False."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
assert config.double_precision_rope is False, \
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
def test_ltx2_default_double_precision_rope(self):
"""LTX-2 default (double_precision_rope not set) should be False."""
config = LTXModelConfig(has_prompt_adaln=False)
assert config.double_precision_rope is False
def test_ltx23_default_double_precision_rope(self):
"""LTX-2.3 default (double_precision_rope not set) should be False (field default)."""
config = LTXModelConfig(has_prompt_adaln=True)
# The field default is False and __post_init__ doesn't override for LTX-2.3
assert config.double_precision_rope is False
def test_config_from_dict_ltx2(self):
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
config = LTXModelConfig.from_dict({
"has_prompt_adaln": False,
"double_precision_rope": True,
"rope_type": "split",
})
assert config.double_precision_rope is False
def test_config_from_dict_ltx23(self):
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
config = LTXModelConfig.from_dict({
"has_prompt_adaln": True,
"double_precision_rope": True,
"rope_type": "split",
})
assert config.double_precision_rope is True
class TestRoPESplit:
"""Tests for split RoPE mode (used by LTX-2)."""
def test_split_rope_output_shape(self):
"""Verify split RoPE output has correct shape (B, H, T, dim_per_head//2)."""
batch_size = 1
num_frames = 4
height = 4
width = 4
num_heads = 32
dim = 128
positions = create_video_position_grid(batch_size, num_frames, height, width)
num_tokens = num_frames * height * width
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions,
dim=dim,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=num_heads,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
# Shape should be (B, H, T, dim_per_head//2)
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
dim_per_head = dim // num_heads
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
assert cos_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert sin_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {sin_freq.shape}"
if __name__ == "__main__":
pytest.main([__file__, "-v"])