Refactor Wan model structure by renaming and relocating model imports from model.py to wan2.py, enhancing code organization and clarity across the Wan2 module.
This commit is contained in:
@@ -4,4 +4,4 @@ from mlx_video.models.ltx_2.config import (
|
|||||||
LTXModelType,
|
LTXModelType,
|
||||||
TransformerConfig,
|
TransformerConfig,
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx_2.ltx import LTXModel, X0Model
|
from mlx_video.models.ltx_2.ltx_2 import LTXModel, X0Model
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from mlx_video.models.ltx_2.conditioning import (
|
|||||||
apply_conditioning,
|
apply_conditioning,
|
||||||
)
|
)
|
||||||
from mlx_video.models.ltx_2.conditioning.latent import LatentState, apply_denoise_mask
|
from mlx_video.models.ltx_2.conditioning.latent import LatentState, apply_denoise_mask
|
||||||
from mlx_video.models.ltx_2.ltx import LTXModel
|
from mlx_video.models.ltx_2.ltx_2 import LTXModel
|
||||||
from mlx_video.models.ltx_2.transformer import Modality
|
from mlx_video.models.ltx_2.transformer import Modality
|
||||||
from mlx_video.models.ltx_2.upsampler import load_upsampler, upsample_latents
|
from mlx_video.models.ltx_2.upsampler import load_upsampler, upsample_latents
|
||||||
from mlx_video.models.ltx_2.video_vae import VideoEncoder
|
from mlx_video.models.ltx_2.video_vae import VideoEncoder
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
from mlx_video.models.wan2.config import WanModelConfig
|
from mlx_video.models.wan2.config import WanModelConfig
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|||||||
@@ -594,7 +594,7 @@ def _quantize_saved_model(
|
|||||||
|
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
if source_dir is None:
|
if source_dir is None:
|
||||||
source_dir = output_dir
|
source_dir = output_dir
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def load_wan_model(
|
|||||||
If provided, creates QuantizedLinear stubs before loading.
|
If provided, creates QuantizedLinear stubs before loading.
|
||||||
loras: Optional list of (lora_path, strength) tuples to apply.
|
loras: Optional list of (lora_path, strength) tuples to apply.
|
||||||
"""
|
"""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class TestEndToEnd:
|
|||||||
|
|
||||||
def test_tiny_model_denoise_step(self):
|
def test_tiny_model_denoise_step(self):
|
||||||
"""Simulate one denoising step with tiny model."""
|
"""Simulate one denoising step with tiny model."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
@@ -43,7 +43,7 @@ class TestEndToEnd:
|
|||||||
|
|
||||||
def test_tiny_model_full_loop(self):
|
def test_tiny_model_full_loop(self):
|
||||||
"""Run a complete (tiny) diffusion loop."""
|
"""Run a complete (tiny) diffusion loop."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
mx.random.seed(123)
|
mx.random.seed(123)
|
||||||
@@ -201,7 +201,7 @@ class TestDimensionAlignment:
|
|||||||
|
|
||||||
def test_patchify_valid_after_alignment(self):
|
def test_patchify_valid_after_alignment(self):
|
||||||
"""After alignment, patchify should succeed without reshape errors."""
|
"""After alignment, patchify should succeed without reshape errors."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class TestModelYParameter:
|
|||||||
|
|
||||||
def test_forward_without_y(self):
|
def test_forward_without_y(self):
|
||||||
"""Standard T2V forward pass (no y) still works."""
|
"""Standard T2V forward pass (no y) still works."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -85,7 +85,7 @@ class TestModelYParameter:
|
|||||||
|
|
||||||
def test_forward_with_y(self):
|
def test_forward_with_y(self):
|
||||||
"""I2V forward pass with y channel concatenation."""
|
"""I2V forward pass with y channel concatenation."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_i2v_config()
|
config = _make_tiny_i2v_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -108,7 +108,7 @@ class TestModelYParameter:
|
|||||||
|
|
||||||
def test_y_none_is_noop(self):
|
def test_y_none_is_noop(self):
|
||||||
"""Passing y=None should be identical to not passing y."""
|
"""Passing y=None should be identical to not passing y."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -129,7 +129,7 @@ class TestModelYParameter:
|
|||||||
|
|
||||||
def test_batched_cfg_with_y(self):
|
def test_batched_cfg_with_y(self):
|
||||||
"""Batched CFG (B=2) with y should work."""
|
"""Batched CFG (B=2) with y should work."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_i2v_config()
|
config = _make_tiny_i2v_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -307,7 +307,7 @@ class TestI2VEndToEndPipeline:
|
|||||||
|
|
||||||
def test_full_i2v_pipeline(self):
|
def test_full_i2v_pipeline(self):
|
||||||
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
"""End-to-end I2V: synthetic image → VAE encode → build y → denoise → VAE decode."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||||
from mlx_video.models.wan2.vae import WanVAE
|
from mlx_video.models.wan2.vae import WanVAE
|
||||||
|
|
||||||
@@ -410,7 +410,7 @@ class TestDualModelSwitching:
|
|||||||
|
|
||||||
def test_model_selection_by_timestep(self):
|
def test_model_selection_by_timestep(self):
|
||||||
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
"""Verify high_noise model used for timesteps >= boundary, low_noise otherwise."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
mx.random.seed(1)
|
mx.random.seed(1)
|
||||||
@@ -485,7 +485,7 @@ class TestDualModelSwitching:
|
|||||||
|
|
||||||
def test_guide_scale_tuple_applied_per_model(self):
|
def test_guide_scale_tuple_applied_per_model(self):
|
||||||
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
"""Verify (low_gs, high_gs) tuple applies different scales per model."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
mx.random.seed(2)
|
mx.random.seed(2)
|
||||||
@@ -545,7 +545,7 @@ class TestDualModelSwitching:
|
|||||||
|
|
||||||
def test_single_model_fallback_with_tuple_guide_scale(self):
|
def test_single_model_fallback_with_tuple_guide_scale(self):
|
||||||
"""When dual_model=False, guide_scale tuple should use first element."""
|
"""When dual_model=False, guide_scale tuple should use first element."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
mx.random.seed(3)
|
mx.random.seed(3)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from wan_test_helpers import _make_tiny_config
|
|||||||
|
|
||||||
class TestSinusoidalEmbedding:
|
class TestSinusoidalEmbedding:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos = mx.arange(10).astype(mx.float32)
|
pos = mx.arange(10).astype(mx.float32)
|
||||||
emb = sinusoidal_embedding_1d(256, pos)
|
emb = sinusoidal_embedding_1d(256, pos)
|
||||||
@@ -21,7 +21,7 @@ class TestSinusoidalEmbedding:
|
|||||||
|
|
||||||
def test_position_zero(self):
|
def test_position_zero(self):
|
||||||
"""Position 0 should have cos=1 for all dims and sin=0."""
|
"""Position 0 should have cos=1 for all dims and sin=0."""
|
||||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos = mx.array([0.0])
|
pos = mx.array([0.0])
|
||||||
emb = sinusoidal_embedding_1d(64, pos)
|
emb = sinusoidal_embedding_1d(64, pos)
|
||||||
@@ -33,7 +33,7 @@ class TestSinusoidalEmbedding:
|
|||||||
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
|
np.testing.assert_allclose(emb_np[32:], 0.0, atol=1e-5)
|
||||||
|
|
||||||
def test_different_positions_differ(self):
|
def test_different_positions_differ(self):
|
||||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos = mx.array([0.0, 100.0, 999.0])
|
pos = mx.array([0.0, 100.0, 999.0])
|
||||||
emb = sinusoidal_embedding_1d(128, pos)
|
emb = sinusoidal_embedding_1d(128, pos)
|
||||||
@@ -50,7 +50,7 @@ class TestSinusoidalEmbedding:
|
|||||||
|
|
||||||
class TestHead:
|
class TestHead:
|
||||||
def test_output_shape(self):
|
def test_output_shape(self):
|
||||||
from mlx_video.models.wan2.model import Head
|
from mlx_video.models.wan2.wan2 import Head
|
||||||
|
|
||||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||||
B, L = 1, 24
|
B, L = 1, 24
|
||||||
@@ -62,7 +62,7 @@ class TestHead:
|
|||||||
assert out.shape == (B, L, expected_proj_dim)
|
assert out.shape == (B, L, expected_proj_dim)
|
||||||
|
|
||||||
def test_modulation_shape(self):
|
def test_modulation_shape(self):
|
||||||
from mlx_video.models.wan2.model import Head
|
from mlx_video.models.wan2.wan2 import Head
|
||||||
|
|
||||||
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
head = Head(dim=64, out_dim=16, patch_size=(1, 2, 2))
|
||||||
assert head.modulation.shape == (1, 2, 64)
|
assert head.modulation.shape == (1, 2, 64)
|
||||||
@@ -78,7 +78,7 @@ class TestWanModel:
|
|||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
|
|
||||||
def test_instantiation(self):
|
def test_instantiation(self):
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -86,7 +86,7 @@ class TestWanModel:
|
|||||||
assert num_params > 0
|
assert num_params > 0
|
||||||
|
|
||||||
def test_patchify_shape(self):
|
def test_patchify_shape(self):
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -99,7 +99,7 @@ class TestWanModel:
|
|||||||
assert patches.shape == (1, 1 * 2 * 2, config.dim)
|
assert patches.shape == (1, 1 * 2 * 2, config.dim)
|
||||||
|
|
||||||
def test_patchify_various_sizes(self):
|
def test_patchify_various_sizes(self):
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -115,7 +115,7 @@ class TestWanModel:
|
|||||||
|
|
||||||
def test_unpatchify_inverse(self):
|
def test_unpatchify_inverse(self):
|
||||||
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
"""Patchify then unpatchify should reconstruct original spatial dims."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -131,7 +131,7 @@ class TestWanModel:
|
|||||||
assert out[0].shape == (config.out_dim, F, H, W)
|
assert out[0].shape == (config.out_dim, F, H, W)
|
||||||
|
|
||||||
def test_forward_pass(self):
|
def test_forward_pass(self):
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -149,7 +149,7 @@ class TestWanModel:
|
|||||||
assert out[0].shape == (C, F, H, W)
|
assert out[0].shape == (C, F, H, W)
|
||||||
|
|
||||||
def test_forward_batch(self):
|
def test_forward_batch(self):
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -171,7 +171,7 @@ class TestWanModel:
|
|||||||
assert o.shape == (C, F, H, W)
|
assert o.shape == (C, F, H, W)
|
||||||
|
|
||||||
def test_output_is_float32(self):
|
def test_output_is_float32(self):
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -234,7 +234,7 @@ class TestWan21Model:
|
|||||||
|
|
||||||
def test_wan21_tiny_model_forward(self):
|
def test_wan21_tiny_model_forward(self):
|
||||||
"""Forward pass with Wan2.1 tiny config."""
|
"""Forward pass with Wan2.1 tiny config."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = self._make_tiny_wan21_config()
|
config = self._make_tiny_wan21_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -252,7 +252,7 @@ class TestWan21Model:
|
|||||||
|
|
||||||
def test_wan21_1_3b_tiny_model_forward(self):
|
def test_wan21_1_3b_tiny_model_forward(self):
|
||||||
"""Forward pass with Wan2.1 1.3B tiny config."""
|
"""Forward pass with Wan2.1 1.3B tiny config."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = self._make_tiny_wan21_1_3b_config()
|
config = self._make_tiny_wan21_1_3b_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -270,7 +270,7 @@ class TestWan21Model:
|
|||||||
|
|
||||||
def test_wan21_single_model_loop(self):
|
def test_wan21_single_model_loop(self):
|
||||||
"""Full diffusion loop with single model (Wan2.1 style)."""
|
"""Full diffusion loop with single model (Wan2.1 style)."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
from mlx_video.models.wan2.scheduler import FlowMatchEulerScheduler
|
||||||
|
|
||||||
config = self._make_tiny_wan21_config()
|
config = self._make_tiny_wan21_config()
|
||||||
@@ -333,21 +333,21 @@ class TestPerTokenTimestep:
|
|||||||
"""Tests for per-token sinusoidal embedding."""
|
"""Tests for per-token sinusoidal embedding."""
|
||||||
|
|
||||||
def test_1d_unchanged(self):
|
def test_1d_unchanged(self):
|
||||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos = mx.array([0.0, 100.0, 500.0])
|
pos = mx.array([0.0, 100.0, 500.0])
|
||||||
emb = sinusoidal_embedding_1d(256, pos)
|
emb = sinusoidal_embedding_1d(256, pos)
|
||||||
assert emb.shape == (3, 256)
|
assert emb.shape == (3, 256)
|
||||||
|
|
||||||
def test_2d_per_token(self):
|
def test_2d_per_token(self):
|
||||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
|
pos = mx.array([[0.0, 100.0, 100.0], [50.0, 50.0, 50.0]])
|
||||||
emb = sinusoidal_embedding_1d(256, pos)
|
emb = sinusoidal_embedding_1d(256, pos)
|
||||||
assert emb.shape == (2, 3, 256)
|
assert emb.shape == (2, 3, 256)
|
||||||
|
|
||||||
def test_consistency(self):
|
def test_consistency(self):
|
||||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||||
|
|
||||||
pos_1d = mx.array([0.0, 100.0])
|
pos_1d = mx.array([0.0, 100.0])
|
||||||
emb_1d = sinusoidal_embedding_1d(256, pos_1d)
|
emb_1d = sinusoidal_embedding_1d(256, pos_1d)
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class TestQuantizeRoundTrip:
|
|||||||
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
def _quantize_and_save(self, config, tmp_path, bits=4, group_size=64):
|
||||||
"""Helper: create model, quantize, save to tmp_path."""
|
"""Helper: create model, quantize, save to tmp_path."""
|
||||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
@@ -164,7 +164,7 @@ class TestQuantizeRoundTrip:
|
|||||||
|
|
||||||
def test_loading_without_quantization_flag(self, tmp_path):
|
def test_loading_without_quantization_flag(self, tmp_path):
|
||||||
"""Loading a non-quantized model should have standard Linear layers."""
|
"""Loading a non-quantized model should have standard Linear layers."""
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -188,7 +188,7 @@ class TestQuantizeRoundTrip:
|
|||||||
class TestQuantizedInference:
|
class TestQuantizedInference:
|
||||||
def _make_quantized_model(self, config, bits=4):
|
def _make_quantized_model(self, config, bits=4):
|
||||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
nn.quantize(
|
nn.quantize(
|
||||||
@@ -239,7 +239,7 @@ class TestQuantizedInference:
|
|||||||
def test_quantized_output_differs_from_unquantized(self):
|
def test_quantized_output_differs_from_unquantized(self):
|
||||||
"""Sanity check: quantization should change the weights."""
|
"""Sanity check: quantization should change the weights."""
|
||||||
from mlx_video.models.wan2.convert import _quantize_predicate
|
from mlx_video.models.wan2.convert import _quantize_predicate
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
@@ -272,7 +272,7 @@ class TestQuantizationConfig:
|
|||||||
def test_config_metadata_written(self, tmp_path):
|
def test_config_metadata_written(self, tmp_path):
|
||||||
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
"""Verify _quantize_saved_model writes quantization metadata to config.json."""
|
||||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -296,7 +296,7 @@ class TestQuantizationConfig:
|
|||||||
|
|
||||||
def test_config_metadata_8bit(self, tmp_path):
|
def test_config_metadata_8bit(self, tmp_path):
|
||||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
model = WanModel(config)
|
model = WanModel(config)
|
||||||
@@ -317,7 +317,7 @@ class TestQuantizationConfig:
|
|||||||
def test_dual_model_quantization(self, tmp_path):
|
def test_dual_model_quantization(self, tmp_path):
|
||||||
"""Verify dual-model quantization writes both model files."""
|
"""Verify dual-model quantization writes both model files."""
|
||||||
from mlx_video.models.wan2.convert import _quantize_saved_model
|
from mlx_video.models.wan2.convert import _quantize_saved_model
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = _make_tiny_config()
|
config = _make_tiny_config()
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class TestRoPEFrequencyConstruction:
|
|||||||
def _get_model_freqs(self, dim=64, num_heads=4):
|
def _get_model_freqs(self, dim=64, num_heads=4):
|
||||||
"""Instantiate a tiny WanModel and return its .freqs tensor."""
|
"""Instantiate a tiny WanModel and return its .freqs tensor."""
|
||||||
from mlx_video.models.wan2.config import WanModelConfig
|
from mlx_video.models.wan2.config import WanModelConfig
|
||||||
from mlx_video.models.wan2.model import WanModel
|
from mlx_video.models.wan2.wan2 import WanModel
|
||||||
|
|
||||||
config = WanModelConfig()
|
config = WanModelConfig()
|
||||||
config.dim = dim
|
config.dim = dim
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class TestFloat32Modulation:
|
|||||||
|
|
||||||
def test_head_modulation_float32(self):
|
def test_head_modulation_float32(self):
|
||||||
"""Head modulation should be float32 even with bf16 e input."""
|
"""Head modulation should be float32 even with bf16 e input."""
|
||||||
from mlx_video.models.wan2.model import Head
|
from mlx_video.models.wan2.wan2 import Head
|
||||||
|
|
||||||
head = Head(self.dim, 4, (1, 2, 2))
|
head = Head(self.dim, 4, (1, 2, 2))
|
||||||
x = mx.random.normal((1, 8, self.dim))
|
x = mx.random.normal((1, 8, self.dim))
|
||||||
@@ -164,7 +164,7 @@ class TestFloat32Modulation:
|
|||||||
|
|
||||||
def test_model_time_embedding_float32(self):
|
def test_model_time_embedding_float32(self):
|
||||||
"""sinusoidal_embedding_1d output must be float32."""
|
"""sinusoidal_embedding_1d output must be float32."""
|
||||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||||
|
|
||||||
t = mx.array([500.0])
|
t = mx.array([500.0])
|
||||||
emb = sinusoidal_embedding_1d(256, t)
|
emb = sinusoidal_embedding_1d(256, t)
|
||||||
@@ -173,7 +173,7 @@ class TestFloat32Modulation:
|
|||||||
|
|
||||||
def test_model_per_token_time_embedding_float32(self):
|
def test_model_per_token_time_embedding_float32(self):
|
||||||
"""Per-token time embeddings (I2V) should also be float32."""
|
"""Per-token time embeddings (I2V) should also be float32."""
|
||||||
from mlx_video.models.wan2.model import sinusoidal_embedding_1d
|
from mlx_video.models.wan2.wan2 import sinusoidal_embedding_1d
|
||||||
|
|
||||||
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
t = mx.array([[0.0, 100.0, 200.0, 300.0]]) # [B=1, L=4]
|
||||||
emb = sinusoidal_embedding_1d(256, t)
|
emb = sinusoidal_embedding_1d(256, t)
|
||||||
|
|||||||
Reference in New Issue
Block a user