Add audio generation capabilities to video pipeline, including audio position grid creation, audio frame computation, and integration of audio VAE and vocoder. Update tests to cover new audio functionalities.

This commit is contained in:
Prince Canuma
2026-01-18 21:28:56 +01:00
parent b36ad1e22d
commit 7069cc39c9
2 changed files with 667 additions and 127 deletions

View File

@@ -2,14 +2,16 @@
import pytest
import mlx.core as mx
import numpy as np
from mlx_video.generate_dev import (
ltx2_scheduler,
create_position_grid,
create_audio_position_grid,
compute_audio_frames,
cfg_delta,
denoise_with_cfg,
DEFAULT_NEGATIVE_PROMPT,
AUDIO_SAMPLE_RATE,
AUDIO_LATENTS_PER_SECOND,
)
@@ -260,28 +262,6 @@ class TestInputValidation:
class TestDenoiseWithCFGMocked:
"""Tests for denoise_with_cfg with mocked transformer."""
def test_denoise_returns_correct_shape(self):
"""Denoised output should have same shape as input latents."""
# Create a simple mock transformer
class MockTransformer:
inner_dim = 4096
positional_embedding_theta = 10000.0
positional_embedding_max_pos = [20, 2048, 2048]
use_middle_indices_grid = True
num_attention_heads = 32
rope_type = None
class config:
double_precision_rope = True
def __call__(self, video, audio):
# Return input as output (identity)
return video.latent, None
# Skip this test if we can't import the required modules easily
# This is a structural test to ensure the function signature is correct
pass
def test_sigmas_list_conversion(self):
"""Sigmas should be convertible to list."""
sigmas = ltx2_scheduler(steps=5)
@@ -296,16 +276,10 @@ class TestTilingDefault:
def test_tiling_default_is_none(self):
"""Default tiling should be 'none' for performance."""
# Import and check the default
import argparse
from mlx_video.generate_dev import main
# The default is set in the argparse definition
# We verify this by checking the function signature
import inspect
sig = inspect.signature(
__import__('mlx_video.generate_dev', fromlist=['generate_video_dev']).generate_video_dev
)
from mlx_video.generate_dev import generate_video_dev
sig = inspect.signature(generate_video_dev)
tiling_param = sig.parameters.get('tiling')
assert tiling_param is not None
@@ -358,5 +332,91 @@ class TestLatentDimensions:
assert num_tokens == expected, f"Expected {expected} tokens, got {num_tokens}"
class TestAudioPositionGrid:
"""Tests for audio position grid creation."""
def test_audio_position_grid_shape(self):
"""Audio position grid should have correct shape (B, 1, T, 2)."""
batch_size = 1
audio_frames = 34 # ~1.36 seconds at 25 latent frames/sec
positions = create_audio_position_grid(batch_size, audio_frames)
expected_shape = (batch_size, 1, audio_frames, 2)
assert positions.shape == expected_shape, \
f"Expected {expected_shape}, got {positions.shape}"
def test_audio_position_grid_dtype(self):
"""Audio position grid should be float32."""
positions = create_audio_position_grid(1, 34)
assert positions.dtype == mx.float32, \
f"Expected float32, got {positions.dtype}"
def test_audio_position_grid_batch_size(self):
"""Audio position grid should respect batch size."""
for batch_size in [1, 2, 4]:
positions = create_audio_position_grid(batch_size, 34)
assert positions.shape[0] == batch_size
def test_audio_position_grid_temporal_values(self):
"""Audio positions should be in seconds."""
positions = create_audio_position_grid(1, 34)
# Values should be in seconds (small values for ~1 second of audio)
max_val = mx.max(positions).item()
assert max_val < 10, f"Audio positions seem too large: {max_val}"
assert max_val > 0, "Audio positions should be positive"
def test_audio_position_grid_no_nan_or_inf(self):
"""Audio position grid should not contain NaN or Inf."""
positions = create_audio_position_grid(1, 34)
assert not mx.any(mx.isnan(positions)).item(), "Audio position grid contains NaN"
assert not mx.any(mx.isinf(positions)).item(), "Audio position grid contains Inf"
class TestComputeAudioFrames:
"""Tests for audio frame count calculation."""
def test_audio_frames_basic(self):
"""Audio frames should be proportional to video duration."""
# 33 frames at 24 fps = ~1.375 seconds
# At 25 latent frames/sec = ~34 audio frames
audio_frames = compute_audio_frames(33, 24.0)
assert audio_frames > 0
assert isinstance(audio_frames, int)
def test_audio_frames_scales_with_video(self):
"""More video frames should produce more audio frames."""
audio_33 = compute_audio_frames(33, 24.0)
audio_65 = compute_audio_frames(65, 24.0)
assert audio_65 > audio_33, \
f"Expected more audio frames for longer video: {audio_65} <= {audio_33}"
def test_audio_frames_formula(self):
"""Audio frames should match expected formula."""
num_video_frames = 33
fps = 24.0
duration = num_video_frames / fps # ~1.375 seconds
expected = round(duration * AUDIO_LATENTS_PER_SECOND)
actual = compute_audio_frames(num_video_frames, fps)
assert actual == expected, f"Expected {expected}, got {actual}"
class TestAudioConstants:
"""Tests for audio constants."""
def test_audio_sample_rate(self):
"""Audio sample rate should be 24000 Hz."""
assert AUDIO_SAMPLE_RATE == 24000
def test_audio_latents_per_second(self):
"""Audio latents per second should be 25."""
assert AUDIO_LATENTS_PER_SECOND == 25.0
if __name__ == "__main__":
pytest.main([__file__, "-v"])