This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,17 +1,17 @@
"""Tests for LTX-2 dev model generation pipeline."""
import pytest
import mlx.core as mx
import pytest
from mlx_video.generate_dev import (
ltx2_scheduler,
create_position_grid,
create_audio_position_grid,
compute_audio_frames,
cfg_delta,
DEFAULT_NEGATIVE_PROMPT,
AUDIO_SAMPLE_RATE,
AUDIO_LATENTS_PER_SECOND,
AUDIO_SAMPLE_RATE,
DEFAULT_NEGATIVE_PROMPT,
cfg_delta,
compute_audio_frames,
create_audio_position_grid,
create_position_grid,
ltx2_scheduler,
)
@@ -22,12 +22,16 @@ class TestLTX2Scheduler:
"""Scheduler should return steps+1 sigma values."""
steps = 20
sigmas = ltx2_scheduler(steps=steps)
assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}"
assert sigmas.shape == (
steps + 1,
), f"Expected ({steps + 1},), got {sigmas.shape}"
def test_scheduler_starts_at_one(self):
"""Sigma schedule should start at 1.0."""
sigmas = ltx2_scheduler(steps=20)
assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}"
assert (
abs(sigmas[0].item() - 1.0) < 1e-5
), f"Expected 1.0, got {sigmas[0].item()}"
def test_scheduler_ends_at_zero(self):
"""Sigma schedule should end at 0.0."""
@@ -39,8 +43,9 @@ class TestLTX2Scheduler:
sigmas = ltx2_scheduler(steps=20)
sigmas_list = sigmas.tolist()
for i in range(len(sigmas_list) - 1):
assert sigmas_list[i] >= sigmas_list[i + 1], \
f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
assert (
sigmas_list[i] >= sigmas_list[i + 1]
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
def test_scheduler_dtype(self):
"""Scheduler should return float32 array."""
@@ -84,14 +89,16 @@ class TestCreatePositionGrid:
num_patches = num_frames * height * width
expected_shape = (batch_size, 3, num_patches, 2)
assert positions.shape == expected_shape, \
f"Expected {expected_shape}, got {positions.shape}"
assert (
positions.shape == expected_shape
), f"Expected {expected_shape}, got {positions.shape}"
def test_position_grid_dtype(self):
"""Position grid should be float32 for RoPE precision."""
positions = create_position_grid(1, 5, 16, 24)
assert positions.dtype == mx.float32, \
f"Expected float32 for RoPE precision, got {positions.dtype}"
assert (
positions.dtype == mx.float32
), f"Expected float32 for RoPE precision, got {positions.dtype}"
def test_position_grid_batch_size(self):
"""Position grid should respect batch size."""
@@ -165,7 +172,9 @@ class TestCFGDelta:
mx.eval(delta)
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero"
assert (
mx.max(mx.abs(delta)).item() < 1e-6
), "CFG delta with scale=1.0 should be zero"
def test_cfg_delta_formula(self):
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
@@ -204,8 +213,9 @@ class TestDefaultNegativePrompt:
# Check for common negative quality terms
assert "blurry" in prompt_lower, "Should contain 'blurry'"
assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \
"Should contain quality-related terms"
assert (
"low quality" in prompt_lower or "low contrast" in prompt_lower
), "Should contain quality-related terms"
class TestInputValidation:
@@ -248,15 +258,16 @@ class TestInputValidation:
(30, 33), # 30 -> nearest valid is 33
(35, 33), # 35 -> nearest valid is 33
(40, 41), # 40 -> nearest valid is 41
(1, 1), # 1 is already valid
(1, 1), # 1 is already valid
(33, 33), # 33 is already valid
]
for input_frames, expected in test_cases:
if input_frames % 8 != 1:
adjusted = round((input_frames - 1) / 8) * 8 + 1
assert adjusted == expected, \
f"Expected {expected} for input {input_frames}, got {adjusted}"
assert (
adjusted == expected
), f"Expected {expected} for input {input_frames}, got {adjusted}"
class TestDenoiseWithCFGMocked:
@@ -277,14 +288,16 @@ class TestTilingDefault:
def test_tiling_default_is_none(self):
"""Default tiling should be 'none' for performance."""
import inspect
from mlx_video.generate_dev import generate_video_dev
sig = inspect.signature(generate_video_dev)
tiling_param = sig.parameters.get('tiling')
tiling_param = sig.parameters.get("tiling")
assert tiling_param is not None
assert tiling_param.default == "none", \
f"Expected default tiling='none', got '{tiling_param.default}'"
assert (
tiling_param.default == "none"
), f"Expected default tiling='none', got '{tiling_param.default}'"
class TestLatentDimensions:
@@ -296,8 +309,9 @@ class TestLatentDimensions:
for height, expected_latent_h in test_cases:
latent_h = height // 32
assert latent_h == expected_latent_h, \
f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
assert (
latent_h == expected_latent_h
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
def test_latent_width_calculation(self):
"""Latent width should be width // 32."""
@@ -305,8 +319,9 @@ class TestLatentDimensions:
for width, expected_latent_w in test_cases:
latent_w = width // 32
assert latent_w == expected_latent_w, \
f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
assert (
latent_w == expected_latent_w
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
def test_latent_frames_calculation(self):
"""Latent frames should be 1 + (num_frames - 1) // 8."""
@@ -314,8 +329,9 @@ class TestLatentDimensions:
for num_frames, expected_latent_f in test_cases:
latent_f = 1 + (num_frames - 1) // 8
assert latent_f == expected_latent_f, \
f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
assert (
latent_f == expected_latent_f
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
def test_num_tokens_calculation(self):
"""Number of tokens should be latent_f * latent_h * latent_w."""
@@ -343,14 +359,14 @@ class TestAudioPositionGrid:
positions = create_audio_position_grid(batch_size, audio_frames)
expected_shape = (batch_size, 1, audio_frames, 2)
assert positions.shape == expected_shape, \
f"Expected {expected_shape}, got {positions.shape}"
assert (
positions.shape == expected_shape
), f"Expected {expected_shape}, got {positions.shape}"
def test_audio_position_grid_dtype(self):
"""Audio position grid should be float32."""
positions = create_audio_position_grid(1, 34)
assert positions.dtype == mx.float32, \
f"Expected float32, got {positions.dtype}"
assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
def test_audio_position_grid_batch_size(self):
"""Audio position grid should respect batch size."""
@@ -371,8 +387,12 @@ class TestAudioPositionGrid:
"""Audio position grid should not contain NaN or Inf."""
positions = create_audio_position_grid(1, 34)
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
assert not mx.any(
mx.isnan(positions)
).item(), "Audio position grid contains NaN"
assert not mx.any(
mx.isinf(positions)
).item(), "Audio position grid contains Inf"
class TestComputeAudioFrames:
@@ -391,8 +411,9 @@ class TestComputeAudioFrames:
audio_33 = compute_audio_frames(33, 24.0)
audio_65 = compute_audio_frames(65, 24.0)
assert audio_65 > audio_33, \
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
assert (
audio_65 > audio_33
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
def test_audio_frames_formula(self):
"""Audio frames should match expected formula."""