format
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
"""Tests for LTX-2 dev model generation pipeline."""
|
||||
|
||||
import pytest
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from mlx_video.generate_dev import (
|
||||
ltx2_scheduler,
|
||||
create_position_grid,
|
||||
create_audio_position_grid,
|
||||
compute_audio_frames,
|
||||
cfg_delta,
|
||||
DEFAULT_NEGATIVE_PROMPT,
|
||||
AUDIO_SAMPLE_RATE,
|
||||
AUDIO_LATENTS_PER_SECOND,
|
||||
AUDIO_SAMPLE_RATE,
|
||||
DEFAULT_NEGATIVE_PROMPT,
|
||||
cfg_delta,
|
||||
compute_audio_frames,
|
||||
create_audio_position_grid,
|
||||
create_position_grid,
|
||||
ltx2_scheduler,
|
||||
)
|
||||
|
||||
|
||||
@@ -22,12 +22,16 @@ class TestLTX2Scheduler:
|
||||
"""Scheduler should return steps+1 sigma values."""
|
||||
steps = 20
|
||||
sigmas = ltx2_scheduler(steps=steps)
|
||||
assert sigmas.shape == (steps + 1,), f"Expected ({steps + 1},), got {sigmas.shape}"
|
||||
assert sigmas.shape == (
|
||||
steps + 1,
|
||||
), f"Expected ({steps + 1},), got {sigmas.shape}"
|
||||
|
||||
def test_scheduler_starts_at_one(self):
|
||||
"""Sigma schedule should start at 1.0."""
|
||||
sigmas = ltx2_scheduler(steps=20)
|
||||
assert abs(sigmas[0].item() - 1.0) < 1e-5, f"Expected 1.0, got {sigmas[0].item()}"
|
||||
assert (
|
||||
abs(sigmas[0].item() - 1.0) < 1e-5
|
||||
), f"Expected 1.0, got {sigmas[0].item()}"
|
||||
|
||||
def test_scheduler_ends_at_zero(self):
|
||||
"""Sigma schedule should end at 0.0."""
|
||||
@@ -39,8 +43,9 @@ class TestLTX2Scheduler:
|
||||
sigmas = ltx2_scheduler(steps=20)
|
||||
sigmas_list = sigmas.tolist()
|
||||
for i in range(len(sigmas_list) - 1):
|
||||
assert sigmas_list[i] >= sigmas_list[i + 1], \
|
||||
f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
||||
assert (
|
||||
sigmas_list[i] >= sigmas_list[i + 1]
|
||||
), f"Sigma not decreasing at index {i}: {sigmas_list[i]} < {sigmas_list[i + 1]}"
|
||||
|
||||
def test_scheduler_dtype(self):
|
||||
"""Scheduler should return float32 array."""
|
||||
@@ -84,14 +89,16 @@ class TestCreatePositionGrid:
|
||||
num_patches = num_frames * height * width
|
||||
|
||||
expected_shape = (batch_size, 3, num_patches, 2)
|
||||
assert positions.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {positions.shape}"
|
||||
assert (
|
||||
positions.shape == expected_shape
|
||||
), f"Expected {expected_shape}, got {positions.shape}"
|
||||
|
||||
def test_position_grid_dtype(self):
|
||||
"""Position grid should be float32 for RoPE precision."""
|
||||
positions = create_position_grid(1, 5, 16, 24)
|
||||
assert positions.dtype == mx.float32, \
|
||||
f"Expected float32 for RoPE precision, got {positions.dtype}"
|
||||
assert (
|
||||
positions.dtype == mx.float32
|
||||
), f"Expected float32 for RoPE precision, got {positions.dtype}"
|
||||
|
||||
def test_position_grid_batch_size(self):
|
||||
"""Position grid should respect batch size."""
|
||||
@@ -165,7 +172,9 @@ class TestCFGDelta:
|
||||
mx.eval(delta)
|
||||
|
||||
# Scale=1.0 means (1.0 - 1.0) * (cond - uncond) = 0
|
||||
assert mx.max(mx.abs(delta)).item() < 1e-6, "CFG delta with scale=1.0 should be zero"
|
||||
assert (
|
||||
mx.max(mx.abs(delta)).item() < 1e-6
|
||||
), "CFG delta with scale=1.0 should be zero"
|
||||
|
||||
def test_cfg_delta_formula(self):
|
||||
"""CFG delta should follow the formula: (scale-1) * (cond - uncond)."""
|
||||
@@ -204,8 +213,9 @@ class TestDefaultNegativePrompt:
|
||||
|
||||
# Check for common negative quality terms
|
||||
assert "blurry" in prompt_lower, "Should contain 'blurry'"
|
||||
assert "low quality" in prompt_lower or "low contrast" in prompt_lower, \
|
||||
"Should contain quality-related terms"
|
||||
assert (
|
||||
"low quality" in prompt_lower or "low contrast" in prompt_lower
|
||||
), "Should contain quality-related terms"
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
@@ -248,15 +258,16 @@ class TestInputValidation:
|
||||
(30, 33), # 30 -> nearest valid is 33
|
||||
(35, 33), # 35 -> nearest valid is 33
|
||||
(40, 41), # 40 -> nearest valid is 41
|
||||
(1, 1), # 1 is already valid
|
||||
(1, 1), # 1 is already valid
|
||||
(33, 33), # 33 is already valid
|
||||
]
|
||||
|
||||
for input_frames, expected in test_cases:
|
||||
if input_frames % 8 != 1:
|
||||
adjusted = round((input_frames - 1) / 8) * 8 + 1
|
||||
assert adjusted == expected, \
|
||||
f"Expected {expected} for input {input_frames}, got {adjusted}"
|
||||
assert (
|
||||
adjusted == expected
|
||||
), f"Expected {expected} for input {input_frames}, got {adjusted}"
|
||||
|
||||
|
||||
class TestDenoiseWithCFGMocked:
|
||||
@@ -277,14 +288,16 @@ class TestTilingDefault:
|
||||
def test_tiling_default_is_none(self):
|
||||
"""Default tiling should be 'none' for performance."""
|
||||
import inspect
|
||||
|
||||
from mlx_video.generate_dev import generate_video_dev
|
||||
|
||||
sig = inspect.signature(generate_video_dev)
|
||||
|
||||
tiling_param = sig.parameters.get('tiling')
|
||||
tiling_param = sig.parameters.get("tiling")
|
||||
assert tiling_param is not None
|
||||
assert tiling_param.default == "none", \
|
||||
f"Expected default tiling='none', got '{tiling_param.default}'"
|
||||
assert (
|
||||
tiling_param.default == "none"
|
||||
), f"Expected default tiling='none', got '{tiling_param.default}'"
|
||||
|
||||
|
||||
class TestLatentDimensions:
|
||||
@@ -296,8 +309,9 @@ class TestLatentDimensions:
|
||||
|
||||
for height, expected_latent_h in test_cases:
|
||||
latent_h = height // 32
|
||||
assert latent_h == expected_latent_h, \
|
||||
f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
||||
assert (
|
||||
latent_h == expected_latent_h
|
||||
), f"Expected latent_h={expected_latent_h} for height={height}, got {latent_h}"
|
||||
|
||||
def test_latent_width_calculation(self):
|
||||
"""Latent width should be width // 32."""
|
||||
@@ -305,8 +319,9 @@ class TestLatentDimensions:
|
||||
|
||||
for width, expected_latent_w in test_cases:
|
||||
latent_w = width // 32
|
||||
assert latent_w == expected_latent_w, \
|
||||
f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
||||
assert (
|
||||
latent_w == expected_latent_w
|
||||
), f"Expected latent_w={expected_latent_w} for width={width}, got {latent_w}"
|
||||
|
||||
def test_latent_frames_calculation(self):
|
||||
"""Latent frames should be 1 + (num_frames - 1) // 8."""
|
||||
@@ -314,8 +329,9 @@ class TestLatentDimensions:
|
||||
|
||||
for num_frames, expected_latent_f in test_cases:
|
||||
latent_f = 1 + (num_frames - 1) // 8
|
||||
assert latent_f == expected_latent_f, \
|
||||
f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
||||
assert (
|
||||
latent_f == expected_latent_f
|
||||
), f"Expected latent_f={expected_latent_f} for num_frames={num_frames}, got {latent_f}"
|
||||
|
||||
def test_num_tokens_calculation(self):
|
||||
"""Number of tokens should be latent_f * latent_h * latent_w."""
|
||||
@@ -343,14 +359,14 @@ class TestAudioPositionGrid:
|
||||
positions = create_audio_position_grid(batch_size, audio_frames)
|
||||
expected_shape = (batch_size, 1, audio_frames, 2)
|
||||
|
||||
assert positions.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {positions.shape}"
|
||||
assert (
|
||||
positions.shape == expected_shape
|
||||
), f"Expected {expected_shape}, got {positions.shape}"
|
||||
|
||||
def test_audio_position_grid_dtype(self):
|
||||
"""Audio position grid should be float32."""
|
||||
positions = create_audio_position_grid(1, 34)
|
||||
assert positions.dtype == mx.float32, \
|
||||
f"Expected float32, got {positions.dtype}"
|
||||
assert positions.dtype == mx.float32, f"Expected float32, got {positions.dtype}"
|
||||
|
||||
def test_audio_position_grid_batch_size(self):
|
||||
"""Audio position grid should respect batch size."""
|
||||
@@ -371,8 +387,12 @@ class TestAudioPositionGrid:
|
||||
"""Audio position grid should not contain NaN or Inf."""
|
||||
positions = create_audio_position_grid(1, 34)
|
||||
|
||||
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
|
||||
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
|
||||
assert not mx.any(
|
||||
mx.isnan(positions)
|
||||
).item(), "Audio position grid contains NaN"
|
||||
assert not mx.any(
|
||||
mx.isinf(positions)
|
||||
).item(), "Audio position grid contains Inf"
|
||||
|
||||
|
||||
class TestComputeAudioFrames:
|
||||
@@ -391,8 +411,9 @@ class TestComputeAudioFrames:
|
||||
audio_33 = compute_audio_frames(33, 24.0)
|
||||
audio_65 = compute_audio_frames(65, 24.0)
|
||||
|
||||
assert audio_65 > audio_33, \
|
||||
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
||||
assert (
|
||||
audio_65 > audio_33
|
||||
), f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
|
||||
|
||||
def test_audio_frames_formula(self):
|
||||
"""Audio frames should match expected formula."""
|
||||
|
||||
Reference in New Issue
Block a user