Refactor model path handling: moved get_model_path function to utils.py and updated generate.py to use the new import.
This commit is contained in:
@@ -14,7 +14,7 @@ from mlx_video.utils import to_denoised
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx_video.utils import get_model_path
|
||||
|
||||
|
||||
# Distilled sigma schedules
|
||||
@@ -95,20 +95,6 @@ def create_position_grid(
|
||||
return mx.array(pixel_coords, dtype=mx.float32)
|
||||
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
|
||||
|
||||
def denoise(
|
||||
latents: mx.array,
|
||||
positions: mx.array,
|
||||
|
||||
@@ -6,6 +6,22 @@ from typing import Optional
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user