format
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mlx_video.models.ltx_2.rope import (
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
from mlx_video.models.ltx_2.config import LTXModelConfig, LTXRopeType
|
||||
from mlx_video.models.ltx_2.rope import precompute_freqs_cis
|
||||
|
||||
|
||||
def create_video_position_grid(
|
||||
@@ -20,7 +18,7 @@ def create_video_position_grid(
|
||||
h_coords = np.arange(0, height)
|
||||
w_coords = np.arange(0, width)
|
||||
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij')
|
||||
t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing="ij")
|
||||
patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0)
|
||||
patch_ends = patch_starts + 1
|
||||
|
||||
@@ -71,10 +69,14 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
||||
scaled = fractional * 2 - 1 # [-1, 1]
|
||||
|
||||
# Outer product: (B, T, n_dims, 1) * (1, 1, 1, num_indices)
|
||||
freqs = scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
freqs = (
|
||||
scaled[..., np.newaxis] * freq_indices[np.newaxis, np.newaxis, np.newaxis, :]
|
||||
)
|
||||
# (B, T, n_dims, num_indices) -> swap last two -> (B, T, num_indices, n_dims) -> flatten
|
||||
freqs = np.swapaxes(freqs, -1, -2)
|
||||
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # (B, T, num_indices * n_dims)
|
||||
freqs = freqs.reshape(
|
||||
freqs.shape[0], freqs.shape[1], -1
|
||||
) # (B, T, num_indices * n_dims)
|
||||
|
||||
cos_ref = np.cos(freqs)
|
||||
sin_ref = np.sin(freqs)
|
||||
@@ -84,8 +86,12 @@ def _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads):
|
||||
pad_size = expected - cos_ref.shape[-1]
|
||||
if pad_size > 0:
|
||||
# Padding is prepended (ones for cos, zeros for sin) — matches split_freqs_cis()
|
||||
cos_ref = np.concatenate([np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1)
|
||||
sin_ref = np.concatenate([np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1)
|
||||
cos_ref = np.concatenate(
|
||||
[np.ones((*cos_ref.shape[:-1], pad_size)), cos_ref], axis=-1
|
||||
)
|
||||
sin_ref = np.concatenate(
|
||||
[np.zeros((*sin_ref.shape[:-1], pad_size)), sin_ref], axis=-1
|
||||
)
|
||||
|
||||
B, T, _ = cos_ref.shape
|
||||
dim_per_head = dim // num_heads
|
||||
@@ -124,10 +130,12 @@ class TestRoPEPositionPrecision:
|
||||
assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf"
|
||||
|
||||
# Verify cos/sin are in valid range [-1, 1]
|
||||
assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \
|
||||
"cos_freq values out of [-1, 1] range"
|
||||
assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \
|
||||
"sin_freq values out of [-1, 1] range"
|
||||
assert (
|
||||
mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item()
|
||||
), "cos_freq values out of [-1, 1] range"
|
||||
assert (
|
||||
mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item()
|
||||
), "sin_freq values out of [-1, 1] range"
|
||||
|
||||
def test_bfloat16_positions_cause_precision_loss(self):
|
||||
"""bfloat16 positions should produce different (less precise) results than float32.
|
||||
@@ -175,7 +183,9 @@ class TestRoPEPositionPrecision:
|
||||
# The threshold here is intentionally low to catch the issue
|
||||
precision_threshold = 1e-6
|
||||
|
||||
has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||
has_precision_loss = (
|
||||
max_cos_diff > precision_threshold or max_sin_diff > precision_threshold
|
||||
)
|
||||
|
||||
# Document the precision loss (this is expected behavior)
|
||||
if has_precision_loss:
|
||||
@@ -184,8 +194,9 @@ class TestRoPEPositionPrecision:
|
||||
print(f" Max sin difference: {max_sin_diff:.6e}")
|
||||
|
||||
# This assertion documents the issue - bfloat16 positions cause precision loss
|
||||
assert has_precision_loss, \
|
||||
"Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||
assert (
|
||||
has_precision_loss
|
||||
), "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed"
|
||||
|
||||
def test_double_precision_converts_to_float32_internally(self):
|
||||
"""Verify that double_precision mode converts bfloat16 to float32 first."""
|
||||
@@ -215,20 +226,26 @@ class TestRoPEPositionPrecision:
|
||||
# Recommended: create positions in float32
|
||||
positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32)
|
||||
|
||||
assert positions.dtype == mx.float32, \
|
||||
"Position grids should be created in float32 for RoPE precision"
|
||||
assert (
|
||||
positions.dtype == mx.float32
|
||||
), "Position grids should be created in float32 for RoPE precision"
|
||||
|
||||
# Verify the position values are reasonable
|
||||
# Temporal positions should be small (seconds)
|
||||
temporal_positions = positions[:, 0, :, :]
|
||||
assert mx.max(temporal_positions).item() < 100, \
|
||||
"Temporal positions should be in seconds (small values)"
|
||||
assert (
|
||||
mx.max(temporal_positions).item() < 100
|
||||
), "Temporal positions should be in seconds (small values)"
|
||||
|
||||
# Spatial positions should be larger (pixels)
|
||||
spatial_h = positions[:, 1, :, :]
|
||||
spatial_w = positions[:, 2, :, :]
|
||||
assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive"
|
||||
assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive"
|
||||
assert (
|
||||
mx.max(spatial_h).item() > 0
|
||||
), "Spatial height positions should be positive"
|
||||
assert (
|
||||
mx.max(spatial_w).item() > 0
|
||||
), "Spatial width positions should be positive"
|
||||
|
||||
def test_float32_positions_match_numpy_float64_reference(self):
|
||||
"""Regression test: float32 RoPE must closely match a NumPy float64 reference.
|
||||
@@ -259,7 +276,9 @@ class TestRoPEPositionPrecision:
|
||||
)
|
||||
|
||||
# NumPy float64 reference
|
||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
||||
cos_ref, sin_ref = _numpy_reference_rope(
|
||||
positions_np, dim, theta, max_pos, num_heads
|
||||
)
|
||||
|
||||
cos_mlx_np = np.array(cos_mlx)
|
||||
sin_mlx_np = np.array(sin_mlx)
|
||||
@@ -270,16 +289,21 @@ class TestRoPEPositionPrecision:
|
||||
# Cosine similarity (flatten for single scalar)
|
||||
cos_flat = cos_mlx_np.flatten()
|
||||
ref_flat = cos_ref.flatten()
|
||||
cosine_sim = np.dot(cos_flat, ref_flat) / (np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat))
|
||||
cosine_sim = np.dot(cos_flat, ref_flat) / (
|
||||
np.linalg.norm(cos_flat) * np.linalg.norm(ref_flat)
|
||||
)
|
||||
|
||||
# float32 vs float64: expect small diffs from 23-bit vs 52-bit mantissa.
|
||||
# Threshold 0.01 is well below the bfloat16 failure mode (~2.0 max diff).
|
||||
assert max_cos_diff < 0.01, \
|
||||
f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert max_sin_diff < 0.01, \
|
||||
f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert cosine_sim > 0.9999, \
|
||||
f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
||||
assert (
|
||||
max_cos_diff < 0.01
|
||||
), f"cos max diff {max_cos_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert (
|
||||
max_sin_diff < 0.01
|
||||
), f"sin max diff {max_sin_diff:.2e} exceeds 0.01 — float32 positions may not be preserved"
|
||||
assert (
|
||||
cosine_sim > 0.9999
|
||||
), f"cos cosine similarity {cosine_sim:.6f} too low — expected >0.9999"
|
||||
|
||||
def test_high_frequency_amplification_regression(self):
|
||||
"""Regression test for the specific failure mode: high-frequency index amplification.
|
||||
@@ -309,16 +333,20 @@ class TestRoPEPositionPrecision:
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
cos_ref, sin_ref = _numpy_reference_rope(positions_np, dim, theta, max_pos, num_heads)
|
||||
cos_ref, sin_ref = _numpy_reference_rope(
|
||||
positions_np, dim, theta, max_pos, num_heads
|
||||
)
|
||||
|
||||
max_cos_diff = np.max(np.abs(np.array(cos_mlx) - cos_ref))
|
||||
max_sin_diff = np.max(np.abs(np.array(sin_mlx) - sin_ref))
|
||||
|
||||
# Float32 should keep errors well below the bfloat16 failure threshold of ~2.0
|
||||
assert max_cos_diff < 0.01, \
|
||||
f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
||||
assert max_sin_diff < 0.01, \
|
||||
f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
||||
assert (
|
||||
max_cos_diff < 0.01
|
||||
), f"Production grid cos max diff {max_cos_diff:.4f} — high-freq amplification detected"
|
||||
assert (
|
||||
max_sin_diff < 0.01
|
||||
), f"Production grid sin max diff {max_sin_diff:.4f} — high-freq amplification detected"
|
||||
|
||||
|
||||
class TestRoPEInterleaved:
|
||||
@@ -359,9 +387,13 @@ class TestRoPEInputCasting:
|
||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||
|
||||
kwargs = dict(
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||
@@ -383,9 +415,13 @@ class TestRoPEInputCasting:
|
||||
positions_bf16 = positions_f32.astype(mx.bfloat16)
|
||||
|
||||
kwargs = dict(
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=True,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=True,
|
||||
)
|
||||
|
||||
cos_f32, sin_f32 = precompute_freqs_cis(indices_grid=positions_f32, **kwargs)
|
||||
@@ -405,9 +441,13 @@ class TestRoPEInputCasting:
|
||||
|
||||
cos_freq, sin_freq = precompute_freqs_cis(
|
||||
indices_grid=positions_f16,
|
||||
dim=128, theta=10000.0, max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True, num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT, double_precision=False,
|
||||
dim=128,
|
||||
theta=10000.0,
|
||||
max_pos=[20, 2048, 2048],
|
||||
use_middle_indices_grid=True,
|
||||
num_attention_heads=32,
|
||||
rope_type=LTXRopeType.SPLIT,
|
||||
double_precision=False,
|
||||
)
|
||||
|
||||
assert cos_freq.dtype == mx.float32
|
||||
@@ -421,20 +461,23 @@ class TestDoublePrecisionRopeConfig:
|
||||
def test_ltx2_forces_double_precision_rope_false(self):
|
||||
"""LTX-2 (no prompt adaln) must have double_precision_rope=False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=False, double_precision_rope=True)
|
||||
assert config.double_precision_rope is False, \
|
||||
"LTX-2 should force double_precision_rope=False regardless of input"
|
||||
assert (
|
||||
config.double_precision_rope is False
|
||||
), "LTX-2 should force double_precision_rope=False regardless of input"
|
||||
|
||||
def test_ltx23_preserves_double_precision_rope_true(self):
|
||||
"""LTX-2.3 (has_prompt_adaln=True) should keep double_precision_rope=True."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=True)
|
||||
assert config.double_precision_rope is True, \
|
||||
"LTX-2.3 should preserve double_precision_rope=True"
|
||||
assert (
|
||||
config.double_precision_rope is True
|
||||
), "LTX-2.3 should preserve double_precision_rope=True"
|
||||
|
||||
def test_ltx23_preserves_double_precision_rope_false(self):
|
||||
"""LTX-2.3 with double_precision_rope=False should stay False."""
|
||||
config = LTXModelConfig(has_prompt_adaln=True, double_precision_rope=False)
|
||||
assert config.double_precision_rope is False, \
|
||||
"LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
||||
assert (
|
||||
config.double_precision_rope is False
|
||||
), "LTX-2.3 should respect double_precision_rope=False when explicitly set"
|
||||
|
||||
def test_ltx2_default_double_precision_rope(self):
|
||||
"""LTX-2 default (double_precision_rope not set) should be False."""
|
||||
@@ -449,20 +492,24 @@ class TestDoublePrecisionRopeConfig:
|
||||
|
||||
def test_config_from_dict_ltx2(self):
|
||||
"""Config created from dict for LTX-2 should force double_precision_rope=False."""
|
||||
config = LTXModelConfig.from_dict({
|
||||
"has_prompt_adaln": False,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
})
|
||||
config = LTXModelConfig.from_dict(
|
||||
{
|
||||
"has_prompt_adaln": False,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
}
|
||||
)
|
||||
assert config.double_precision_rope is False
|
||||
|
||||
def test_config_from_dict_ltx23(self):
|
||||
"""Config created from dict for LTX-2.3 should preserve double_precision_rope."""
|
||||
config = LTXModelConfig.from_dict({
|
||||
"has_prompt_adaln": True,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
})
|
||||
config = LTXModelConfig.from_dict(
|
||||
{
|
||||
"has_prompt_adaln": True,
|
||||
"double_precision_rope": True,
|
||||
"rope_type": "split",
|
||||
}
|
||||
)
|
||||
assert config.double_precision_rope is True
|
||||
|
||||
|
||||
@@ -496,10 +543,12 @@ class TestRoPESplit:
|
||||
# dim=128, num_heads=32, so dim_per_head=4, and split uses half=2
|
||||
dim_per_head = dim // num_heads
|
||||
expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2)
|
||||
assert cos_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||
assert sin_freq.shape == expected_shape, \
|
||||
f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||
assert (
|
||||
cos_freq.shape == expected_shape
|
||||
), f"Expected shape {expected_shape}, got {cos_freq.shape}"
|
||||
assert (
|
||||
sin_freq.shape == expected_shape
|
||||
), f"Expected shape {expected_shape}, got {sin_freq.shape}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user