Add custom text encoder with quantization

Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com>
This commit is contained in:
Prince Canuma
2026-01-13 22:56:51 +01:00
parent 01d895bc77
commit fc6ef20c1b
3 changed files with 87 additions and 85 deletions

View File

@@ -1,7 +1,4 @@
"""Utility functions for MLX Video."""
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
@@ -22,6 +19,28 @@ def get_model_path(model_repo: str):
allow_patterns=["*.safetensors", "*.json"],
))
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
if quantization is not None:
def get_class_predicate(p, m):
# Handle custom per layer quantizations
if p in quantization:
return quantization[p]
if not hasattr(m, "to_quantized"):
return False
# Skip layers not divisible by 64
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
mode=quantization.get("mode", "affine"),
class_predicate=get_class_predicate,
)
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: