Refactor LTX-2 model structure

This commit is contained in:
Prince Canuma
2026-03-16 14:50:01 +01:00
parent decb3eb9e5
commit 3a0da19adb
50 changed files with 3882 additions and 3365 deletions

View File

@@ -0,0 +1,181 @@
"""Second-order res_2s sampler for diffusion models.
Implements the exponential Rosenbrock-type Runge-Kutta integrator with SDE
noise injection, ported from the LTX-2 PyTorch implementation.
"""
import math
from typing import Optional
import mlx.core as mx
# ---------------------------------------------------------------------------
# Phi functions and RK coefficients (pure Python math, no MLX needed)
# ---------------------------------------------------------------------------
def phi(j: int, neg_h: float) -> float:
"""Compute phi_j(z) where z = -h (negative step size in log-space).
phi_1(z) = (e^z - 1) / z
phi_2(z) = (e^z - 1 - z) / z^2
phi_j(z) = (e^z - sum_{k=0}^{j-1} z^k/k!) / z^j
"""
if abs(neg_h) < 1e-10:
return 1.0 / math.factorial(j)
remainder = sum(neg_h**k / math.factorial(k) for k in range(j))
return (math.exp(neg_h) - remainder) / (neg_h**j)
def get_res2s_coefficients(
h: float,
phi_cache: dict,
c2: float = 0.5,
) -> tuple[float, float, float]:
"""Compute res_2s Runge-Kutta coefficients for a given step size.
Args:
h: Step size in log-space = log(sigma / sigma_next)
phi_cache: Dictionary to cache phi function results.
c2: Substep position (default 0.5 = midpoint)
Returns:
(a21, b1, b2): RK coefficients.
"""
def get_phi(j: int, neg_h: float) -> float:
cache_key = (j, neg_h)
if cache_key in phi_cache:
return phi_cache[cache_key]
result = phi(j, neg_h)
phi_cache[cache_key] = result
return result
neg_h_c2 = -h * c2
phi_1_c2 = get_phi(1, neg_h_c2)
a21 = c2 * phi_1_c2
neg_h_full = -h
phi_2_full = get_phi(2, neg_h_full)
b2 = phi_2_full / c2
phi_1_full = get_phi(1, neg_h_full)
b1 = phi_1_full - b2
return a21, b1, b2
# ---------------------------------------------------------------------------
# SDE noise injection
# ---------------------------------------------------------------------------
def get_sde_coeff(
sigma_next: float,
) -> tuple[float, float, float]:
"""Compute SDE coefficients for variance-preserving noise injection.
Uses sigma_up = sigma_next * 0.5 (hardcoded in PyTorch Res2sDiffusionStep).
Returns:
(alpha_ratio, sigma_down, sigma_up)
"""
sigma_up = sigma_next * 0.5
# Clamp sigma_up to avoid sqrt(negative)
sigma_up = min(sigma_up, sigma_next * 0.9999)
sigma_signal = 1.0 - sigma_next # sigma_max=1
sigma_residual = math.sqrt(max(sigma_next**2 - sigma_up**2, 0.0))
alpha_ratio = sigma_signal + sigma_residual
if alpha_ratio == 0:
sigma_down = sigma_next
else:
sigma_down = sigma_residual / alpha_ratio
# Handle NaN edge cases
if math.isnan(sigma_up):
sigma_up = 0.0
if math.isnan(sigma_down):
sigma_down = sigma_next
if math.isnan(alpha_ratio):
alpha_ratio = 1.0
return alpha_ratio, sigma_down, sigma_up
def sde_noise_step(
sample: mx.array,
denoised_sample: mx.array,
sigma: float,
sigma_next: float,
noise: mx.array,
) -> mx.array:
"""Apply SDE noise injection step.
Advances sample from sigma to sigma_next with stochastic noise injection.
Args:
sample: Current sample (anchor point)
denoised_sample: Denoised prediction at this step
sigma: Current noise level
sigma_next: Next noise level
noise: Pre-generated noise tensor (channel-wise normalized)
Returns:
Noised sample at sigma_next
"""
alpha_ratio, sigma_down, sigma_up = get_sde_coeff(sigma_next)
if sigma_up == 0 or sigma_next == 0:
return denoised_sample
# Float32 arithmetic
sample_f32 = sample.astype(mx.float32)
denoised_f32 = denoised_sample.astype(mx.float32)
noise_f32 = noise.astype(mx.float32)
# Extract epsilon prediction
eps_next = (sample_f32 - denoised_f32) / (sigma - sigma_next)
denoised_next = sample_f32 - sigma * eps_next
# Mix deterministic and stochastic components
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
return x_noised
# ---------------------------------------------------------------------------
# Noise generation
# ---------------------------------------------------------------------------
def channelwise_normalize(x: mx.array) -> mx.array:
"""Normalize each channel to zero mean and unit variance over spatial dims.
Operates on the last 2 dimensions (spatial H, W or time, freq).
"""
mean = mx.mean(x, axis=(-2, -1), keepdims=True)
x = x - mean
std = mx.sqrt(mx.mean(x * x, axis=(-2, -1), keepdims=True) + 1e-8)
x = x / std
return x
def get_new_noise(shape: tuple, key: mx.array) -> mx.array:
"""Generate channel-wise normalized Gaussian noise.
PyTorch uses float64; we use float32 (MLX doesn't support float64).
The channel-wise normalization is the key quality-affecting step.
Args:
shape: Shape of the noise tensor
key: MLX random key for deterministic generation
Returns:
Channel-wise normalized noise in float32
"""
noise = mx.random.normal(shape, dtype=mx.float32, key=key)
# Global normalization
noise = (noise - mx.mean(noise)) / (mx.sqrt(mx.mean(noise * noise)) + 1e-8)
# Channel-wise normalization
noise = channelwise_normalize(noise)
return noise