Add Dev Two-Stage HQ pipeline mode
This commit is contained in:
181
mlx_video/samplers.py
Normal file
181
mlx_video/samplers.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Second-order res_2s sampler for diffusion models.
|
||||
|
||||
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
|
||||
noise injection, ported from the LTX-2 PyTorch implementation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def phi(j: int, neg_h: float) -> float:
|
||||
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
||||
|
||||
phi_1(z) = (e^z - 1) / z
|
||||
phi_2(z) = (e^z - 1 - z) / z^2
|
||||
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
|
||||
"""
|
||||
if abs(neg_h) < 1e-10:
|
||||
return 1.0 / math.factorial(j)
|
||||
|
||||
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
|
||||
return (math.exp(neg_h) - remainder) / (neg_h**j)
|
||||
|
||||
|
||||
def get_res2s_coefficients(
|
||||
h: float,
|
||||
phi_cache: dict,
|
||||
c2: float = 0.5,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute res_2s Runge-Kutta coefficients for a given step size.
|
||||
|
||||
Args:
|
||||
h: Step size in log-space = log(sigma / sigma_next)
|
||||
phi_cache: Dictionary to cache phi function results.
|
||||
c2: Substep position (default 0.5 = midpoint)
|
||||
|
||||
Returns:
|
||||
(a21, b1, b2): RK coefficients.
|
||||
"""
|
||||
def get_phi(j: int, neg_h: float) -> float:
|
||||
cache_key = (j, neg_h)
|
||||
if cache_key in phi_cache:
|
||||
return phi_cache[cache_key]
|
||||
result = phi(j, neg_h)
|
||||
phi_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
neg_h_c2 = -h * c2
|
||||
phi_1_c2 = get_phi(1, neg_h_c2)
|
||||
a21 = c2 * phi_1_c2
|
||||
|
||||
neg_h_full = -h
|
||||
phi_2_full = get_phi(2, neg_h_full)
|
||||
b2 = phi_2_full / c2
|
||||
|
||||
phi_1_full = get_phi(1, neg_h_full)
|
||||
b1 = phi_1_full - b2
|
||||
|
||||
return a21, b1, b2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SDE noise injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_sde_coeff(
|
||||
sigma_next: float,
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute SDE coefficients for variance-preserving noise injection.
|
||||
|
||||
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
|
||||
|
||||
Returns:
|
||||
(alpha_ratio, sigma_down, sigma_up)
|
||||
"""
|
||||
sigma_up = sigma_next * 0.5
|
||||
# Clamp sigma_up to avoid sqrt(negative)
|
||||
sigma_up = min(sigma_up, sigma_next * 0.9999)
|
||||
|
||||
sigma_signal = 1.0 - sigma_next # sigma_max=1
|
||||
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
|
||||
alpha_ratio = sigma_signal + sigma_residual
|
||||
|
||||
if alpha_ratio == 0:
|
||||
sigma_down = sigma_next
|
||||
else:
|
||||
sigma_down = sigma_residual / alpha_ratio
|
||||
|
||||
# Handle NaN edge cases
|
||||
if math.isnan(sigma_up):
|
||||
sigma_up = 0.0
|
||||
if math.isnan(sigma_down):
|
||||
sigma_down = sigma_next
|
||||
if math.isnan(alpha_ratio):
|
||||
alpha_ratio = 1.0
|
||||
|
||||
return alpha_ratio, sigma_down, sigma_up
|
||||
|
||||
|
||||
def sde_noise_step(
|
||||
sample: mx.array,
|
||||
denoised_sample: mx.array,
|
||||
sigma: float,
|
||||
sigma_next: float,
|
||||
noise: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply SDE noise injection step.
|
||||
|
||||
Advances sample from sigma to sigma_next with stochastic noise injection.
|
||||
|
||||
Args:
|
||||
sample: Current sample (anchor point)
|
||||
denoised_sample: Denoised prediction at this step
|
||||
sigma: Current noise level
|
||||
sigma_next: Next noise level
|
||||
noise: Pre-generated noise tensor (channel-wise normalized)
|
||||
|
||||
Returns:
|
||||
Noised sample at sigma_next
|
||||
"""
|
||||
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
|
||||
|
||||
if sigma_up == 0 or sigma_next == 0:
|
||||
return denoised_sample
|
||||
|
||||
# Float32 arithmetic
|
||||
sample_f32 = sample.astype(mx.float32)
|
||||
denoised_f32 = denoised_sample.astype(mx.float32)
|
||||
noise_f32 = noise.astype(mx.float32)
|
||||
|
||||
# Extract epsilon prediction
|
||||
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
|
||||
denoised_next = sample_f32 - sigma * eps_next
|
||||
|
||||
# Mix deterministic and stochastic components
|
||||
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
|
||||
return x_noised
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Noise generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def channelwise_normalize(x: mx.array) -> mx.array:
|
||||
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
||||
|
||||
Operates on the last 2 dimensions (spatial H, W or time, freq).
|
||||
"""
|
||||
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
|
||||
x = x - mean
|
||||
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
|
||||
x = x / std
|
||||
return x
|
||||
|
||||
|
||||
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
|
||||
"""Generate channel-wise normalized Gaussian noise.
|
||||
|
||||
PyTorch uses float64; we use float32 (MLX doesn't support float64).
|
||||
The channel-wise normalization is the key quality-affecting step.
|
||||
|
||||
Args:
|
||||
shape: Shape of the noise tensor
|
||||
key: MLX random key for deterministic generation
|
||||
|
||||
Returns:
|
||||
Channel-wise normalized noise in float32
|
||||
"""
|
||||
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
|
||||
# Global normalization
|
||||
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
|
||||
# Channel-wise normalization
|
||||
noise = channelwise_normalize(noise)
|
||||
return noise
|
||||
Reference in New Issue
Block a user