Refactor model path handling: moved get_model_path function to utils.py and updated generate.py to use the new import.

This commit is contained in:
Prince Canuma
2026-01-12 15:54:32 +01:00
parent 75511a0b17
commit 666e1f2e0c
2 changed files with 17 additions and 15 deletions

View File

@@ -6,6 +6,22 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from functools import partial
from pathlib import Path
from huggingface_hub import snapshot_download
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: