Refactor Wan model imports and update script paths in pyproject.toml; transition from wan to wan2 module structure for improved organization and clarity.
This commit is contained in:
@@ -247,7 +247,7 @@ def _load_lora_configs(
|
||||
|
||||
Shared between weight-merging and runtime-wrapping paths.
|
||||
"""
|
||||
from mlx_video.generate_wan import Colors
|
||||
from mlx_video.models.wan2.generate import Colors
|
||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||
|
||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||
@@ -282,7 +282,7 @@ def load_and_apply_loras(
|
||||
|
||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||
"""
|
||||
from mlx_video.generate_wan import Colors
|
||||
from mlx_video.models.wan2.generate import Colors
|
||||
from mlx_video.lora import apply_loras_to_weights
|
||||
|
||||
if not lora_configs:
|
||||
@@ -411,7 +411,7 @@ def convert_wan_checkpoint(
|
||||
print(" Warning: No transformer weights found!")
|
||||
|
||||
# Save config — detect model size from source config.json or transformer weights
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
def _detect_config():
|
||||
"""Detect config from source config.json or transformer weight shapes."""
|
||||
@@ -522,7 +522,7 @@ def convert_wan_checkpoint(
|
||||
print(f"Converting VAE ({'Wan2.2' if is_wan22_vae else 'Wan2.1'})...")
|
||||
weights = load_torch_weights(str(vae_path))
|
||||
if is_wan22_vae:
|
||||
from mlx_video.models.wan.vae22 import sanitize_wan22_vae_weights
|
||||
from mlx_video.models.wan2.vae22 import sanitize_wan22_vae_weights
|
||||
|
||||
include_encoder = config.model_type in ("ti2v", "i2v")
|
||||
weights = sanitize_wan22_vae_weights(
|
||||
@@ -594,7 +594,7 @@ def _quantize_saved_model(
|
||||
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.wan.model import WanModel
|
||||
from mlx_video.models.wan2.model import WanModel
|
||||
|
||||
if source_dir is None:
|
||||
source_dir = output_dir
|
||||
@@ -704,7 +704,7 @@ def quantize_mlx_model(
|
||||
).exists()
|
||||
|
||||
# Build model config
|
||||
from mlx_video.models.wan.config import WanModelConfig
|
||||
from mlx_video.models.wan2.config import WanModelConfig
|
||||
|
||||
config_dict = {
|
||||
k: v for k, v in cfg.items() if k in WanModelConfig.__dataclass_fields__
|
||||
|
||||
Reference in New Issue
Block a user