Refactor model path handling: moved get_model_path function to utils.py and updated generate.py to use the new import.

This commit is contained in:
Prince Canuma
2026-01-12 15:54:32 +01:00
parent 75511a0b17
commit 666e1f2e0c
2 changed files with 17 additions and 15 deletions

View File

@@ -14,7 +14,7 @@ from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from huggingface_hub import snapshot_download
from mlx_video.utils import get_model_path
# Distilled sigma schedules
@@ -95,20 +95,6 @@ def create_position_grid(
return mx.array(pixel_coords, dtype=mx.float32)
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
def denoise(
latents: mx.array,
positions: mx.array,

View File

@@ -6,6 +6,22 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from functools import partial
from pathlib import Path
from huggingface_hub import snapshot_download
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: