diff --git a/mlx_video/models/ltx_2/__init__.py b/mlx_video/models/ltx_2/__init__.py index f382326..dd2b1e0 100644 --- a/mlx_video/models/ltx_2/__init__.py +++ b/mlx_video/models/ltx_2/__init__.py @@ -4,4 +4,4 @@ from mlx_video.models.ltx_2.config import ( LTXModelType, TransformerConfig, ) -from mlx_video.models.ltx_2.ltx import LTXModel, X0Model +from mlx_video.models.ltx_2.ltx_2 import LTXModel, X0Model diff --git a/mlx_video/models/ltx_2/generate.py b/mlx_video/models/ltx_2/generate.py index 6c3fc72..c6c592d 100644 --- a/mlx_video/models/ltx_2/generate.py +++ b/mlx_video/models/ltx_2/generate.py @@ -33,7 +33,7 @@ from mlx_video.models.ltx_2.conditioning import ( apply_conditioning, ) from mlx_video.models.ltx_2.conditioning.latent import LatentState, apply_denoise_mask -from mlx_video.models.ltx_2.ltx import LTXModel +from mlx_video.models.ltx_2.ltx_2 import LTXModel from mlx_video.models.ltx_2.transformer import Modality from mlx_video.models.ltx_2.upsampler import load_upsampler, upsample_latents from mlx_video.models.ltx_2.video_vae import VideoEncoder diff --git a/mlx_video/models/ltx_2/ltx.py b/mlx_video/models/ltx_2/ltx_2.py similarity index 100% rename from mlx_video/models/ltx_2/ltx.py rename to mlx_video/models/ltx_2/ltx_2.py diff --git a/mlx_video/models/wan2/__init__.py b/mlx_video/models/wan2/__init__.py index b9c08ac..90390e9 100644 --- a/mlx_video/models/wan2/__init__.py +++ b/mlx_video/models/wan2/__init__.py @@ -1,2 +1,2 @@ from mlx_video.models.wan2.config import WanModelConfig -from mlx_video.models.wan2.model import WanModel +from mlx_video.models.wan2.wan2 import WanModel diff --git a/mlx_video/models/wan2/convert.py b/mlx_video/models/wan2/convert.py index 8ae510f..ba2b79a 100644 --- a/mlx_video/models/wan2/convert.py +++ b/mlx_video/models/wan2/convert.py @@ -594,7 +594,7 @@ def _quantize_saved_model( import mlx.nn as nn - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel if source_dir is None: source_dir = output_dir diff --git a/mlx_video/models/wan2/utils.py b/mlx_video/models/wan2/utils.py index 6c9be4f..45964fe 100644 --- a/mlx_video/models/wan2/utils.py +++ b/mlx_video/models/wan2/utils.py @@ -21,7 +21,7 @@ def load_wan_model( If provided, creates QuantizedLinear stubs before loading. loras: Optional list of (lora_path, strength) tuples to apply. """ - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel model = WanModel(config) diff --git a/mlx_video/models/wan2/model.py b/mlx_video/models/wan2/wan2.py similarity index 100% rename from mlx_video/models/wan2/model.py rename to mlx_video/models/wan2/wan2.py diff --git a/tests/test_wan_generate.py b/tests/test_wan_generate.py index f4d1682..e586cce 100644 --- a/tests/test_wan_generate.py +++ b/tests/test_wan_generate.py @@ -14,7 +14,7 @@ class TestEndToEnd: def test_tiny_model_denoise_step(self): """Simulate one denoising step with tiny model.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(42) @@ -43,7 +43,7 @@ class TestEndToEnd: def test_tiny_model_full_loop(self): """Run a complete (tiny) diffusion loop.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(123) @@ -201,7 +201,7 @@ class TestDimensionAlignment: def test_patchify_valid_after_alignment(self): """After alignment, patchify should succeed without reshape errors.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) diff --git a/tests/test_wan_i2v.py b/tests/test_wan_i2v.py index b2a4bab..7c5e0cd 100644 --- a/tests/test_wan_i2v.py +++ b/tests/test_wan_i2v.py @@ -66,7 +66,7 @@ class TestModelYParameter: def test_forward_without_y(self): """Standard T2V forward pass (no y) still works.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -85,7 +85,7 @@ class TestModelYParameter: def test_forward_with_y(self): """I2V forward pass with y channel concatenation.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_i2v_config() model = WanModel(config) @@ -108,7 +108,7 @@ class TestModelYParameter: def test_y_none_is_noop(self): """Passing y=None should be identical to not passing y.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -129,7 +129,7 @@ class TestModelYParameter: def test_batched_cfg_with_y(self): """Batched CFG (B=2) with y should work.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_i2v_config() model = WanModel(config) @@ -307,7 +307,7 @@ class TestI2VEndToEndPipeline: def test_full_i2v_pipeline(self): """End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler from mlx_video.models.wan2.vae import WanVAE @@ -410,7 +410,7 @@ class TestDualModelSwitching: def test_model_selection_by_timestep(self): """Verify high_noise model used for timesteps >= boundary, low_noise otherwise.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(1) @@ -485,7 +485,7 @@ class TestDualModelSwitching: def test_guide_scale_tuple_applied_per_model(self): """Verify (low_gs, high_gs) tuple applies different scales per model.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(2) @@ -545,7 +545,7 @@ class TestDualModelSwitching: def test_single_model_fallback_with_tuple_guide_scale(self): """When dual_model=False, guide_scale tuple should use first element.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler mx.random.seed(3) diff --git a/tests/test_wan_model.py b/tests/test_wan_model.py index 650e0e5..e415052 100644 --- a/tests/test_wan_model.py +++ b/tests/test_wan_model.py @@ -12,7 +12,7 @@ from wan_test_helpers import _make_tiny_config class TestSinusoidalEmbedding: def test_output_shape(self): - from mlx_video.models.wan2.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d pos = mx.arange(10).astype(mx.float32) emb = sinusoidal_embedding_1d(256, pos) @@ -21,7 +21,7 @@ class TestSinusoidalEmbedding: def test_position_zero(self): """Position 0 should have cos=1 for all dims and sin=0.""" - from mlx_video.models.wan2.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d pos = mx.array([0.0]) emb = sinusoidal_embedding_1d(64, pos) @@ -33,7 +33,7 @@ class TestSinusoidalEmbedding: np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5) def test_different_positions_differ(self): - from mlx_video.models.wan2.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d pos = mx.array([0.0, 100.0, 999.0]) emb = sinusoidal_embedding_1d(128, pos) @@ -50,7 +50,7 @@ class TestSinusoidalEmbedding: class TestHead: def test_output_shape(self): - from mlx_video.models.wan2.model import Head + from mlx_video.models.wan2.wan2 import Head head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) B, L = 1, 24 @@ -62,7 +62,7 @@ class TestHead: assert out.shape == (B, L, expected_proj_dim) def test_modulation_shape(self): - from mlx_video.models.wan2.model import Head + from mlx_video.models.wan2.wan2 import Head head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2)) assert head.modulation.shape == (1, 2, 64) @@ -78,7 +78,7 @@ class TestWanModel: mx.random.seed(42) def test_instantiation(self): - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -86,7 +86,7 @@ class TestWanModel: assert num_params > 0 def test_patchify_shape(self): - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -99,7 +99,7 @@ class TestWanModel: assert patches.shape == (1, 1 * 2 * 2, config.dim) def test_patchify_various_sizes(self): - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -115,7 +115,7 @@ class TestWanModel: def test_unpatchify_inverse(self): """Patchify then unpatchify should reconstruct original spatial dims.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -131,7 +131,7 @@ class TestWanModel: assert out[0].shape == (config.out_dim, F, H, W) def test_forward_pass(self): - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -149,7 +149,7 @@ class TestWanModel: assert out[0].shape == (C, F, H, W) def test_forward_batch(self): - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -171,7 +171,7 @@ class TestWanModel: assert o.shape == (C, F, H, W) def test_output_is_float32(self): - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -234,7 +234,7 @@ class TestWan21Model: def test_wan21_tiny_model_forward(self): """Forward pass with Wan2.1 tiny config.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = self._make_tiny_wan21_config() model = WanModel(config) @@ -252,7 +252,7 @@ class TestWan21Model: def test_wan21_1_3b_tiny_model_forward(self): """Forward pass with Wan2.1 1.3B tiny config.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = self._make_tiny_wan21_1_3b_config() model = WanModel(config) @@ -270,7 +270,7 @@ class TestWan21Model: def test_wan21_single_model_loop(self): """Full diffusion loop with single model (Wan2.1 style).""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler config = self._make_tiny_wan21_config() @@ -333,21 +333,21 @@ class TestPerTokenTimestep: """Tests for per-token sinusoidal embedding.""" def test_1d_unchanged(self): - from mlx_video.models.wan2.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d pos = mx.array([0.0, 100.0, 500.0]) emb = sinusoidal_embedding_1d(256, pos) assert emb.shape == (3, 256) def test_2d_per_token(self): - from mlx_video.models.wan2.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]]) emb = sinusoidal_embedding_1d(256, pos) assert emb.shape == (2, 3, 256) def test_consistency(self): - from mlx_video.models.wan2.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d pos_1d = mx.array([0.0, 100.0]) emb_1d = sinusoidal_embedding_1d(256, pos_1d) diff --git a/tests/test_wan_quantization.py b/tests/test_wan_quantization.py index 1eb9622..14fe3ca 100644 --- a/tests/test_wan_quantization.py +++ b/tests/test_wan_quantization.py @@ -91,7 +91,7 @@ class TestQuantizeRoundTrip: def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64): """Helper: create model, quantize, save to tmp_path.""" from mlx_video.models.wan2.convert import _quantize_predicate - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel model = WanModel(config) nn.quantize( @@ -164,7 +164,7 @@ class TestQuantizeRoundTrip: def test_loading_without_quantization_flag(self, tmp_path): """Loading a non-quantized model should have standard Linear layers.""" - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -188,7 +188,7 @@ class TestQuantizeRoundTrip: class TestQuantizedInference: def _make_quantized_model(self, config, bits=4): from mlx_video.models.wan2.convert import _quantize_predicate - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel model = WanModel(config) nn.quantize( @@ -239,7 +239,7 @@ class TestQuantizedInference: def test_quantized_output_differs_from_unquantized(self): """Sanity check: quantization should change the weights.""" from mlx_video.models.wan2.convert import _quantize_predicate - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() mx.random.seed(42) @@ -272,7 +272,7 @@ class TestQuantizationConfig: def test_config_metadata_written(self, tmp_path): """Verify _quantize_saved_model writes quantization metadata to config.json.""" from mlx_video.models.wan2.convert import _quantize_saved_model - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -296,7 +296,7 @@ class TestQuantizationConfig: def test_config_metadata_8bit(self, tmp_path): from mlx_video.models.wan2.convert import _quantize_saved_model - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() model = WanModel(config) @@ -317,7 +317,7 @@ class TestQuantizationConfig: def test_dual_model_quantization(self, tmp_path): """Verify dual-model quantization writes both model files.""" from mlx_video.models.wan2.convert import _quantize_saved_model - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = _make_tiny_config() diff --git a/tests/test_wan_rope_freqs.py b/tests/test_wan_rope_freqs.py index 5da2a5f..93324a5 100644 --- a/tests/test_wan_rope_freqs.py +++ b/tests/test_wan_rope_freqs.py @@ -28,7 +28,7 @@ class TestRoPEFrequencyConstruction: def _get_model_freqs(self, dim=64, num_heads=4): """Instantiate a tiny WanModel and return its .freqs tensor.""" from mlx_video.models.wan2.config import WanModelConfig - from mlx_video.models.wan2.model import WanModel + from mlx_video.models.wan2.wan2 import WanModel config = WanModelConfig() config.dim = dim diff --git a/tests/test_wan_transformer.py b/tests/test_wan_transformer.py index 7d197c2..66df8c5 100644 --- a/tests/test_wan_transformer.py +++ b/tests/test_wan_transformer.py @@ -153,7 +153,7 @@ class TestFloat32Modulation: def test_head_modulation_float32(self): """Head modulation should be float32 even with bf16 e input.""" - from mlx_video.models.wan2.model import Head + from mlx_video.models.wan2.wan2 import Head head = Head(self.dim, 4, (1, 2, 2)) x = mx.random.normal((1, 8, self.dim)) @@ -164,7 +164,7 @@ class TestFloat32Modulation: def test_model_time_embedding_float32(self): """sinusoidal_embedding_1d output must be float32.""" - from mlx_video.models.wan2.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d t = mx.array([500.0]) emb = sinusoidal_embedding_1d(256, t) @@ -173,7 +173,7 @@ class TestFloat32Modulation: def test_model_per_token_time_embedding_float32(self): """Per-token time embeddings (I2V) should also be float32.""" - from mlx_video.models.wan2.model import sinusoidal_embedding_1d + from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4] emb = sinusoidal_embedding_1d(256, t)