initial commit (LTX-2)
This commit is contained in:
127
mlx_video/utils.py
Normal file
127
mlx_video/utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Utility functions for MLX Video."""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from functools import partial
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
|
||||
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def to_denoised(
|
||||
noisy: mx.array,
|
||||
velocity: mx.array,
|
||||
sigma: mx.array | float
|
||||
) -> mx.array:
|
||||
"""Convert velocity prediction to denoised output.
|
||||
|
||||
Given noisy input x_t and velocity prediction v, compute denoised x_0:
|
||||
x_0 = x_t - sigma * v
|
||||
|
||||
Args:
|
||||
noisy: Noisy input tensor x_t
|
||||
velocity: Velocity prediction v
|
||||
sigma: Noise level (scalar or per-sample)
|
||||
|
||||
Returns:
|
||||
Denoised tensor x_0
|
||||
"""
|
||||
if isinstance(sigma, (int, float)):
|
||||
return noisy - sigma * velocity
|
||||
else:
|
||||
# sigma is per-sample
|
||||
while sigma.ndim < velocity.ndim:
|
||||
sigma = mx.expand_dims(sigma, axis=-1)
|
||||
return noisy - sigma * velocity
|
||||
|
||||
|
||||
def repeat_interleave(x: mx.array, repeats: int, axis: int = -1) -> mx.array:
|
||||
"""Repeat elements of tensor along an axis, similar to torch.repeat_interleave.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
repeats: Number of repetitions for each element
|
||||
axis: The axis along which to repeat values
|
||||
|
||||
Returns:
|
||||
Tensor with repeated values
|
||||
"""
|
||||
# Handle negative axis
|
||||
if axis < 0:
|
||||
axis = x.ndim + axis
|
||||
|
||||
# Get shape
|
||||
shape = list(x.shape)
|
||||
|
||||
# Expand dims, repeat, then reshape
|
||||
x = mx.expand_dims(x, axis=axis + 1)
|
||||
|
||||
# Create tile pattern
|
||||
tile_pattern = [1] * x.ndim
|
||||
tile_pattern[axis + 1] = repeats
|
||||
|
||||
x = mx.tile(x, tile_pattern)
|
||||
|
||||
# Reshape to merge the repeated dimension
|
||||
new_shape = shape.copy()
|
||||
new_shape[axis] *= repeats
|
||||
|
||||
return mx.reshape(x, new_shape)
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
def __init__(self, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return x / mx.sqrt(mx.mean(x * x, axis=1, keepdims=True) + self.eps)
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: mx.array,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1.0,
|
||||
scale: float = 1.0,
|
||||
max_period: int = 10000,
|
||||
) -> mx.array:
|
||||
"""Create sinusoidal timestep embeddings.
|
||||
|
||||
Args:
|
||||
timesteps: 1D tensor of timesteps
|
||||
embedding_dim: Dimension of the embeddings to create
|
||||
flip_sin_to_cos: If True, flip sin and cos ordering
|
||||
downscale_freq_shift: Frequency shift factor
|
||||
scale: Scale factor for timesteps
|
||||
max_period: Maximum period for the sinusoids
|
||||
|
||||
Returns:
|
||||
Tensor of shape (len(timesteps), embedding_dim)
|
||||
"""
|
||||
assert timesteps.ndim == 1, "Timesteps should be 1D"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * mx.arange(0, half_dim, dtype=mx.float32)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = mx.exp(exponent)
|
||||
emb = (timesteps[:, None].astype(mx.float32) * scale) * emb[None, :]
|
||||
|
||||
# Compute sin and cos embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = mx.concatenate([mx.cos(emb), mx.sin(emb)], axis=-1)
|
||||
else:
|
||||
emb = mx.concatenate([mx.sin(emb), mx.cos(emb)], axis=-1)
|
||||
|
||||
# Zero pad if odd embedding dimension
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = mx.pad(emb, [(0, 0), (0, 1)])
|
||||
|
||||
return emb
|
||||
Reference in New Issue
Block a user