Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -12,7 +12,7 @@ import numpy as np
|
||||
|
||||
class TestCausalConv3d:
|
||||
def test_output_shape_stride1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
from mlx_video.models.wan2.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
||||
# Initialize weights
|
||||
@@ -28,7 +28,7 @@ class TestCausalConv3d:
|
||||
assert out.shape[4] == 8 # W preserved
|
||||
|
||||
def test_output_shape_kernel1(self):
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
from mlx_video.models.wan2.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
@@ -39,7 +39,7 @@ class TestCausalConv3d:
|
||||
|
||||
def test_causal_padding(self):
|
||||
"""Causal conv should only use past/current frames, not future."""
|
||||
from mlx_video.models.wan.vae import CausalConv3d
|
||||
from mlx_video.models.wan2.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
@@ -56,7 +56,7 @@ class TestCausalConv3d:
|
||||
|
||||
class TestResidualBlock:
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
from mlx_video.models.wan2.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
@@ -65,7 +65,7 @@ class TestResidualBlock:
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
from mlx_video.models.wan2.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
@@ -74,13 +74,13 @@ class TestResidualBlock:
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_shortcut_exists_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
from mlx_video.models.wan2.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_when_dims_same(self):
|
||||
from mlx_video.models.wan.vae import ResidualBlock
|
||||
from mlx_video.models.wan2.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
@@ -88,7 +88,7 @@ class TestResidualBlock:
|
||||
|
||||
class TestAttentionBlock:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
from mlx_video.models.wan2.vae import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
@@ -97,7 +97,7 @@ class TestAttentionBlock:
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae import AttentionBlock
|
||||
from mlx_video.models.wan2.vae import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 1, 3, 3))
|
||||
@@ -109,7 +109,7 @@ class TestAttentionBlock:
|
||||
|
||||
class TestWanVAE:
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
assert vae.z_dim == 16
|
||||
@@ -117,7 +117,7 @@ class TestWanVAE:
|
||||
assert vae.std.shape == (16,)
|
||||
|
||||
def test_normalization_stats(self):
|
||||
from mlx_video.models.wan.vae import VAE_MEAN, VAE_STD
|
||||
from mlx_video.models.wan2.vae import VAE_MEAN, VAE_STD
|
||||
|
||||
assert len(VAE_MEAN) == 16
|
||||
assert len(VAE_STD) == 16
|
||||
@@ -133,7 +133,7 @@ class TestVAE22CausalConv3d:
|
||||
"""Tests for vae22.CausalConv3d (channels-last)."""
|
||||
|
||||
def test_output_shape_k3(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
from mlx_video.models.wan2.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
||||
@@ -142,7 +142,7 @@ class TestVAE22CausalConv3d:
|
||||
assert out.shape == (1, 4, 8, 8, 16)
|
||||
|
||||
def test_output_shape_k1(self):
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
from mlx_video.models.wan2.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(8, 16, kernel_size=1)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -152,7 +152,7 @@ class TestVAE22CausalConv3d:
|
||||
|
||||
def test_temporal_causal(self):
|
||||
"""Output at t=0 should not depend on t>0."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
from mlx_video.models.wan2.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
@@ -178,7 +178,7 @@ class TestVAE22CausalConv3d:
|
||||
|
||||
def test_channels_last_format(self):
|
||||
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
||||
from mlx_video.models.wan.vae22 import CausalConv3d
|
||||
from mlx_video.models.wan2.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((2, 3, 6, 6, 4))
|
||||
@@ -191,7 +191,7 @@ class TestRMSNorm:
|
||||
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
from mlx_video.models.wan2.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(16)
|
||||
x = mx.random.normal((2, 4, 4, 4, 16))
|
||||
@@ -201,7 +201,7 @@ class TestRMSNorm:
|
||||
|
||||
def test_l2_normalization(self):
|
||||
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
from mlx_video.models.wan2.vae22 import RMS_norm
|
||||
|
||||
dim = 32
|
||||
norm = RMS_norm(dim)
|
||||
@@ -215,7 +215,7 @@ class TestRMSNorm:
|
||||
|
||||
def test_scale_invariant(self):
|
||||
"""Scaling input by constant should not change output (L2 norm property)."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
from mlx_video.models.wan2.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(8)
|
||||
x = mx.random.normal((1, 1, 1, 1, 8))
|
||||
@@ -226,7 +226,7 @@ class TestRMSNorm:
|
||||
|
||||
def test_gamma_effect(self):
|
||||
"""Non-unit gamma should scale output."""
|
||||
from mlx_video.models.wan.vae22 import RMS_norm
|
||||
from mlx_video.models.wan2.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(4)
|
||||
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
||||
@@ -241,7 +241,7 @@ class TestDupUp3D:
|
||||
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
from mlx_video.models.wan2.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
@@ -250,7 +250,7 @@ class TestDupUp3D:
|
||||
assert out.shape == (1, 3, 8, 8, 4)
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
from mlx_video.models.wan2.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 16))
|
||||
@@ -259,7 +259,7 @@ class TestDupUp3D:
|
||||
assert out.shape == (1, 6, 8, 8, 8)
|
||||
|
||||
def test_first_chunk_trims(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
from mlx_video.models.wan2.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
@@ -271,7 +271,7 @@ class TestDupUp3D:
|
||||
assert out_trimmed.shape[1] == 5
|
||||
|
||||
def test_no_temporal_first_chunk_noop(self):
|
||||
from mlx_video.models.wan.vae22 import DupUp3D
|
||||
from mlx_video.models.wan2.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
@@ -286,7 +286,7 @@ class TestVAE22Resample:
|
||||
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
|
||||
|
||||
def test_upsample2d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample2d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
@@ -296,7 +296,7 @@ class TestVAE22Resample:
|
||||
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
|
||||
|
||||
def test_upsample3d_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
@@ -306,7 +306,7 @@ class TestVAE22Resample:
|
||||
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
|
||||
|
||||
def test_upsample3d_first_chunk(self):
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
@@ -318,7 +318,7 @@ class TestVAE22Resample:
|
||||
|
||||
def test_upsample3d_first_chunk_single_frame(self):
|
||||
"""Single-frame input with first_chunk: no temporal upsample."""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
@@ -336,7 +336,7 @@ class TestVAE22Resample:
|
||||
We verify this by checking that the first output frame depends only on
|
||||
the first input frame (not on time_conv parameters).
|
||||
"""
|
||||
from mlx_video.models.wan.vae22 import Resample
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
|
||||
C = 8
|
||||
r = Resample(C, "upsample3d")
|
||||
@@ -373,7 +373,7 @@ class TestVAE22ResidualBlock:
|
||||
"""Tests for vae22.ResidualBlock."""
|
||||
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -382,7 +382,7 @@ class TestVAE22ResidualBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 8)
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -391,13 +391,13 @@ class TestVAE22ResidualBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_shortcut_when_dims_differ(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_same_dim(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
@@ -408,7 +408,7 @@ class TestResidualBlockLayers:
|
||||
|
||||
def test_layer_names_no_underscore_prefix(self):
|
||||
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 8)
|
||||
params = dict(block.parameters())
|
||||
@@ -417,7 +417,7 @@ class TestResidualBlockLayers:
|
||||
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
|
||||
|
||||
def test_has_expected_layers(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
assert hasattr(block, "layer_0") # first RMS_norm
|
||||
@@ -426,7 +426,7 @@ class TestResidualBlockLayers:
|
||||
assert hasattr(block, "layer_6") # second CausalConv3d
|
||||
|
||||
def test_forward_shape(self):
|
||||
from mlx_video.models.wan.vae22 import ResidualBlockLayers
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -439,7 +439,7 @@ class TestVAE22AttentionBlock:
|
||||
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
from mlx_video.models.wan2.vae22 import AttentionBlock
|
||||
|
||||
block = AttentionBlock(16)
|
||||
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
||||
@@ -450,7 +450,7 @@ class TestVAE22AttentionBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan.vae22 import AttentionBlock
|
||||
from mlx_video.models.wan2.vae22 import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
||||
@@ -466,7 +466,7 @@ class TestHead22:
|
||||
"""Tests for vae22.Head22 output head."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
from mlx_video.models.wan2.vae22 import Head22
|
||||
|
||||
head = Head22(16, out_channels=12)
|
||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||
@@ -476,7 +476,7 @@ class TestHead22:
|
||||
|
||||
def test_layer_names_no_underscore(self):
|
||||
"""Head layers must not use underscore prefix."""
|
||||
from mlx_video.models.wan.vae22 import Head22
|
||||
from mlx_video.models.wan2.vae22 import Head22
|
||||
|
||||
head = Head22(8)
|
||||
assert hasattr(head, "layer_0") # RMS_norm
|
||||
@@ -490,7 +490,7 @@ class TestUnpatchify:
|
||||
"""Tests for vae22._unpatchify."""
|
||||
|
||||
def test_basic_shape(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
from mlx_video.models.wan2.vae22 import _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
@@ -498,7 +498,7 @@ class TestUnpatchify:
|
||||
assert out.shape == (1, 2, 8, 8, 3)
|
||||
|
||||
def test_patch_size_1_noop(self):
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
from mlx_video.models.wan2.vae22 import _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 4, 4, 3))
|
||||
out = _unpatchify(x, patch_size=1)
|
||||
@@ -507,7 +507,7 @@ class TestUnpatchify:
|
||||
|
||||
def test_preserves_content(self):
|
||||
"""Unpatchify should be a lossless rearrangement."""
|
||||
from mlx_video.models.wan.vae22 import _unpatchify
|
||||
from mlx_video.models.wan2.vae22 import _unpatchify
|
||||
|
||||
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
@@ -521,7 +521,7 @@ class TestDenormalizeLatents:
|
||||
"""Tests for vae22.denormalize_latents."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
from mlx_video.models.wan2.vae22 import denormalize_latents
|
||||
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
out = denormalize_latents(z)
|
||||
@@ -529,7 +529,7 @@ class TestDenormalizeLatents:
|
||||
assert out.shape == (1, 2, 4, 4, 48)
|
||||
|
||||
def test_custom_mean_std(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents
|
||||
from mlx_video.models.wan2.vae22 import denormalize_latents
|
||||
|
||||
z = mx.ones((1, 1, 1, 1, 4))
|
||||
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
@@ -542,7 +542,7 @@ class TestDenormalizeLatents:
|
||||
)
|
||||
|
||||
def test_uses_default_constants(self):
|
||||
from mlx_video.models.wan.vae22 import (
|
||||
from mlx_video.models.wan2.vae22 import (
|
||||
VAE22_MEAN,
|
||||
denormalize_latents,
|
||||
)
|
||||
@@ -563,14 +563,14 @@ class TestVAE22NormConstants:
|
||||
"""Tests for VAE22_MEAN and VAE22_STD constants."""
|
||||
|
||||
def test_dimensions(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_MEAN, VAE22_STD
|
||||
from mlx_video.models.wan2.vae22 import VAE22_MEAN, VAE22_STD
|
||||
|
||||
mx.eval(VAE22_MEAN, VAE22_STD)
|
||||
assert VAE22_MEAN.shape == (48,)
|
||||
assert VAE22_STD.shape == (48,)
|
||||
|
||||
def test_std_positive(self):
|
||||
from mlx_video.models.wan.vae22 import VAE22_STD
|
||||
from mlx_video.models.wan2.vae22 import VAE22_STD
|
||||
|
||||
mx.eval(VAE22_STD)
|
||||
assert (np.array(VAE22_STD) > 0).all()
|
||||
@@ -581,7 +581,7 @@ class TestWan22VAEDecoder:
|
||||
|
||||
def test_output_shape_small(self):
|
||||
"""Tiny decoder should produce correct spatial/temporal output."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
|
||||
|
||||
# Use very small dims to keep test fast
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
@@ -597,7 +597,7 @@ class TestWan22VAEDecoder:
|
||||
assert np.array(out).max() <= 1.0
|
||||
|
||||
def test_output_clipped(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
|
||||
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
||||
@@ -611,7 +611,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
"""Tests for vae22.sanitize_wan22_vae_weights."""
|
||||
|
||||
def test_skip_encoder(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.layer.weight": mx.zeros((4,)),
|
||||
@@ -624,7 +624,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert "decoder.conv1.bias" in out
|
||||
|
||||
def test_sequential_index_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
||||
@@ -639,7 +639,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert "decoder.head.layer_2.bias" in out
|
||||
|
||||
def test_resample_conv_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
||||
@@ -650,7 +650,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
|
||||
|
||||
def test_attention_remapping(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
||||
@@ -665,7 +665,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert "decoder.middle.1.proj_bias" in out
|
||||
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
||||
w = mx.zeros((16, 8, 3, 3, 3))
|
||||
@@ -674,7 +674,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
||||
w = mx.zeros((8, 8, 3, 3))
|
||||
@@ -684,7 +684,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert out[key].shape == (8, 3, 3, 8)
|
||||
|
||||
def test_gamma_squeeze(self):
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# gamma: (dim, 1, 1, 1) → (dim,)
|
||||
w = mx.ones((16, 1, 1, 1))
|
||||
@@ -698,7 +698,7 @@ class TestUpResidualBlock:
|
||||
"""Tests for vae22.Up_ResidualBlock."""
|
||||
|
||||
def test_no_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
|
||||
@@ -710,7 +710,7 @@ class TestUpResidualBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 8)
|
||||
|
||||
def test_spatial_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
|
||||
@@ -722,7 +722,7 @@ class TestUpResidualBlock:
|
||||
assert out.shape == (1, 2, 8, 8, 4)
|
||||
|
||||
def test_spatial_temporal_upsample(self):
|
||||
from mlx_video.models.wan.vae22 import Up_ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
|
||||
@@ -738,7 +738,7 @@ class TestPatchify:
|
||||
"""Tests for _patchify and _unpatchify round-trip."""
|
||||
|
||||
def test_roundtrip(self):
|
||||
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
|
||||
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 1, 64, 64, 3))
|
||||
p = _patchify(x, patch_size=2)
|
||||
@@ -748,7 +748,7 @@ class TestPatchify:
|
||||
assert float(mx.abs(x - back).max()) == 0.0
|
||||
|
||||
def test_identity_patch_1(self):
|
||||
from mlx_video.models.wan.vae22 import _patchify, _unpatchify
|
||||
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 8, 8, 3))
|
||||
assert _patchify(x, patch_size=1).shape == x.shape
|
||||
@@ -759,7 +759,7 @@ class TestAvgDown3D:
|
||||
"""Tests for AvgDown3D downsampling."""
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
from mlx_video.models.wan2.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
@@ -768,7 +768,7 @@ class TestAvgDown3D:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
from mlx_video.models.wan2.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||
@@ -777,7 +777,7 @@ class TestAvgDown3D:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_single_frame(self):
|
||||
from mlx_video.models.wan.vae22 import AvgDown3D
|
||||
from mlx_video.models.wan2.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 1, 8, 8, 8))
|
||||
@@ -791,7 +791,7 @@ class TestDownResidualBlock:
|
||||
"""Tests for Down_ResidualBlock."""
|
||||
|
||||
def test_no_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(
|
||||
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
|
||||
@@ -802,7 +802,7 @@ class TestDownResidualBlock:
|
||||
assert out.shape == (1, 2, 8, 8, 8)
|
||||
|
||||
def test_spatial_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(
|
||||
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
|
||||
@@ -813,7 +813,7 @@ class TestDownResidualBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_spatial_temporal_downsample(self):
|
||||
from mlx_video.models.wan.vae22 import Down_ResidualBlock
|
||||
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(
|
||||
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
|
||||
@@ -828,7 +828,7 @@ class TestEncoder3d:
|
||||
"""Tests for Encoder3d."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
from mlx_video.models.wan2.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8)
|
||||
x = mx.random.normal((1, 1, 16, 16, 12))
|
||||
@@ -839,7 +839,7 @@ class TestEncoder3d:
|
||||
assert out.shape == (1, 1, 2, 2, 8)
|
||||
|
||||
def test_multi_frame(self):
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
from mlx_video.models.wan2.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
@@ -854,7 +854,7 @@ class TestWan22VAEEncoder:
|
||||
"""Tests for Wan22VAEEncoder wrapper."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
|
||||
@@ -865,7 +865,7 @@ class TestWan22VAEEncoder:
|
||||
assert z.shape == (1, 1, 2, 2, 48)
|
||||
|
||||
def test_full_dim(self):
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=160)
|
||||
img = mx.random.normal((1, 1, 64, 64, 3))
|
||||
@@ -880,7 +880,7 @@ class TestNormalizeLatents:
|
||||
"""Tests for normalize/denormalize latent roundtrip."""
|
||||
|
||||
def test_roundtrip(self):
|
||||
from mlx_video.models.wan.vae22 import denormalize_latents, normalize_latents
|
||||
from mlx_video.models.wan2.vae22 import denormalize_latents, normalize_latents
|
||||
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
z_norm = normalize_latents(z)
|
||||
@@ -895,7 +895,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_encoder_temporal_downsample_pattern(self):
|
||||
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
from mlx_video.models.wan2.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
@@ -906,7 +906,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_wrapper_uses_correct_pattern(self):
|
||||
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
||||
from mlx_video.models.wan.vae22 import Resample, Wan22VAEEncoder
|
||||
from mlx_video.models.wan2.vae22 import Resample, Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
down_blocks = enc.encoder.downsamples
|
||||
@@ -921,7 +921,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_single_frame_encoder(self):
|
||||
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEEncoder
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||
@@ -933,7 +933,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_wrong_order_gives_different_result(self):
|
||||
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
||||
from mlx_video.models.wan.vae22 import Encoder3d
|
||||
from mlx_video.models.wan2.vae22 import Encoder3d
|
||||
|
||||
enc_correct = Encoder3d(
|
||||
dim=16, z_dim=8, temperal_downsample=(False, True, True)
|
||||
@@ -963,7 +963,7 @@ class TestVAE21RoundTrip:
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
|
||||
from mlx_video.models.wan.vae import Decoder3d, Encoder3d
|
||||
from mlx_video.models.wan2.vae import Decoder3d, Encoder3d
|
||||
|
||||
z_dim = 4
|
||||
dim = 8
|
||||
@@ -995,7 +995,7 @@ class TestVAE22RoundTrip:
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
|
||||
from mlx_video.models.wan.vae22 import (
|
||||
from mlx_video.models.wan2.vae22 import (
|
||||
Wan22VAEDecoder,
|
||||
Wan22VAEEncoder,
|
||||
denormalize_latents,
|
||||
|
||||
Reference in New Issue
Block a user