This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,11 +1,9 @@
import pytest
import mlx.core as mx
import numpy as np
import pytest
from mlx_video.models.ltx_2.rope import (
precompute_freqs_cis,
)
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
def create_video_position_grid(
@@ -20,7 +18,7 @@ def create_video_position_grid(
h_coords = np.arange(0, height)
w_coords = np.arange(0, width)
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij")
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
patch_ends = patch_starts + 1
@@ -71,10 +69,14 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
scaled = fractional * 2 - 1 # [-1, 1]
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
freqs = (
scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
)
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
freqs = np.swapaxes(freqs, -1, -2)
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
freqs = freqs.reshape(
freqs.shape[0], freqs.shape[1], -1
) # (B, T, num_indices * n_dims)
cos_ref = np.cos(freqs)
sin_ref = np.sin(freqs)
@@ -84,8 +86,12 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
pad_size = expected - cos_ref.shape[-1]
if pad_size > 0:
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
cos_ref = np.concatenate(
[np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1
)
sin_ref = np.concatenate(
[np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1
)
B, T, _ = cos_ref.shape
dim_per_head = dim // num_heads
@@ -124,10 +130,12 @@ class TestRoPEPositionPrecision:
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
# Verify cos/sin are in valid range [-1, 1]
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
"cos_freq values out of [-1, 1] range"
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
"sin_freq values out of [-1, 1] range"
assert (
mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item()
), "cos_freq values out of [-1, 1] range"
assert (
mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item()
), "sin_freq values out of [-1, 1] range"
def test_bfloat16_positions_cause_precision_loss(self):
"""bfloat16 positions should produce different (less precise) results than float32.
@@ -175,7 +183,9 @@ class TestRoPEPositionPrecision:
# The threshold here is intentionally low to catch the issue
precision_threshold = 1e-6
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
has_precision_loss = (
max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
)
# Document the precision loss (this is expected behavior)
if has_precision_loss:
@@ -184,8 +194,9 @@ class TestRoPEPositionPrecision:
print(f" Max sin difference: {max_sin_diff:.6e}")
# This assertion documents the issue - bfloat16 positions cause precision loss
assert has_precision_loss, \
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
assert (
has_precision_loss
), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
def test_double_precision_converts_to_float32_internally(self):
"""Verify that double_precision mode converts bfloat16 to float32 first."""
@@ -215,20 +226,26 @@ class TestRoPEPositionPrecision:
# Recommended: create positions in float32
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
assert positions.dtype == mx.float32, \
"Position grids should be created in float32 for RoPE precision"
assert (
positions.dtype == mx.float32
), "Position grids should be created in float32 for RoPE precision"
# Verify the position values are reasonable
# Temporal positions should be small (seconds)
temporal_positions = positions[:, 0, :, :]
assert mx.max(temporal_positions).item() < 100, \
"Temporal positions should be in seconds (small values)"
assert (
mx.max(temporal_positions).item() < 100
), "Temporal positions should be in seconds (small values)"
# Spatial positions should be larger (pixels)
spatial_h = positions[:, 1, :, :]
spatial_w = positions[:, 2, :, :]
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
assert (
mx.max(spatial_h).item() > 0
), "Spatial height positions should be positive"
assert (
mx.max(spatial_w).item() > 0
), "Spatial width positions should be positive"
def test_float32_positions_match_numpy_float64_reference(self):
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
@@ -259,7 +276,9 @@ class TestRoPEPositionPrecision:
)
# NumPy float64 reference
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
cos_ref, sin_ref = _numpy_reference_rope(
positions_np, dim, theta, max_pos, num_heads
)
cos_mlx_np = np.array(cos_mlx)
sin_mlx_np = np.array(sin_mlx)
@@ -270,16 +289,21 @@ class TestRoPEPositionPrecision:
# Cosine similarity (flatten for single scalar)
cos_flat = cos_mlx_np.flatten()
ref_flat = cos_ref.flatten()
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
cosine_sim = np.dot(cos_flat, ref_flat) / (
np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)
)
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
assert max_cos_diff < 0.01, \
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert max_sin_diff < 0.01, \
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert cosine_sim > 0.9999, \
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
assert (
max_cos_diff < 0.01
), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert (
max_sin_diff < 0.01
), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
assert (
cosine_sim > 0.9999
), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
def test_high_frequency_amplification_regression(self):
"""Regression test for the specific failure mode: high-frequency index amplification.
@@ -309,16 +333,20 @@ class TestRoPEPositionPrecision:
double_precision=False,
)
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
cos_ref, sin_ref = _numpy_reference_rope(
positions_np, dim, theta, max_pos, num_heads
)
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
assert max_cos_diff < 0.01, \
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
assert max_sin_diff < 0.01, \
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
assert (
max_cos_diff < 0.01
), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
assert (
max_sin_diff < 0.01
), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
class TestRoPEInterleaved:
@@ -359,9 +387,13 @@ class TestRoPEInputCasting:
positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=False,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
@@ -383,9 +415,13 @@ class TestRoPEInputCasting:
positions_bf16 = positions_f32.astype(mx.bfloat16)
kwargs = dict(
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=True,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=True,
)
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
@@ -405,9 +441,13 @@ class TestRoPEInputCasting:
cos_freq, sin_freq = precompute_freqs_cis(
indices_grid=positions_f16,
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
use_middle_indices_grid=True, num_attention_heads=32,
rope_type=LTXRopeType.SPLIT, double_precision=False,
dim=128,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=True,
num_attention_heads=32,
rope_type=LTXRopeType.SPLIT,
double_precision=False,
)
assert cos_freq.dtype == mx.float32
@@ -421,20 +461,23 @@ class TestDoublePrecisionRopeConfig:
def test_ltx2_forces_double_precision_rope_false(self):
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
assert config.double_precision_rope is False, \
"LTX-2 should force double_precision_rope=False regardless of input"
assert (
config.double_precision_rope is False
), "LTX-2 should force double_precision_rope=False regardless of input"
def test_ltx23_preserves_double_precision_rope_true(self):
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
assert config.double_precision_rope is True, \
"LTX-2.3 should preserve double_precision_rope=True"
assert (
config.double_precision_rope is True
), "LTX-2.3 should preserve double_precision_rope=True"
def test_ltx23_preserves_double_precision_rope_false(self):
"""LTX-2.3 with double_precision_rope=False should stay False."""
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
assert config.double_precision_rope is False, \
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
assert (
config.double_precision_rope is False
), "LTX-2.3 should respect double_precision_rope=False when explicitly set"
def test_ltx2_default_double_precision_rope(self):
"""LTX-2 default (double_precision_rope not set) should be False."""
@@ -449,20 +492,24 @@ class TestDoublePrecisionRopeConfig:
def test_config_from_dict_ltx2(self):
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
config = LTXModelConfig.from_dict({
"has_prompt_adaln": False,
"double_precision_rope": True,
"rope_type": "split",
})
config = LTXModelConfig.from_dict(
{
"has_prompt_adaln": False,
"double_precision_rope": True,
"rope_type": "split",
}
)
assert config.double_precision_rope is False
def test_config_from_dict_ltx23(self):
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
config = LTXModelConfig.from_dict({
"has_prompt_adaln": True,
"double_precision_rope": True,
"rope_type": "split",
})
config = LTXModelConfig.from_dict(
{
"has_prompt_adaln": True,
"double_precision_rope": True,
"rope_type": "split",
}
)
assert config.double_precision_rope is True
@@ -496,10 +543,12 @@ class TestRoPESplit:
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
dim_per_head = dim // num_heads
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
assert cos_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert sin_freq.shape == expected_shape, \
f"Expected shape {expected_shape}, got {sin_freq.shape}"
assert (
cos_freq.shape == expected_shape
), f"Expected shape {expected_shape}, got {cos_freq.shape}"
assert (
sin_freq.shape == expected_shape
), f"Expected shape {expected_shape}, got {sin_freq.shape}"
if __name__ == "__main__":