Files
mlx-video/mlx_video/utils.py
2026-01-11 23:48:33 +01:00

128 lines
3.4 KiB
Python

"""Utility functions for MLX Video."""
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from functools import partial
@partial(mx.compile, shapeless=True)
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
@partial(mx.compile, shapeless=True)
def to_denoised(
noisy: mx.array,
velocity: mx.array,
sigma: mx.array | float
) -> mx.array:
"""Convert velocity prediction to denoised output.
Given noisy input x_t and velocity prediction v, compute denoised x_0:
x_0 = x_t - sigma * v
Args:
noisy: Noisy input tensor x_t
velocity: Velocity prediction v
sigma: Noise level (scalar or per-sample)
Returns:
Denoised tensor x_0
"""
if isinstance(sigma, (int, float)):
return noisy - sigma * velocity
else:
# sigma is per-sample
while sigma.ndim < velocity.ndim:
sigma = mx.expand_dims(sigma, axis=-1)
return noisy - sigma * velocity
def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array:
"""Repeat elements of tensor along an axis, similar to torch.repeat_interleave.
Args:
x: Input tensor
repeats: Number of repetitions for each element
axis: The axis along which to repeat values
Returns:
Tensor with repeated values
"""
# Handle negative axis
if axis < 0:
axis = x.ndim + axis
# Get shape
shape = list(x.shape)
# Expand dims, repeat, then reshape
x = mx.expand_dims(x, axis=axis + 1)
# Create tile pattern
tile_pattern = [1] * x.ndim
tile_pattern[axis + 1] = repeats
x = mx.tile(x, tile_pattern)
# Reshape to merge the repeated dimension
new_shape = shape.copy()
new_shape[axis] *= repeats
return mx.reshape(x, new_shape)
class PixelNorm(nn.Module):
def __init__(self, eps: float = 1e-6):
super().__init__()
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
return x / mx.sqrt(mx.mean(x * x, axis=1, keepdims=True) + self.eps)
def get_timestep_embedding(
timesteps: mx.array,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1.0,
scale: float = 1.0,
max_period: int = 10000,
) -> mx.array:
"""Create sinusoidal timestep embeddings.
Args:
timesteps: 1D tensor of timesteps
embedding_dim: Dimension of the embeddings to create
flip_sin_to_cos: If True, flip sin and cos ordering
downscale_freq_shift: Frequency shift factor
scale: Scale factor for timesteps
max_period: Maximum period for the sinusoids
Returns:
Tensor of shape (len(timesteps), embedding_dim)
"""
assert timesteps.ndim == 1, "Timesteps should be 1D"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = mx.exp(exponent)
emb = (timesteps[:, None].astype(mx.float32) * scale) * emb[None, :]
# Compute sin and cos embeddings
if flip_sin_to_cos:
emb = mx.concatenate([mx.cos(emb), mx.sin(emb)], axis=-1)
else:
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
# Zero pad if odd embedding dimension
if embedding_dim % 2 == 1:
emb = mx.pad(emb, [(0, 0), (0, 1)])
return emb