Refactor model path handling: moved get_model_path function to utils.py and updated generate.py to use the new import.
This commit is contained in:
@@ -14,7 +14,7 @@ from mlx_video.utils import to_denoised
|
|||||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from mlx_video.utils import get_model_path
|
||||||
|
|
||||||
|
|
||||||
# Distilled sigma schedules
|
# Distilled sigma schedules
|
||||||
@@ -95,20 +95,6 @@ def create_position_grid(
|
|||||||
return mx.array(pixel_coords, dtype=mx.float32)
|
return mx.array(pixel_coords, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_repo: str):
|
|
||||||
"""Get or download LTX-2 model path."""
|
|
||||||
try:
|
|
||||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
|
||||||
except Exception:
|
|
||||||
print("Downloading LTX-2 model weights...")
|
|
||||||
return Path(snapshot_download(
|
|
||||||
repo_id=model_repo,
|
|
||||||
local_files_only=False,
|
|
||||||
resume_download=True,
|
|
||||||
allow_patterns=["*.safetensors", "*.json"],
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
def denoise(
|
def denoise(
|
||||||
latents: mx.array,
|
latents: mx.array,
|
||||||
positions: mx.array,
|
positions: mx.array,
|
||||||
|
|||||||
@@ -6,6 +6,22 @@ from typing import Optional
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
def get_model_path(model_repo: str):
|
||||||
|
"""Get or download LTX-2 model path."""
|
||||||
|
try:
|
||||||
|
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||||
|
except Exception:
|
||||||
|
print("Downloading LTX-2 model weights...")
|
||||||
|
return Path(snapshot_download(
|
||||||
|
repo_id=model_repo,
|
||||||
|
local_files_only=False,
|
||||||
|
resume_download=True,
|
||||||
|
allow_patterns=["*.safetensors", "*.json"],
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||||
|
|||||||
Reference in New Issue
Block a user