Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.

This commit is contained in:
Prince Canuma
2026-03-18 17:52:30 +01:00
parent 17397da70c
commit 6c63163671
28 changed files with 354 additions and 1033 deletions

View File

@@ -3,7 +3,7 @@
import mlx.core as mx
import numpy as np
from mlx_video.models.ltx.video_vae.tiling import (
from mlx_video.models.ltx_2.video_vae.tiling import (
TilingConfig,
decode_with_tiling,
split_in_spatial,
@@ -75,7 +75,7 @@ class TestWan22TiledDecoding:
def _make_small_wan22_decoder(self):
"""Create a small Wan2.2 decoder for testing."""
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
# Use very small dimensions for fast testing
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
@@ -139,7 +139,7 @@ class TestWan21TiledDecoding:
def _make_small_wan21_vae(self):
"""Create a small Wan2.1 VAE for testing."""
from mlx_video.models.wan.vae import WanVAE
from mlx_video.models.wan2.vae import WanVAE
vae = WanVAE(z_dim=16)
mx.eval(vae.parameters())
@@ -192,7 +192,7 @@ class TestWan21TemporalScale:
def test_wan21_decoder_temporal_output(self):
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
from mlx_video.models.wan.vae import Decoder3d
from mlx_video.models.wan2.vae import Decoder3d
# Small decoder for fast test
dec = Decoder3d(