format
This commit is contained in:
@@ -5,15 +5,14 @@ noise injection, ported from the LTX-2 PyTorch implementation.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phi functions and RK coefficients (pure Python math, no MLX needed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def phi(j: int, neg_h: float) -> float:
|
||||
"""Compute phi_j(z) where z = -h (negative step size in log-space).
|
||||
|
||||
@@ -43,6 +42,7 @@ def get_res2s_coefficients(
|
||||
Returns:
|
||||
(a21, b1, b2): RK coefficients.
|
||||
"""
|
||||
|
||||
def get_phi(j: int, neg_h: float) -> float:
|
||||
cache_key = (j, neg_h)
|
||||
if cache_key in phi_cache:
|
||||
@@ -69,6 +69,7 @@ def get_res2s_coefficients(
|
||||
# SDE noise injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_sde_coeff(
|
||||
sigma_next: float,
|
||||
) -> tuple[float, float, float]:
|
||||
@@ -139,7 +140,9 @@ def sde_noise_step(
|
||||
denoised_next = sample_f32 - sigma * eps_next
|
||||
|
||||
# Mix deterministic and stochastic components
|
||||
x_noised = alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
x_noised = (
|
||||
alpha_ratio * (denoised_next + sigma_down * eps_next) + sigma_up * noise_f32
|
||||
)
|
||||
|
||||
return x_noised
|
||||
|
||||
@@ -148,6 +151,7 @@ def sde_noise_step(
|
||||
# Noise generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def channelwise_normalize(x: mx.array) -> mx.array:
|
||||
"""Normalize each channel to zero mean and unit variance over spatial dims.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user