From 666e1f2e0c6a4eaccf310ef54168627a2add8f32 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 12 Jan 2026 15:54:32 +0100 Subject: [PATCH] Refactor model path handling: moved get_model_path function to utils.py and updated generate.py to use the new import. --- mlx_video/generate.py | 16 +--------------- mlx_video/utils.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index b6e4daa..83abd5f 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -14,7 +14,7 @@ from mlx_video.utils import to_denoised from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents -from huggingface_hub import snapshot_download +from mlx_video.utils import get_model_path # Distilled sigma schedules @@ -95,20 +95,6 @@ def create_position_grid( return mx.array(pixel_coords, dtype=mx.float32) -def get_model_path(model_repo: str): - """Get or download LTX-2 model path.""" - try: - return Path(snapshot_download(repo_id=model_repo, local_files_only=True)) - except Exception: - print("Downloading LTX-2 model weights...") - return Path(snapshot_download( - repo_id=model_repo, - local_files_only=False, - resume_download=True, - allow_patterns=["*.safetensors", "*.json"], - )) - - def denoise( latents: mx.array, positions: mx.array, diff --git a/mlx_video/utils.py b/mlx_video/utils.py index f44e2a7..1ef162f 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -6,6 +6,22 @@ from typing import Optional import mlx.core as mx import mlx.nn as nn from functools import partial +from pathlib import Path +from huggingface_hub import snapshot_download + +def get_model_path(model_repo: str): + """Get or download LTX-2 model path.""" + try: + return Path(snapshot_download(repo_id=model_repo, local_files_only=True)) + except Exception: + print("Downloading LTX-2 model weights...") + return Path(snapshot_download( + repo_id=model_repo, + local_files_only=False, + resume_download=True, + allow_patterns=["*.safetensors", "*.json"], + )) + @partial(mx.compile, shapeless=True) def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: