Refactor model path handling: moved get_model_path function to utils.py and updated generate.py to use the new import.
This commit is contained in:
@@ -6,6 +6,22 @@ from typing import Optional
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user