444 lines
16 KiB
Python
444 lines
16 KiB
Python
"""Tests for LTX-2 dev model generation pipeline."""
|
|
|
|
import mlx.core as mx
|
|
import pytest
|
|
|
|
from mlx_video.generate_dev import (
|
|
AUDIO_LATENTS_PER_SECOND,
|
|
AUDIO_SAMPLE_RATE,
|
|
DEFAULT_NEGATIVE_PROMPT,
|
|
cfg_delta,
|
|
compute_audio_frames,
|
|
create_audio_position_grid,
|
|
create_position_grid,
|
|
ltx2_scheduler,
|
|
)
|
|
|
|
|
|
class TestLTX2Scheduler:
|
|
"""Tests for the LTX-2 sigma scheduler."""
|
|
|
|
def test_scheduler_output_shape(self):
|
|
"""Scheduler should return steps+1 sigma values."""
|
|
steps = 20
|
|
sigmas = ltx2_scheduler(steps=steps)
|
|
assert sigmas.shape == (
|
|
steps + 1,
|
|
), f"Expected ({steps + 1},), got {sigmas.shape}"
|
|
|
|
def test_scheduler_starts_at_one(self):
|
|
"""Sigma schedule should start at 1.0."""
|
|
sigmas = ltx2_scheduler(steps=20)
|
|
assert (
|
|
abs(sigmas[0].item() - 1.0) < 1e-5
|
|
), f"Expected 1.0, got {sigmas[0].item()}"
|
|
|
|
def test_scheduler_ends_at_zero(self):
|
|
"""Sigma schedule should end at 0.0."""
|
|
sigmas = ltx2_scheduler(steps=20)
|
|
assert abs(sigmas[-1].item()) < 1e-5, f"Expected 0.0, got {sigmas[-1].item()}"
|
|
|
|
def test_scheduler_monotonically_decreasing(self):
|
|
"""Sigma values should monotonically decrease."""
|
|
sigmas = ltx2_scheduler(steps=20)
|
|
sigmas_list = sigmas.tolist()
|
|
for i in range(len(sigmas_list) - 1):
|
|
assert (
|
|
sigmas_list[i] >= sigmas_list[i + 1]
|
|
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
|
|
|
def test_scheduler_dtype(self):
|
|
"""Scheduler should return float32 array."""
|
|
sigmas = ltx2_scheduler(steps=20)
|
|
assert sigmas.dtype == mx.float32, f"Expected float32, got {sigmas.dtype}"
|
|
|
|
def test_scheduler_with_num_tokens(self):
|
|
"""Scheduler should accept num_tokens parameter."""
|
|
sigmas_default = ltx2_scheduler(steps=20, num_tokens=None)
|
|
sigmas_custom = ltx2_scheduler(steps=20, num_tokens=1920)
|
|
|
|
# Both should be valid arrays
|
|
assert sigmas_default.shape == (21,)
|
|
assert sigmas_custom.shape == (21,)
|
|
|
|
def test_scheduler_no_stretch(self):
|
|
"""Scheduler without stretching should still work."""
|
|
sigmas = ltx2_scheduler(steps=20, stretch=False)
|
|
assert sigmas.shape == (21,)
|
|
assert sigmas[0].item() > 0
|
|
assert sigmas[-1].item() == 0.0
|
|
|
|
def test_scheduler_different_steps(self):
|
|
"""Scheduler should work with different step counts."""
|
|
for steps in [5, 10, 20, 40, 50]:
|
|
sigmas = ltx2_scheduler(steps=steps)
|
|
assert sigmas.shape == (steps + 1,), f"Failed for steps={steps}"
|
|
|
|
|
|
class TestCreatePositionGrid:
|
|
"""Tests for position grid creation."""
|
|
|
|
def test_position_grid_shape(self):
|
|
"""Position grid should have correct shape (B, 3, num_patches, 2)."""
|
|
batch_size = 1
|
|
num_frames = 5
|
|
height = 16
|
|
width = 24
|
|
|
|
positions = create_position_grid(batch_size, num_frames, height, width)
|
|
num_patches = num_frames * height * width
|
|
|
|
expected_shape = (batch_size, 3, num_patches, 2)
|
|
assert (
|
|
positions.shape == expected_shape
|
|
), f"Expected {expected_shape}, got {positions.shape}"
|
|
|
|
def test_position_grid_dtype(self):
|
|
"""Position grid should be float32 for RoPE precision."""
|
|
positions = create_position_grid(1, 5, 16, 24)
|
|
assert (
|
|
positions.dtype == mx.float32
|
|
), f"Expected float32 for RoPE precision, got {positions.dtype}"
|
|
|
|
def test_position_grid_batch_size(self):
|
|
"""Position grid should respect batch size."""
|
|
for batch_size in [1, 2, 4]:
|
|
positions = create_position_grid(batch_size, 5, 16, 24)
|
|
assert positions.shape[0] == batch_size
|
|
|
|
def test_position_grid_temporal_dimension(self):
|
|
"""Temporal dimension should have values scaled by fps."""
|
|
positions = create_position_grid(1, 5, 16, 24, fps=24.0)
|
|
temporal = positions[0, 0, :, :] # (num_patches, 2)
|
|
|
|
# Values should be in seconds (divided by fps)
|
|
max_temporal = mx.max(temporal).item()
|
|
# For 5 latent frames at scale 8, max pixel frame ~ 40, divided by 24 fps ~ 1.67s
|
|
assert max_temporal < 10, f"Temporal values too large: {max_temporal}"
|
|
|
|
def test_position_grid_spatial_dimensions(self):
|
|
"""Spatial dimensions should have pixel-space values."""
|
|
positions = create_position_grid(1, 5, 16, 24, spatial_scale=32)
|
|
|
|
# Height dimension
|
|
height_vals = positions[0, 1, :, :]
|
|
max_height = mx.max(height_vals).item()
|
|
# 16 latent * 32 scale = 512 pixels
|
|
assert max_height <= 512, f"Height values too large: {max_height}"
|
|
|
|
# Width dimension
|
|
width_vals = positions[0, 2, :, :]
|
|
max_width = mx.max(width_vals).item()
|
|
# 24 latent * 32 scale = 768 pixels
|
|
assert max_width <= 768, f"Width values too large: {max_width}"
|
|
|
|
def test_position_grid_causal_fix(self):
|
|
"""Causal fix should adjust first frame temporal values."""
|
|
positions_causal = create_position_grid(1, 5, 16, 24, causal_fix=True)
|
|
positions_no_causal = create_position_grid(1, 5, 16, 24, causal_fix=False)
|
|
|
|
# They should be different due to causal fix
|
|
diff = mx.abs(positions_causal - positions_no_causal)
|
|
assert mx.max(diff).item() > 0, "Causal fix should change position values"
|
|
|
|
def test_position_grid_no_nan_or_inf(self):
|
|
"""Position grid should not contain NaN or Inf values."""
|
|
positions = create_position_grid(1, 5, 16, 24)
|
|
|
|
assert not mx.any(mx.isnan(positions)).item(), "Position grid contains NaN"
|
|
assert not mx.any(mx.isinf(positions)).item(), "Position grid contains Inf"
|
|
|
|
|
|
class TestCFGDelta:
|
|
"""Tests for CFG (Classifier-Free Guidance) delta calculation."""
|
|
|
|
def test_cfg_delta_shape(self):
|
|
"""CFG delta should have same shape as inputs."""
|
|
shape = (1, 1920, 128)
|
|
cond = mx.random.normal(shape)
|
|
uncond = mx.random.normal(shape)
|
|
|
|
delta = cfg_delta(cond, uncond, scale=4.0)
|
|
assert delta.shape == shape
|
|
|
|
def test_cfg_delta_scale_one(self):
|
|
"""CFG with scale=1.0 should return zero delta."""
|
|
shape = (1, 1920, 128)
|
|
cond = mx.random.normal(shape)
|
|
uncond = mx.random.normal(shape)
|
|
mx.eval(cond, uncond)
|
|
|
|
delta = cfg_delta(cond, uncond, scale=1.0)
|
|
mx.eval(delta)
|
|
|
|
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
|
|
assert (
|
|
mx.max(mx.abs(delta)).item() < 1e-6
|
|
), "CFG delta with scale=1.0 should be zero"
|
|
|
|
def test_cfg_delta_formula(self):
|
|
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
|
|
cond = mx.array([[[1.0, 2.0, 3.0]]])
|
|
uncond = mx.array([[[0.5, 1.0, 1.5]]])
|
|
scale = 4.0
|
|
|
|
delta = cfg_delta(cond, uncond, scale)
|
|
expected = (scale - 1.0) * (cond - uncond)
|
|
|
|
mx.eval(delta, expected)
|
|
diff = mx.max(mx.abs(delta - expected)).item()
|
|
assert diff < 1e-6, f"CFG delta formula mismatch: diff={diff}"
|
|
|
|
def test_cfg_delta_dtype_preservation(self):
|
|
"""CFG delta should preserve input dtype."""
|
|
for dtype in [mx.float32, mx.bfloat16]:
|
|
cond = mx.random.normal((1, 100, 64)).astype(dtype)
|
|
uncond = mx.random.normal((1, 100, 64)).astype(dtype)
|
|
|
|
delta = cfg_delta(cond, uncond, scale=4.0)
|
|
assert delta.dtype == dtype, f"Expected {dtype}, got {delta.dtype}"
|
|
|
|
|
|
class TestDefaultNegativePrompt:
|
|
"""Tests for the default negative prompt."""
|
|
|
|
def test_default_negative_prompt_exists(self):
|
|
"""Default negative prompt should be defined."""
|
|
assert DEFAULT_NEGATIVE_PROMPT is not None
|
|
assert len(DEFAULT_NEGATIVE_PROMPT) > 0
|
|
|
|
def test_default_negative_prompt_contains_quality_terms(self):
|
|
"""Default negative prompt should contain quality-related terms."""
|
|
prompt_lower = DEFAULT_NEGATIVE_PROMPT.lower()
|
|
|
|
# Check for common negative quality terms
|
|
assert "blurry" in prompt_lower, "Should contain 'blurry'"
|
|
assert (
|
|
"low quality" in prompt_lower or "low contrast" in prompt_lower
|
|
), "Should contain quality-related terms"
|
|
|
|
|
|
class TestInputValidation:
|
|
"""Tests for input validation in generate_video_dev."""
|
|
|
|
def test_height_divisible_by_32(self):
|
|
"""Height must be divisible by 32."""
|
|
# This would be tested via the actual function, but we can test the validation logic
|
|
valid_heights = [256, 384, 512, 640, 768]
|
|
invalid_heights = [100, 300, 500, 700]
|
|
|
|
for h in valid_heights:
|
|
assert h % 32 == 0, f"Height {h} should be valid"
|
|
|
|
for h in invalid_heights:
|
|
assert h % 32 != 0, f"Height {h} should be invalid"
|
|
|
|
def test_width_divisible_by_32(self):
|
|
"""Width must be divisible by 32."""
|
|
valid_widths = [256, 384, 512, 640, 768, 1024]
|
|
invalid_widths = [100, 300, 500, 700]
|
|
|
|
for w in valid_widths:
|
|
assert w % 32 == 0, f"Width {w} should be valid"
|
|
|
|
for w in invalid_widths:
|
|
assert w % 32 != 0, f"Width {w} should be invalid"
|
|
|
|
def test_num_frames_formula(self):
|
|
"""Number of frames should be 1 + 8*k."""
|
|
valid_frames = [1, 9, 17, 25, 33, 41, 49, 57, 65]
|
|
|
|
for f in valid_frames:
|
|
assert (f - 1) % 8 == 0, f"Frames {f} should be valid (1 + 8*k)"
|
|
|
|
def test_num_frames_adjustment(self):
|
|
"""Invalid frame counts should be adjusted to nearest valid value."""
|
|
# Test the adjustment logic
|
|
test_cases = [
|
|
(30, 33), # 30 -> nearest valid is 33
|
|
(35, 33), # 35 -> nearest valid is 33
|
|
(40, 41), # 40 -> nearest valid is 41
|
|
(1, 1), # 1 is already valid
|
|
(33, 33), # 33 is already valid
|
|
]
|
|
|
|
for input_frames, expected in test_cases:
|
|
if input_frames % 8 != 1:
|
|
adjusted = round((input_frames - 1) / 8) * 8 + 1
|
|
assert (
|
|
adjusted == expected
|
|
), f"Expected {expected} for input {input_frames}, got {adjusted}"
|
|
|
|
|
|
class TestDenoiseWithCFGMocked:
|
|
"""Tests for denoise_with_cfg with mocked transformer."""
|
|
|
|
def test_sigmas_list_conversion(self):
|
|
"""Sigmas should be convertible to list."""
|
|
sigmas = ltx2_scheduler(steps=5)
|
|
sigmas_list = sigmas.tolist()
|
|
|
|
assert isinstance(sigmas_list, list)
|
|
assert len(sigmas_list) == 6 # steps + 1
|
|
|
|
|
|
class TestTilingDefault:
|
|
"""Tests for tiling default behavior."""
|
|
|
|
def test_tiling_default_is_none(self):
|
|
"""Default tiling should be 'none' for performance."""
|
|
import inspect
|
|
|
|
from mlx_video.generate_dev import generate_video_dev
|
|
|
|
sig = inspect.signature(generate_video_dev)
|
|
|
|
tiling_param = sig.parameters.get("tiling")
|
|
assert tiling_param is not None
|
|
assert (
|
|
tiling_param.default == "none"
|
|
), f"Expected default tiling='none', got '{tiling_param.default}'"
|
|
|
|
|
|
class TestLatentDimensions:
|
|
"""Tests for latent dimension calculations."""
|
|
|
|
def test_latent_height_calculation(self):
|
|
"""Latent height should be height // 32."""
|
|
test_cases = [(512, 16), (768, 24), (1024, 32)]
|
|
|
|
for height, expected_latent_h in test_cases:
|
|
latent_h = height // 32
|
|
assert (
|
|
latent_h == expected_latent_h
|
|
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
|
|
|
def test_latent_width_calculation(self):
|
|
"""Latent width should be width // 32."""
|
|
test_cases = [(512, 16), (768, 24), (1024, 32)]
|
|
|
|
for width, expected_latent_w in test_cases:
|
|
latent_w = width // 32
|
|
assert (
|
|
latent_w == expected_latent_w
|
|
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
|
|
|
def test_latent_frames_calculation(self):
|
|
"""Latent frames should be 1 + (num_frames - 1) // 8."""
|
|
test_cases = [(1, 1), (9, 2), (17, 3), (33, 5), (65, 9)]
|
|
|
|
for num_frames, expected_latent_f in test_cases:
|
|
latent_f = 1 + (num_frames - 1) // 8
|
|
assert (
|
|
latent_f == expected_latent_f
|
|
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
|
|
|
def test_num_tokens_calculation(self):
|
|
"""Number of tokens should be latent_f * latent_h * latent_w."""
|
|
# For 33 frames at 512x768
|
|
num_frames, height, width = 33, 512, 768
|
|
|
|
latent_f = 1 + (num_frames - 1) // 8 # 5
|
|
latent_h = height // 32 # 16
|
|
latent_w = width // 32 # 24
|
|
|
|
num_tokens = latent_f * latent_h * latent_w
|
|
expected = 5 * 16 * 24 # 1920
|
|
|
|
assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}"
|
|
|
|
|
|
class TestAudioPositionGrid:
|
|
"""Tests for audio position grid creation."""
|
|
|
|
def test_audio_position_grid_shape(self):
|
|
"""Audio position grid should have correct shape (B, 1, T, 2)."""
|
|
batch_size = 1
|
|
audio_frames = 34 # ~1.36 seconds at 25 latent frames/sec
|
|
|
|
positions = create_audio_position_grid(batch_size, audio_frames)
|
|
expected_shape = (batch_size, 1, audio_frames, 2)
|
|
|
|
assert (
|
|
positions.shape == expected_shape
|
|
), f"Expected {expected_shape}, got {positions.shape}"
|
|
|
|
def test_audio_position_grid_dtype(self):
|
|
"""Audio position grid should be float32."""
|
|
positions = create_audio_position_grid(1, 34)
|
|
assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
|
|
|
|
def test_audio_position_grid_batch_size(self):
|
|
"""Audio position grid should respect batch size."""
|
|
for batch_size in [1, 2, 4]:
|
|
positions = create_audio_position_grid(batch_size, 34)
|
|
assert positions.shape[0] == batch_size
|
|
|
|
def test_audio_position_grid_temporal_values(self):
|
|
"""Audio positions should be in seconds."""
|
|
positions = create_audio_position_grid(1, 34)
|
|
|
|
# Values should be in seconds (small values for ~1 second of audio)
|
|
max_val = mx.max(positions).item()
|
|
assert max_val < 10, f"Audio positions seem too large: {max_val}"
|
|
assert max_val > 0, "Audio positions should be positive"
|
|
|
|
def test_audio_position_grid_no_nan_or_inf(self):
|
|
"""Audio position grid should not contain NaN or Inf."""
|
|
positions = create_audio_position_grid(1, 34)
|
|
|
|
assert not mx.any(
|
|
mx.isnan(positions)
|
|
).item(), "Audio position grid contains NaN"
|
|
assert not mx.any(
|
|
mx.isinf(positions)
|
|
).item(), "Audio position grid contains Inf"
|
|
|
|
|
|
class TestComputeAudioFrames:
|
|
"""Tests for audio frame count calculation."""
|
|
|
|
def test_audio_frames_basic(self):
|
|
"""Audio frames should be proportional to video duration."""
|
|
# 33 frames at 24 fps = ~1.375 seconds
|
|
# At 25 latent frames/sec = ~34 audio frames
|
|
audio_frames = compute_audio_frames(33, 24.0)
|
|
assert audio_frames > 0
|
|
assert isinstance(audio_frames, int)
|
|
|
|
def test_audio_frames_scales_with_video(self):
|
|
"""More video frames should produce more audio frames."""
|
|
audio_33 = compute_audio_frames(33, 24.0)
|
|
audio_65 = compute_audio_frames(65, 24.0)
|
|
|
|
assert (
|
|
audio_65 > audio_33
|
|
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
|
|
|
def test_audio_frames_formula(self):
|
|
"""Audio frames should match expected formula."""
|
|
num_video_frames = 33
|
|
fps = 24.0
|
|
|
|
duration = num_video_frames / fps # ~1.375 seconds
|
|
expected = round(duration * AUDIO_LATENTS_PER_SECOND)
|
|
|
|
actual = compute_audio_frames(num_video_frames, fps)
|
|
assert actual == expected, f"Expected {expected}, got {actual}"
|
|
|
|
|
|
class TestAudioConstants:
|
|
"""Tests for audio constants."""
|
|
|
|
def test_audio_sample_rate(self):
|
|
"""Audio sample rate should be 24000 Hz."""
|
|
assert AUDIO_SAMPLE_RATE == 24000
|
|
|
|
def test_audio_latents_per_second(self):
|
|
"""Audio latents per second should be 25."""
|
|
assert AUDIO_LATENTS_PER_SECOND == 25.0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|