Remove Wan2 model files, including configuration, attention mechanisms, and utility functions, to streamline the codebase and eliminate unused components. This cleanup enhances maintainability and focuses on the core functionality of the Wan2 module.
This commit is contained in:
@@ -12,7 +12,7 @@ import numpy as np
|
||||
|
||||
class TestCausalConv3d:
|
||||
def test_output_shape_stride1(self):
|
||||
from mlx_video.models.wan2.vae import CausalConv3d
|
||||
from mlx_video.models.wan_2.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, stride=1, padding=1)
|
||||
# Initialize weights
|
||||
@@ -28,7 +28,7 @@ class TestCausalConv3d:
|
||||
assert out.shape[4] == 8 # W preserved
|
||||
|
||||
def test_output_shape_kernel1(self):
|
||||
from mlx_video.models.wan2.vae import CausalConv3d
|
||||
from mlx_video.models.wan_2.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=1, stride=1, padding=0)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.02
|
||||
@@ -39,7 +39,7 @@ class TestCausalConv3d:
|
||||
|
||||
def test_causal_padding(self):
|
||||
"""Causal conv should only use past/current frames, not future."""
|
||||
from mlx_video.models.wan2.vae import CausalConv3d
|
||||
from mlx_video.models.wan_2.vae import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, stride=1, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
@@ -56,7 +56,7 @@ class TestCausalConv3d:
|
||||
|
||||
class TestResidualBlock:
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan2.vae import ResidualBlock
|
||||
from mlx_video.models.wan_2.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
@@ -65,7 +65,7 @@ class TestResidualBlock:
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan2.vae import ResidualBlock
|
||||
from mlx_video.models.wan_2.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
@@ -74,13 +74,13 @@ class TestResidualBlock:
|
||||
assert out.shape == (1, 16, 2, 4, 4)
|
||||
|
||||
def test_shortcut_exists_when_dims_differ(self):
|
||||
from mlx_video.models.wan2.vae import ResidualBlock
|
||||
from mlx_video.models.wan_2.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_when_dims_same(self):
|
||||
from mlx_video.models.wan2.vae import ResidualBlock
|
||||
from mlx_video.models.wan_2.vae import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
@@ -88,7 +88,7 @@ class TestResidualBlock:
|
||||
|
||||
class TestAttentionBlock:
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.vae import AttentionBlock
|
||||
from mlx_video.models.wan_2.vae import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 2, 4, 4))
|
||||
@@ -97,7 +97,7 @@ class TestAttentionBlock:
|
||||
assert out.shape == (1, 8, 2, 4, 4)
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan2.vae import AttentionBlock
|
||||
from mlx_video.models.wan_2.vae import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
x = mx.random.normal((1, 8, 1, 3, 3))
|
||||
@@ -109,7 +109,7 @@ class TestAttentionBlock:
|
||||
|
||||
class TestWanVAE:
|
||||
def test_instantiation(self):
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
from mlx_video.models.wan_2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
assert vae.z_dim == 16
|
||||
@@ -117,7 +117,7 @@ class TestWanVAE:
|
||||
assert vae.std.shape == (16,)
|
||||
|
||||
def test_normalization_stats(self):
|
||||
from mlx_video.models.wan2.vae import VAE_MEAN, VAE_STD
|
||||
from mlx_video.models.wan_2.vae import VAE_MEAN, VAE_STD
|
||||
|
||||
assert len(VAE_MEAN) == 16
|
||||
assert len(VAE_STD) == 16
|
||||
@@ -133,7 +133,7 @@ class TestVAE22CausalConv3d:
|
||||
"""Tests for vae22.CausalConv3d (channels-last)."""
|
||||
|
||||
def test_output_shape_k3(self):
|
||||
from mlx_video.models.wan2.vae22 import CausalConv3d
|
||||
from mlx_video.models.wan_2.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(8, 16, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8)) # [B, T, H, W, C]
|
||||
@@ -142,7 +142,7 @@ class TestVAE22CausalConv3d:
|
||||
assert out.shape == (1, 4, 8, 8, 16)
|
||||
|
||||
def test_output_shape_k1(self):
|
||||
from mlx_video.models.wan2.vae22 import CausalConv3d
|
||||
from mlx_video.models.wan_2.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(8, 16, kernel_size=1)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -152,7 +152,7 @@ class TestVAE22CausalConv3d:
|
||||
|
||||
def test_temporal_causal(self):
|
||||
"""Output at t=0 should not depend on t>0."""
|
||||
from mlx_video.models.wan2.vae22 import CausalConv3d
|
||||
from mlx_video.models.wan_2.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(2, 2, kernel_size=3, padding=1)
|
||||
conv.weight = mx.random.normal(conv.weight.shape) * 0.1
|
||||
@@ -178,7 +178,7 @@ class TestVAE22CausalConv3d:
|
||||
|
||||
def test_channels_last_format(self):
|
||||
"""Verify input/output are channels-last [B, T, H, W, C]."""
|
||||
from mlx_video.models.wan2.vae22 import CausalConv3d
|
||||
from mlx_video.models.wan_2.vae22 import CausalConv3d
|
||||
|
||||
conv = CausalConv3d(4, 8, kernel_size=3, padding=1)
|
||||
x = mx.random.normal((2, 3, 6, 6, 4))
|
||||
@@ -191,7 +191,7 @@ class TestRMSNorm:
|
||||
"""Tests for vae22.RMS_norm (actually L2 normalization)."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import RMS_norm
|
||||
from mlx_video.models.wan_2.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(16)
|
||||
x = mx.random.normal((2, 4, 4, 4, 16))
|
||||
@@ -201,7 +201,7 @@ class TestRMSNorm:
|
||||
|
||||
def test_l2_normalization(self):
|
||||
"""RMS_norm should normalize to unit L2 norm * sqrt(dim)."""
|
||||
from mlx_video.models.wan2.vae22 import RMS_norm
|
||||
from mlx_video.models.wan_2.vae22 import RMS_norm
|
||||
|
||||
dim = 32
|
||||
norm = RMS_norm(dim)
|
||||
@@ -215,7 +215,7 @@ class TestRMSNorm:
|
||||
|
||||
def test_scale_invariant(self):
|
||||
"""Scaling input by constant should not change output (L2 norm property)."""
|
||||
from mlx_video.models.wan2.vae22 import RMS_norm
|
||||
from mlx_video.models.wan_2.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(8)
|
||||
x = mx.random.normal((1, 1, 1, 1, 8))
|
||||
@@ -226,7 +226,7 @@ class TestRMSNorm:
|
||||
|
||||
def test_gamma_effect(self):
|
||||
"""Non-unit gamma should scale output."""
|
||||
from mlx_video.models.wan2.vae22 import RMS_norm
|
||||
from mlx_video.models.wan_2.vae22 import RMS_norm
|
||||
|
||||
norm = RMS_norm(4)
|
||||
norm.gamma = mx.array([2.0, 2.0, 2.0, 2.0])
|
||||
@@ -241,7 +241,7 @@ class TestDupUp3D:
|
||||
"""Tests for vae22.DupUp3D spatial/temporal upsampling."""
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan2.vae22 import DupUp3D
|
||||
from mlx_video.models.wan_2.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
@@ -250,7 +250,7 @@ class TestDupUp3D:
|
||||
assert out.shape == (1, 3, 8, 8, 4)
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan2.vae22 import DupUp3D
|
||||
from mlx_video.models.wan_2.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(16, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 16))
|
||||
@@ -259,7 +259,7 @@ class TestDupUp3D:
|
||||
assert out.shape == (1, 6, 8, 8, 8)
|
||||
|
||||
def test_first_chunk_trims(self):
|
||||
from mlx_video.models.wan2.vae22 import DupUp3D
|
||||
from mlx_video.models.wan_2.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
@@ -271,7 +271,7 @@ class TestDupUp3D:
|
||||
assert out_trimmed.shape[1] == 5
|
||||
|
||||
def test_no_temporal_first_chunk_noop(self):
|
||||
from mlx_video.models.wan2.vae22 import DupUp3D
|
||||
from mlx_video.models.wan_2.vae22 import DupUp3D
|
||||
|
||||
up = DupUp3D(8, 4, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 3, 4, 4, 8))
|
||||
@@ -286,7 +286,7 @@ class TestVAE22Resample:
|
||||
"""Tests for vae22.Resample (spatial/temporal upsampling)."""
|
||||
|
||||
def test_upsample2d_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
from mlx_video.models.wan_2.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample2d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
@@ -296,7 +296,7 @@ class TestVAE22Resample:
|
||||
assert out.shape == (1, 2, 8, 8, 8) # 2x spatial, same temporal
|
||||
|
||||
def test_upsample3d_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
from mlx_video.models.wan_2.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
@@ -306,7 +306,7 @@ class TestVAE22Resample:
|
||||
assert out.shape == (1, 4, 8, 8, 8) # 2x spatial + 2x temporal
|
||||
|
||||
def test_upsample3d_first_chunk(self):
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
from mlx_video.models.wan_2.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
@@ -318,7 +318,7 @@ class TestVAE22Resample:
|
||||
|
||||
def test_upsample3d_first_chunk_single_frame(self):
|
||||
"""Single-frame input with first_chunk: no temporal upsample."""
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
from mlx_video.models.wan_2.vae22 import Resample
|
||||
|
||||
r = Resample(8, "upsample3d")
|
||||
r.resample_weight = mx.random.normal(r.resample_weight.shape) * 0.01
|
||||
@@ -336,7 +336,7 @@ class TestVAE22Resample:
|
||||
We verify this by checking that the first output frame depends only on
|
||||
the first input frame (not on time_conv parameters).
|
||||
"""
|
||||
from mlx_video.models.wan2.vae22 import Resample
|
||||
from mlx_video.models.wan_2.vae22 import Resample
|
||||
|
||||
C = 8
|
||||
r = Resample(C, "upsample3d")
|
||||
@@ -373,7 +373,7 @@ class TestVAE22ResidualBlock:
|
||||
"""Tests for vae22.ResidualBlock."""
|
||||
|
||||
def test_same_dim(self):
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -382,7 +382,7 @@ class TestVAE22ResidualBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 8)
|
||||
|
||||
def test_different_dim(self):
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -391,13 +391,13 @@ class TestVAE22ResidualBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_shortcut_when_dims_differ(self):
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 16)
|
||||
assert block.shortcut is not None
|
||||
|
||||
def test_no_shortcut_same_dim(self):
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import ResidualBlock
|
||||
|
||||
block = ResidualBlock(8, 8)
|
||||
assert block.shortcut is None
|
||||
@@ -408,7 +408,7 @@ class TestResidualBlockLayers:
|
||||
|
||||
def test_layer_names_no_underscore_prefix(self):
|
||||
"""Layer names must NOT start with underscore (MLX ignores them)."""
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
|
||||
from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 8)
|
||||
params = dict(block.parameters())
|
||||
@@ -417,7 +417,7 @@ class TestResidualBlockLayers:
|
||||
assert not key.startswith("_"), f"Parameter {key} starts with underscore"
|
||||
|
||||
def test_has_expected_layers(self):
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
|
||||
from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
assert hasattr(block, "layer_0") # first RMS_norm
|
||||
@@ -426,7 +426,7 @@ class TestResidualBlockLayers:
|
||||
assert hasattr(block, "layer_6") # second CausalConv3d
|
||||
|
||||
def test_forward_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import ResidualBlockLayers
|
||||
from mlx_video.models.wan_2.vae22 import ResidualBlockLayers
|
||||
|
||||
block = ResidualBlockLayers(8, 16)
|
||||
x = mx.random.normal((1, 2, 4, 4, 8))
|
||||
@@ -439,7 +439,7 @@ class TestVAE22AttentionBlock:
|
||||
"""Tests for vae22.AttentionBlock (per-frame 2D self-attention)."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import AttentionBlock
|
||||
from mlx_video.models.wan_2.vae22 import AttentionBlock
|
||||
|
||||
block = AttentionBlock(16)
|
||||
block.to_qkv_weight = mx.random.normal(block.to_qkv_weight.shape) * 0.01
|
||||
@@ -450,7 +450,7 @@ class TestVAE22AttentionBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_residual_connection(self):
|
||||
from mlx_video.models.wan2.vae22 import AttentionBlock
|
||||
from mlx_video.models.wan_2.vae22 import AttentionBlock
|
||||
|
||||
block = AttentionBlock(8)
|
||||
block.to_qkv_weight = mx.zeros(block.to_qkv_weight.shape)
|
||||
@@ -466,7 +466,7 @@ class TestHead22:
|
||||
"""Tests for vae22.Head22 output head."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import Head22
|
||||
from mlx_video.models.wan_2.vae22 import Head22
|
||||
|
||||
head = Head22(16, out_channels=12)
|
||||
x = mx.random.normal((1, 2, 4, 4, 16))
|
||||
@@ -476,7 +476,7 @@ class TestHead22:
|
||||
|
||||
def test_layer_names_no_underscore(self):
|
||||
"""Head layers must not use underscore prefix."""
|
||||
from mlx_video.models.wan2.vae22 import Head22
|
||||
from mlx_video.models.wan_2.vae22 import Head22
|
||||
|
||||
head = Head22(8)
|
||||
assert hasattr(head, "layer_0") # RMS_norm
|
||||
@@ -490,7 +490,7 @@ class TestUnpatchify:
|
||||
"""Tests for vae22._unpatchify."""
|
||||
|
||||
def test_basic_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import _unpatchify
|
||||
from mlx_video.models.wan_2.vae22 import _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 4, 4, 12)) # 12 = 3 * 2 * 2
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
@@ -498,7 +498,7 @@ class TestUnpatchify:
|
||||
assert out.shape == (1, 2, 8, 8, 3)
|
||||
|
||||
def test_patch_size_1_noop(self):
|
||||
from mlx_video.models.wan2.vae22 import _unpatchify
|
||||
from mlx_video.models.wan_2.vae22 import _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 4, 4, 3))
|
||||
out = _unpatchify(x, patch_size=1)
|
||||
@@ -507,7 +507,7 @@ class TestUnpatchify:
|
||||
|
||||
def test_preserves_content(self):
|
||||
"""Unpatchify should be a lossless rearrangement."""
|
||||
from mlx_video.models.wan2.vae22 import _unpatchify
|
||||
from mlx_video.models.wan_2.vae22 import _unpatchify
|
||||
|
||||
x = mx.arange(48).reshape(1, 1, 2, 2, 12).astype(mx.float32)
|
||||
out = _unpatchify(x, patch_size=2)
|
||||
@@ -521,7 +521,7 @@ class TestDenormalizeLatents:
|
||||
"""Tests for vae22.denormalize_latents."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import denormalize_latents
|
||||
from mlx_video.models.wan_2.vae22 import denormalize_latents
|
||||
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
out = denormalize_latents(z)
|
||||
@@ -529,7 +529,7 @@ class TestDenormalizeLatents:
|
||||
assert out.shape == (1, 2, 4, 4, 48)
|
||||
|
||||
def test_custom_mean_std(self):
|
||||
from mlx_video.models.wan2.vae22 import denormalize_latents
|
||||
from mlx_video.models.wan_2.vae22 import denormalize_latents
|
||||
|
||||
z = mx.ones((1, 1, 1, 1, 4))
|
||||
mean = mx.array([1.0, 2.0, 3.0, 4.0])
|
||||
@@ -542,7 +542,7 @@ class TestDenormalizeLatents:
|
||||
)
|
||||
|
||||
def test_uses_default_constants(self):
|
||||
from mlx_video.models.wan2.vae22 import (
|
||||
from mlx_video.models.wan_2.vae22 import (
|
||||
VAE22_MEAN,
|
||||
denormalize_latents,
|
||||
)
|
||||
@@ -563,14 +563,14 @@ class TestVAE22NormConstants:
|
||||
"""Tests for VAE22_MEAN and VAE22_STD constants."""
|
||||
|
||||
def test_dimensions(self):
|
||||
from mlx_video.models.wan2.vae22 import VAE22_MEAN, VAE22_STD
|
||||
from mlx_video.models.wan_2.vae22 import VAE22_MEAN, VAE22_STD
|
||||
|
||||
mx.eval(VAE22_MEAN, VAE22_STD)
|
||||
assert VAE22_MEAN.shape == (48,)
|
||||
assert VAE22_STD.shape == (48,)
|
||||
|
||||
def test_std_positive(self):
|
||||
from mlx_video.models.wan2.vae22 import VAE22_STD
|
||||
from mlx_video.models.wan_2.vae22 import VAE22_STD
|
||||
|
||||
mx.eval(VAE22_STD)
|
||||
assert (np.array(VAE22_STD) > 0).all()
|
||||
@@ -581,7 +581,7 @@ class TestWan22VAEDecoder:
|
||||
|
||||
def test_output_shape_small(self):
|
||||
"""Tiny decoder should produce correct spatial/temporal output."""
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
|
||||
|
||||
# Use very small dims to keep test fast
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
@@ -597,7 +597,7 @@ class TestWan22VAEDecoder:
|
||||
assert np.array(out).max() <= 1.0
|
||||
|
||||
def test_output_clipped(self):
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEDecoder
|
||||
|
||||
dec = Wan22VAEDecoder(z_dim=4, dim=8, dec_dim=8)
|
||||
z = mx.random.normal((1, 2, 2, 2, 4)) * 10.0 # large values
|
||||
@@ -611,7 +611,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
"""Tests for vae22.sanitize_wan22_vae_weights."""
|
||||
|
||||
def test_skip_encoder(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"encoder.layer.weight": mx.zeros((4,)),
|
||||
@@ -624,7 +624,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert "decoder.conv1.bias" in out
|
||||
|
||||
def test_sequential_index_remapping(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.upsamples.0.upsamples.0.residual.0.gamma": mx.ones((8,)),
|
||||
@@ -639,7 +639,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert "decoder.head.layer_2.bias" in out
|
||||
|
||||
def test_resample_conv_remapping(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.upsamples.1.upsamples.3.resample.1.weight": mx.zeros((8, 8, 3, 3)),
|
||||
@@ -650,7 +650,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert "decoder.upsamples.1.upsamples.3.resample_bias" in out
|
||||
|
||||
def test_attention_remapping(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
weights = {
|
||||
"decoder.middle.1.to_qkv.weight": mx.zeros((24, 8, 1, 1)),
|
||||
@@ -665,7 +665,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert "decoder.middle.1.proj_bias" in out
|
||||
|
||||
def test_conv3d_transpose(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# Conv3d weight: [O, I, D, H, W] → [O, D, H, W, I]
|
||||
w = mx.zeros((16, 8, 3, 3, 3))
|
||||
@@ -674,7 +674,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert out["decoder.conv1.weight"].shape == (16, 3, 3, 3, 8)
|
||||
|
||||
def test_conv2d_transpose(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# Conv2d weight: [O, I, H, W] → [O, H, W, I]
|
||||
w = mx.zeros((8, 8, 3, 3))
|
||||
@@ -684,7 +684,7 @@ class TestSanitizeWan22VAEWeights:
|
||||
assert out[key].shape == (8, 3, 3, 8)
|
||||
|
||||
def test_gamma_squeeze(self):
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan_2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
# gamma: (dim, 1, 1, 1) → (dim,)
|
||||
w = mx.ones((16, 1, 1, 1))
|
||||
@@ -698,7 +698,7 @@ class TestUpResidualBlock:
|
||||
"""Tests for vae22.Up_ResidualBlock."""
|
||||
|
||||
def test_no_upsample(self):
|
||||
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 8, num_res_blocks=1, temperal_upsample=False, up_flag=False
|
||||
@@ -710,7 +710,7 @@ class TestUpResidualBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 8)
|
||||
|
||||
def test_spatial_upsample(self):
|
||||
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 4, num_res_blocks=1, temperal_upsample=False, up_flag=True
|
||||
@@ -722,7 +722,7 @@ class TestUpResidualBlock:
|
||||
assert out.shape == (1, 2, 8, 8, 4)
|
||||
|
||||
def test_spatial_temporal_upsample(self):
|
||||
from mlx_video.models.wan2.vae22 import Up_ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import Up_ResidualBlock
|
||||
|
||||
block = Up_ResidualBlock(
|
||||
8, 4, num_res_blocks=1, temperal_upsample=True, up_flag=True
|
||||
@@ -738,7 +738,7 @@ class TestPatchify:
|
||||
"""Tests for _patchify and _unpatchify round-trip."""
|
||||
|
||||
def test_roundtrip(self):
|
||||
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
|
||||
from mlx_video.models.wan_2.vae22 import _patchify, _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 1, 64, 64, 3))
|
||||
p = _patchify(x, patch_size=2)
|
||||
@@ -748,7 +748,7 @@ class TestPatchify:
|
||||
assert float(mx.abs(x - back).max()) == 0.0
|
||||
|
||||
def test_identity_patch_1(self):
|
||||
from mlx_video.models.wan2.vae22 import _patchify, _unpatchify
|
||||
from mlx_video.models.wan_2.vae22 import _patchify, _unpatchify
|
||||
|
||||
x = mx.random.normal((1, 2, 8, 8, 3))
|
||||
assert _patchify(x, patch_size=1).shape == x.shape
|
||||
@@ -759,7 +759,7 @@ class TestAvgDown3D:
|
||||
"""Tests for AvgDown3D downsampling."""
|
||||
|
||||
def test_spatial_only(self):
|
||||
from mlx_video.models.wan2.vae22 import AvgDown3D
|
||||
from mlx_video.models.wan_2.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 16, factor_t=1, factor_s=2)
|
||||
x = mx.random.normal((1, 2, 8, 8, 8))
|
||||
@@ -768,7 +768,7 @@ class TestAvgDown3D:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_temporal_and_spatial(self):
|
||||
from mlx_video.models.wan2.vae22 import AvgDown3D
|
||||
from mlx_video.models.wan_2.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 16, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 4, 8, 8, 8))
|
||||
@@ -777,7 +777,7 @@ class TestAvgDown3D:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_single_frame(self):
|
||||
from mlx_video.models.wan2.vae22 import AvgDown3D
|
||||
from mlx_video.models.wan_2.vae22 import AvgDown3D
|
||||
|
||||
down = AvgDown3D(8, 8, factor_t=2, factor_s=2)
|
||||
x = mx.random.normal((1, 1, 8, 8, 8))
|
||||
@@ -791,7 +791,7 @@ class TestDownResidualBlock:
|
||||
"""Tests for Down_ResidualBlock."""
|
||||
|
||||
def test_no_downsample(self):
|
||||
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(
|
||||
8, 8, num_res_blocks=1, temperal_downsample=False, down_flag=False
|
||||
@@ -802,7 +802,7 @@ class TestDownResidualBlock:
|
||||
assert out.shape == (1, 2, 8, 8, 8)
|
||||
|
||||
def test_spatial_downsample(self):
|
||||
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(
|
||||
8, 16, num_res_blocks=1, temperal_downsample=False, down_flag=True
|
||||
@@ -813,7 +813,7 @@ class TestDownResidualBlock:
|
||||
assert out.shape == (1, 2, 4, 4, 16)
|
||||
|
||||
def test_spatial_temporal_downsample(self):
|
||||
from mlx_video.models.wan2.vae22 import Down_ResidualBlock
|
||||
from mlx_video.models.wan_2.vae22 import Down_ResidualBlock
|
||||
|
||||
block = Down_ResidualBlock(
|
||||
8, 16, num_res_blocks=1, temperal_downsample=True, down_flag=True
|
||||
@@ -828,7 +828,7 @@ class TestEncoder3d:
|
||||
"""Tests for Encoder3d."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import Encoder3d
|
||||
from mlx_video.models.wan_2.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8)
|
||||
x = mx.random.normal((1, 1, 16, 16, 12))
|
||||
@@ -839,7 +839,7 @@ class TestEncoder3d:
|
||||
assert out.shape == (1, 1, 2, 2, 8)
|
||||
|
||||
def test_multi_frame(self):
|
||||
from mlx_video.models.wan2.vae22 import Encoder3d
|
||||
from mlx_video.models.wan_2.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(True, True, False))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
@@ -854,7 +854,7 @@ class TestWan22VAEEncoder:
|
||||
"""Tests for Wan22VAEEncoder wrapper."""
|
||||
|
||||
def test_output_shape(self):
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
# Input: single image 32×32 (patchify÷2 → 16×16, then 3 spatial ÷8 → 2×2)
|
||||
@@ -865,7 +865,7 @@ class TestWan22VAEEncoder:
|
||||
assert z.shape == (1, 1, 2, 2, 48)
|
||||
|
||||
def test_full_dim(self):
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=160)
|
||||
img = mx.random.normal((1, 1, 64, 64, 3))
|
||||
@@ -880,7 +880,7 @@ class TestNormalizeLatents:
|
||||
"""Tests for normalize/denormalize latent roundtrip."""
|
||||
|
||||
def test_roundtrip(self):
|
||||
from mlx_video.models.wan2.vae22 import denormalize_latents, normalize_latents
|
||||
from mlx_video.models.wan_2.vae22 import denormalize_latents, normalize_latents
|
||||
|
||||
z = mx.random.normal((1, 2, 4, 4, 48))
|
||||
z_norm = normalize_latents(z)
|
||||
@@ -895,7 +895,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_encoder_temporal_downsample_pattern(self):
|
||||
"""Encoder3d with (False, True, True): T=5→5→3→2."""
|
||||
from mlx_video.models.wan2.vae22 import Encoder3d
|
||||
from mlx_video.models.wan_2.vae22 import Encoder3d
|
||||
|
||||
enc = Encoder3d(dim=16, z_dim=8, temperal_downsample=(False, True, True))
|
||||
x = mx.random.normal((1, 5, 16, 16, 12))
|
||||
@@ -906,7 +906,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_wrapper_uses_correct_pattern(self):
|
||||
"""Wan22VAEEncoder should use (False, True, True) temporal downsample."""
|
||||
from mlx_video.models.wan2.vae22 import Resample, Wan22VAEEncoder
|
||||
from mlx_video.models.wan_2.vae22 import Resample, Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
down_blocks = enc.encoder.downsamples
|
||||
@@ -921,7 +921,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_single_frame_encoder(self):
|
||||
"""Single frame (T=1) should work with (False, True, True) pattern."""
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEEncoder
|
||||
from mlx_video.models.wan_2.vae22 import Wan22VAEEncoder
|
||||
|
||||
enc = Wan22VAEEncoder(z_dim=48, dim=16)
|
||||
img = mx.random.normal((1, 1, 32, 32, 3))
|
||||
@@ -933,7 +933,7 @@ class TestVAEEncoderTemporalOrder:
|
||||
|
||||
def test_wrong_order_gives_different_result(self):
|
||||
"""(True, True, False) vs (False, True, True) produce different outputs."""
|
||||
from mlx_video.models.wan2.vae22 import Encoder3d
|
||||
from mlx_video.models.wan_2.vae22 import Encoder3d
|
||||
|
||||
enc_correct = Encoder3d(
|
||||
dim=16, z_dim=8, temperal_downsample=(False, True, True)
|
||||
@@ -963,7 +963,7 @@ class TestVAE21RoundTrip:
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Encoder3d → Decoder3d: output shape matches input, values are finite."""
|
||||
from mlx_video.models.wan2.vae import Decoder3d, Encoder3d
|
||||
from mlx_video.models.wan_2.vae import Decoder3d, Encoder3d
|
||||
|
||||
z_dim = 4
|
||||
dim = 8
|
||||
@@ -995,7 +995,7 @@ class TestVAE22RoundTrip:
|
||||
|
||||
def test_encode_decode_shape_and_values(self):
|
||||
"""Wan22VAEEncoder → Wan22VAEDecoder: shapes consistent, values in range."""
|
||||
from mlx_video.models.wan2.vae22 import (
|
||||
from mlx_video.models.wan_2.vae22 import (
|
||||
Wan22VAEDecoder,
|
||||
Wan22VAEEncoder,
|
||||
denormalize_latents,
|
||||
|
||||
Reference in New Issue
Block a user