Add custom text encoder with quantization
Co-authored-by: HimanshU Mourya <40685364+codingstark-dev@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,4 @@
|
||||
"""Utility functions for MLX Video."""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -22,6 +19,28 @@ def get_model_path(model_repo: str):
|
||||
allow_patterns=["*.safetensors", "*.json"],
|
||||
))
|
||||
|
||||
def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
||||
if quantization is not None:
|
||||
def get_class_predicate(p, m):
|
||||
# Handle custom per layer quantizations
|
||||
if p in quantization:
|
||||
return quantization[p]
|
||||
if not hasattr(m, "to_quantized"):
|
||||
return False
|
||||
# Skip layers not divisible by 64
|
||||
if hasattr(m, "weight") and m.weight.shape[0] % 64 != 0:
|
||||
return False
|
||||
# Handle legacy models which may not have everything quantized
|
||||
return f"{p}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=quantization["group_size"],
|
||||
bits=quantization["bits"],
|
||||
mode=quantization.get("mode", "affine"),
|
||||
class_predicate=get_class_predicate,
|
||||
)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
|
||||
Reference in New Issue
Block a user