Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
from mlx_video.models.ltx_2.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
decode_with_tiling,
|
||||
split_in_spatial,
|
||||
@@ -75,7 +75,7 @@ class TestWan22TiledDecoding:
|
||||
|
||||
def _make_small_wan22_decoder(self):
|
||||
"""Create a small Wan2.2 decoder for testing."""
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
from mlx_video.models.wan2.vae22 import Wan22VAEDecoder
|
||||
|
||||
# Use very small dimensions for fast testing
|
||||
vae = Wan22VAEDecoder(z_dim=48, dim=16, dec_dim=16)
|
||||
@@ -139,7 +139,7 @@ class TestWan21TiledDecoding:
|
||||
|
||||
def _make_small_wan21_vae(self):
|
||||
"""Create a small Wan2.1 VAE for testing."""
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
from mlx_video.models.wan2.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
mx.eval(vae.parameters())
|
||||
@@ -192,7 +192,7 @@ class TestWan21TemporalScale:
|
||||
|
||||
def test_wan21_decoder_temporal_output(self):
|
||||
"""Wan2.1 Decoder3d should produce T*4 temporal output (non-causal doubling)."""
|
||||
from mlx_video.models.wan.vae import Decoder3d
|
||||
from mlx_video.models.wan2.vae import Decoder3d
|
||||
|
||||
# Small decoder for fast test
|
||||
dec = Decoder3d(
|
||||
|
||||
Reference in New Issue
Block a user