format
This commit is contained in:
@@ -6,7 +6,12 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
def load_wan_model(model_path: Path, config, quantization: dict | None = None, loras: list | None = None):
|
||||
def load_wan_model(
|
||||
model_path: Path,
|
||||
config,
|
||||
quantization: dict | None = None,
|
||||
loras: list | None = None,
|
||||
):
|
||||
"""Load and initialize WanModel, with optional quantization and LoRA support.
|
||||
|
||||
Args:
|
||||
@@ -93,9 +98,11 @@ def load_vae_decoder(model_path: Path, config=None):
|
||||
|
||||
if is_wan22:
|
||||
from mlx_video.models.wan.vae22 import Wan22VAEDecoder
|
||||
|
||||
vae = Wan22VAEDecoder(z_dim=48)
|
||||
else:
|
||||
from mlx_video.models.wan.vae import WanVAE
|
||||
|
||||
vae = WanVAE(z_dim=16)
|
||||
|
||||
weights = mx.load(str(model_path))
|
||||
@@ -140,6 +147,7 @@ def _clean_text(text: str) -> str:
|
||||
|
||||
try:
|
||||
import ftfy
|
||||
|
||||
text = ftfy.fix_text(text)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user