This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -5,15 +5,14 @@ noise injection, ported from the LTX-2 PyTorch implementation.
"""
import math
from typing import Optional
import mlx.core as mx
# ---------------------------------------------------------------------------
# Phi functions and RK coefficients (pure Python math, no MLX needed)
# ---------------------------------------------------------------------------
def phi(j: int, neg_h: float) -> float:
"""Compute phi_j(z) where z = -h (negative step size in log-space).
@@ -43,6 +42,7 @@ def get_res2s_coefficients(
Returns:
(a21, b1, b2): RK coefficients.
"""
def get_phi(j: int, neg_h: float) -> float:
cache_key = (j, neg_h)
if cache_key in phi_cache:
@@ -69,6 +69,7 @@ def get_res2s_coefficients(
# SDE noise injection
# ---------------------------------------------------------------------------
def get_sde_coeff(
sigma_next: float,
) -> tuple[float, float, float]:
@@ -139,7 +140,9 @@ def sde_noise_step(
denoised_next = sample_f32 - sigma * eps_next
# Mix deterministic and stochastic components
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
x_noised = (
alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
)
return x_noised
@@ -148,6 +151,7 @@ def sde_noise_step(
# Noise generation
# ---------------------------------------------------------------------------
def channelwise_normalize(x: mx.array) -> mx.array:
"""Normalize each channel to zero mean and unit variance over spatial dims.