Files
mlx-video/tests/test_generate_dev.py

423 lines
16 KiB
Python

"""Tests for LTX-2 dev model generation pipeline."""
import pytest
import mlx.core as mx
from mlx_video.generate_dev import (
ltx2_scheduler,
create_position_grid,
create_audio_position_grid,
compute_audio_frames,
cfg_delta,
DEFAULT_NEGATIVE_PROMPT,
AUDIO_SAMPLE_RATE,
AUDIO_LATENTS_PER_SECOND,
)
class TestLTX2Scheduler:
"""Tests for the LTX-2 sigma scheduler."""
def test_scheduler_output_shape(self):
"""Scheduler should return steps+1 sigma values."""
steps = 20
sigmas = ltx2_scheduler(steps=steps)
assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}"
def test_scheduler_starts_at_one(self):
"""Sigma schedule should start at 1.0."""
sigmas = ltx2_scheduler(steps=20)
assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}"
def test_scheduler_ends_at_zero(self):
"""Sigma schedule should end at 0.0."""
sigmas = ltx2_scheduler(steps=20)
assert abs(sigmas[-1].item()) < 1e-5, f"Expected 0.0, got {sigmas[-1].item()}"
def test_scheduler_monotonically_decreasing(self):
"""Sigma values should monotonically decrease."""
sigmas = ltx2_scheduler(steps=20)
sigmas_list = sigmas.tolist()
for i in range(len(sigmas_list) - 1):
assert sigmas_list[i] >= sigmas_list[i + 1], \
f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
def test_scheduler_dtype(self):
"""Scheduler should return float32 array."""
sigmas = ltx2_scheduler(steps=20)
assert sigmas.dtype == mx.float32, f"Expected float32, got {sigmas.dtype}"
def test_scheduler_with_num_tokens(self):
"""Scheduler should accept num_tokens parameter."""
sigmas_default = ltx2_scheduler(steps=20, num_tokens=None)
sigmas_custom = ltx2_scheduler(steps=20, num_tokens=1920)
# Both should be valid arrays
assert sigmas_default.shape == (21,)
assert sigmas_custom.shape == (21,)
def test_scheduler_no_stretch(self):
"""Scheduler without stretching should still work."""
sigmas = ltx2_scheduler(steps=20, stretch=False)
assert sigmas.shape == (21,)
assert sigmas[0].item() > 0
assert sigmas[-1].item() == 0.0
def test_scheduler_different_steps(self):
"""Scheduler should work with different step counts."""
for steps in [5, 10, 20, 40, 50]:
sigmas = ltx2_scheduler(steps=steps)
assert sigmas.shape == (steps + 1,), f"Failed for steps={steps}"
class TestCreatePositionGrid:
"""Tests for position grid creation."""
def test_position_grid_shape(self):
"""Position grid should have correct shape (B, 3, num_patches, 2)."""
batch_size = 1
num_frames = 5
height = 16
width = 24
positions = create_position_grid(batch_size, num_frames, height, width)
num_patches = num_frames * height * width
expected_shape = (batch_size, 3, num_patches, 2)
assert positions.shape == expected_shape, \
f"Expected {expected_shape}, got {positions.shape}"
def test_position_grid_dtype(self):
"""Position grid should be float32 for RoPE precision."""
positions = create_position_grid(1, 5, 16, 24)
assert positions.dtype == mx.float32, \
f"Expected float32 for RoPE precision, got {positions.dtype}"
def test_position_grid_batch_size(self):
"""Position grid should respect batch size."""
for batch_size in [1, 2, 4]:
positions = create_position_grid(batch_size, 5, 16, 24)
assert positions.shape[0] == batch_size
def test_position_grid_temporal_dimension(self):
"""Temporal dimension should have values scaled by fps."""
positions = create_position_grid(1, 5, 16, 24, fps=24.0)
temporal = positions[0, 0, :, :] # (num_patches, 2)
# Values should be in seconds (divided by fps)
max_temporal = mx.max(temporal).item()
# For 5 latent frames at scale 8, max pixel frame ~ 40, divided by 24 fps ~ 1.67s
assert max_temporal < 10, f"Temporal values too large: {max_temporal}"
def test_position_grid_spatial_dimensions(self):
"""Spatial dimensions should have pixel-space values."""
positions = create_position_grid(1, 5, 16, 24, spatial_scale=32)
# Height dimension
height_vals = positions[0, 1, :, :]
max_height = mx.max(height_vals).item()
# 16 latent * 32 scale = 512 pixels
assert max_height <= 512, f"Height values too large: {max_height}"
# Width dimension
width_vals = positions[0, 2, :, :]
max_width = mx.max(width_vals).item()
# 24 latent * 32 scale = 768 pixels
assert max_width <= 768, f"Width values too large: {max_width}"
def test_position_grid_causal_fix(self):
"""Causal fix should adjust first frame temporal values."""
positions_causal = create_position_grid(1, 5, 16, 24, causal_fix=True)
positions_no_causal = create_position_grid(1, 5, 16, 24, causal_fix=False)
# They should be different due to causal fix
diff = mx.abs(positions_causal - positions_no_causal)
assert mx.max(diff).item() > 0, "Causal fix should change position values"
def test_position_grid_no_nan_or_inf(self):
"""Position grid should not contain NaN or Inf values."""
positions = create_position_grid(1, 5, 16, 24)
assert not mx.any(mx.isnan(positions)).item(), "Position grid contains NaN"
assert not mx.any(mx.isinf(positions)).item(), "Position grid contains Inf"
class TestCFGDelta:
"""Tests for CFG (Classifier-Free Guidance) delta calculation."""
def test_cfg_delta_shape(self):
"""CFG delta should have same shape as inputs."""
shape = (1, 1920, 128)
cond = mx.random.normal(shape)
uncond = mx.random.normal(shape)
delta = cfg_delta(cond, uncond, scale=4.0)
assert delta.shape == shape
def test_cfg_delta_scale_one(self):
"""CFG with scale=1.0 should return zero delta."""
shape = (1, 1920, 128)
cond = mx.random.normal(shape)
uncond = mx.random.normal(shape)
mx.eval(cond, uncond)
delta = cfg_delta(cond, uncond, scale=1.0)
mx.eval(delta)
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero"
def test_cfg_delta_formula(self):
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
cond = mx.array([[[1.0, 2.0, 3.0]]])
uncond = mx.array([[[0.5, 1.0, 1.5]]])
scale = 4.0
delta = cfg_delta(cond, uncond, scale)
expected = (scale - 1.0) * (cond - uncond)
mx.eval(delta, expected)
diff = mx.max(mx.abs(delta - expected)).item()
assert diff < 1e-6, f"CFG delta formula mismatch: diff={diff}"
def test_cfg_delta_dtype_preservation(self):
"""CFG delta should preserve input dtype."""
for dtype in [mx.float32, mx.bfloat16]:
cond = mx.random.normal((1, 100, 64)).astype(dtype)
uncond = mx.random.normal((1, 100, 64)).astype(dtype)
delta = cfg_delta(cond, uncond, scale=4.0)
assert delta.dtype == dtype, f"Expected {dtype}, got {delta.dtype}"
class TestDefaultNegativePrompt:
"""Tests for the default negative prompt."""
def test_default_negative_prompt_exists(self):
"""Default negative prompt should be defined."""
assert DEFAULT_NEGATIVE_PROMPT is not None
assert len(DEFAULT_NEGATIVE_PROMPT) > 0
def test_default_negative_prompt_contains_quality_terms(self):
"""Default negative prompt should contain quality-related terms."""
prompt_lower = DEFAULT_NEGATIVE_PROMPT.lower()
# Check for common negative quality terms
assert "blurry" in prompt_lower, "Should contain 'blurry'"
assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \
"Should contain quality-related terms"
class TestInputValidation:
"""Tests for input validation in generate_video_dev."""
def test_height_divisible_by_32(self):
"""Height must be divisible by 32."""
# This would be tested via the actual function, but we can test the validation logic
valid_heights = [256, 384, 512, 640, 768]
invalid_heights = [100, 300, 500, 700]
for h in valid_heights:
assert h % 32 == 0, f"Height {h} should be valid"
for h in invalid_heights:
assert h % 32 != 0, f"Height {h} should be invalid"
def test_width_divisible_by_32(self):
"""Width must be divisible by 32."""
valid_widths = [256, 384, 512, 640, 768, 1024]
invalid_widths = [100, 300, 500, 700]
for w in valid_widths:
assert w % 32 == 0, f"Width {w} should be valid"
for w in invalid_widths:
assert w % 32 != 0, f"Width {w} should be invalid"
def test_num_frames_formula(self):
"""Number of frames should be 1 + 8*k."""
valid_frames = [1, 9, 17, 25, 33, 41, 49, 57, 65]
for f in valid_frames:
assert (f - 1) % 8 == 0, f"Frames {f} should be valid (1 + 8*k)"
def test_num_frames_adjustment(self):
"""Invalid frame counts should be adjusted to nearest valid value."""
# Test the adjustment logic
test_cases = [
(30, 33), # 30 -> nearest valid is 33
(35, 33), # 35 -> nearest valid is 33
(40, 41), # 40 -> nearest valid is 41
(1, 1), # 1 is already valid
(33, 33), # 33 is already valid
]
for input_frames, expected in test_cases:
if input_frames % 8 != 1:
adjusted = round((input_frames - 1) / 8) * 8 + 1
assert adjusted == expected, \
f"Expected {expected} for input {input_frames}, got {adjusted}"
class TestDenoiseWithCFGMocked:
"""Tests for denoise_with_cfg with mocked transformer."""
def test_sigmas_list_conversion(self):
"""Sigmas should be convertible to list."""
sigmas = ltx2_scheduler(steps=5)
sigmas_list = sigmas.tolist()
assert isinstance(sigmas_list, list)
assert len(sigmas_list) == 6 # steps + 1
class TestTilingDefault:
"""Tests for tiling default behavior."""
def test_tiling_default_is_none(self):
"""Default tiling should be 'none' for performance."""
import inspect
from mlx_video.generate_dev import generate_video_dev
sig = inspect.signature(generate_video_dev)
tiling_param = sig.parameters.get('tiling')
assert tiling_param is not None
assert tiling_param.default == "none", \
f"Expected default tiling='none', got '{tiling_param.default}'"
class TestLatentDimensions:
"""Tests for latent dimension calculations."""
def test_latent_height_calculation(self):
"""Latent height should be height // 32."""
test_cases = [(512, 16), (768, 24), (1024, 32)]
for height, expected_latent_h in test_cases:
latent_h = height // 32
assert latent_h == expected_latent_h, \
f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
def test_latent_width_calculation(self):
"""Latent width should be width // 32."""
test_cases = [(512, 16), (768, 24), (1024, 32)]
for width, expected_latent_w in test_cases:
latent_w = width // 32
assert latent_w == expected_latent_w, \
f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
def test_latent_frames_calculation(self):
"""Latent frames should be 1 + (num_frames - 1) // 8."""
test_cases = [(1, 1), (9, 2), (17, 3), (33, 5), (65, 9)]
for num_frames, expected_latent_f in test_cases:
latent_f = 1 + (num_frames - 1) // 8
assert latent_f == expected_latent_f, \
f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
def test_num_tokens_calculation(self):
"""Number of tokens should be latent_f * latent_h * latent_w."""
# For 33 frames at 512x768
num_frames, height, width = 33, 512, 768
latent_f = 1 + (num_frames - 1) // 8 # 5
latent_h = height // 32 # 16
latent_w = width // 32 # 24
num_tokens = latent_f * latent_h * latent_w
expected = 5 * 16 * 24 # 1920
assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}"
class TestAudioPositionGrid:
"""Tests for audio position grid creation."""
def test_audio_position_grid_shape(self):
"""Audio position grid should have correct shape (B, 1, T, 2)."""
batch_size = 1
audio_frames = 34 # ~1.36 seconds at 25 latent frames/sec
positions = create_audio_position_grid(batch_size, audio_frames)
expected_shape = (batch_size, 1, audio_frames, 2)
assert positions.shape == expected_shape, \
f"Expected {expected_shape}, got {positions.shape}"
def test_audio_position_grid_dtype(self):
"""Audio position grid should be float32."""
positions = create_audio_position_grid(1, 34)
assert positions.dtype == mx.float32, \
f"Expected float32, got {positions.dtype}"
def test_audio_position_grid_batch_size(self):
"""Audio position grid should respect batch size."""
for batch_size in [1, 2, 4]:
positions = create_audio_position_grid(batch_size, 34)
assert positions.shape[0] == batch_size
def test_audio_position_grid_temporal_values(self):
"""Audio positions should be in seconds."""
positions = create_audio_position_grid(1, 34)
# Values should be in seconds (small values for ~1 second of audio)
max_val = mx.max(positions).item()
assert max_val < 10, f"Audio positions seem too large: {max_val}"
assert max_val > 0, "Audio positions should be positive"
def test_audio_position_grid_no_nan_or_inf(self):
"""Audio position grid should not contain NaN or Inf."""
positions = create_audio_position_grid(1, 34)
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
class TestComputeAudioFrames:
"""Tests for audio frame count calculation."""
def test_audio_frames_basic(self):
"""Audio frames should be proportional to video duration."""
# 33 frames at 24 fps = ~1.375 seconds
# At 25 latent frames/sec = ~34 audio frames
audio_frames = compute_audio_frames(33, 24.0)
assert audio_frames > 0
assert isinstance(audio_frames, int)
def test_audio_frames_scales_with_video(self):
"""More video frames should produce more audio frames."""
audio_33 = compute_audio_frames(33, 24.0)
audio_65 = compute_audio_frames(65, 24.0)
assert audio_65 > audio_33, \
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
def test_audio_frames_formula(self):
"""Audio frames should match expected formula."""
num_video_frames = 33
fps = 24.0
duration = num_video_frames / fps # ~1.375 seconds
expected = round(duration * AUDIO_LATENTS_PER_SECOND)
actual = compute_audio_frames(num_video_frames, fps)
assert actual == expected, f"Expected {expected}, got {actual}"
class TestAudioConstants:
"""Tests for audio constants."""
def test_audio_sample_rate(self):
"""Audio sample rate should be 24000 Hz."""
assert AUDIO_SAMPLE_RATE == 24000
def test_audio_latents_per_second(self):
"""Audio latents per second should be 25."""
assert AUDIO_LATENTS_PER_SECOND == 25.0
if __name__ == "__main__":
pytest.main([__file__, "-v"])