This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,14 +1,15 @@
import math
from functools import partial
from pathlib import Path
from typing import Optional, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from functools import partial
from pathlib import Path
from huggingface_hub import snapshot_download
from PIL import Image
def get_model_path(model_repo: str):
"""Get or download LTX-2 model path."""
try:
@@ -17,15 +18,19 @@ def get_model_path(model_repo: str):
return Path(snapshot_download(repo_id=model_repo, local_files_only=True))
except Exception:
print("Downloading LTX-2 model weights...")
return Path(snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
))
return Path(
snapshot_download(
repo_id=model_repo,
local_files_only=False,
resume_download=True,
allow_patterns=["*.safetensors", "*.json"],
)
)
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
if quantization is not None:
def get_class_predicate(p, m):
# Handle custom per layer quantizations
if p in quantization:
@@ -46,17 +51,15 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
class_predicate=get_class_predicate,
)
@partial(mx.compile, shapeless=True)
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
@partial(mx.compile, shapeless=True)
def to_denoised(
noisy: mx.array,
velocity: mx.array,
sigma: mx.array | float
noisy: mx.array, velocity: mx.array, sigma: mx.array | float
) -> mx.array:
"""Convert velocity prediction to denoised output.
@@ -284,7 +287,9 @@ def prepare_image_for_encoding(
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
pil_image = Image.fromarray(image_np)
pil_image = pil_image.resize((target_width, target_height), Image.Resampling.LANCZOS)
pil_image = pil_image.resize(
(target_width, target_height), Image.Resampling.LANCZOS
)
image = mx.array(np.array(pil_image).astype(np.float32) / 255.0)
# Normalize to [-1, 1]