format
This commit is contained in:
@@ -1,14 +1,15 @@
|
||||
import math
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_model_path(model_repo: str):
|
||||
"""Get or download LTX-2 model path."""
|
||||
try:
|
||||
@@ -17,15 +18,19 @@ def get_model_path(model_repo: str):
|
||||
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
|
||||
except Exception:
|
||||
print("Downloading LTX-2 model weights...")
|
||||
return Path(snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=model_repo,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||
if quantization is not None:
|
||||
|
||||
def get_class_predicate(p, m):
|
||||
# Handle custom per layer quantizations
|
||||
if p in quantization:
|
||||
@@ -46,17 +51,15 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||
class_predicate=get_class_predicate,
|
||||
)
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
|
||||
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def to_denoised(
|
||||
noisy: mx.array,
|
||||
velocity: mx.array,
|
||||
sigma: mx.array | float
|
||||
noisy: mx.array, velocity: mx.array, sigma: mx.array | float
|
||||
) -> mx.array:
|
||||
"""Convert velocity prediction to denoised output.
|
||||
|
||||
@@ -284,7 +287,9 @@ def prepare_image_for_encoding(
|
||||
if image_np.max() <= 1.0:
|
||||
image_np = (image_np * 255).astype(np.uint8)
|
||||
pil_image = Image.fromarray(image_np)
|
||||
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
|
||||
pil_image = pil_image.resize(
|
||||
(target_width, target_height), Image.Resampling.LANCZOS
|
||||
)
|
||||
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
|
||||
|
||||
# Normalize to [-1, 1]
|
||||
|
||||
Reference in New Issue
Block a user