Refactor model path handling: moved get_model_path function to utils.py and updated generate.py to use the new import.

This commit is contained in:
Prince Canuma
2026-01-12 15:54:32 +01:00
parent 75511a0b17
commit 666e1f2e0c
2 changed files with 17 additions and 15 deletions

View File

@@ -14,7 +14,7 @@ from mlx_video.utils import to_denoised
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from huggingface_hub import snapshot_download from mlx_video.utils import get_model_path
# Distilled sigma schedules # Distilled sigma schedules
@@ -95,20 +95,6 @@ def create_position_grid(
return mx.array(pixel_coords, dtype=mx.float32) return mx.array(pixel_coords, dtype=mx.float32)
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
def denoise( def denoise(
latents: mx.array, latents: mx.array,
positions: mx.array, positions: mx.array,

View File

@@ -6,6 +6,22 @@ from typing import Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from functools import partial from functools import partial
from pathlib import Path
from huggingface_hub import snapshot_download
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
@partial(mx.compile, shapeless=True) @partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: