format
This commit is contained in:
@@ -7,9 +7,8 @@ for the same quality as Euler.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _compute_sigmas(
|
||||
@@ -25,9 +24,7 @@ def _compute_sigmas(
|
||||
Returns num_steps+1 values (the last being 0.0 for the terminal state).
|
||||
"""
|
||||
# sigma bounds from unshifted training schedule (constructor uses shift=1)
|
||||
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[
|
||||
::-1
|
||||
]
|
||||
alphas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps)[::-1]
|
||||
sigmas_unshifted = 1.0 - alphas
|
||||
sigma_max = float(sigmas_unshifted[0]) # (N-1)/N
|
||||
sigma_min = float(sigmas_unshifted[-1]) # 0.0
|
||||
@@ -65,7 +62,10 @@ class FlowMatchEulerScheduler:
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||||
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
|
||||
dt = (
|
||||
self._sigmas_float[self._step_index + 1]
|
||||
- self._sigmas_float[self._step_index]
|
||||
)
|
||||
x_next = sample + dt * model_output
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
@@ -139,13 +139,8 @@ class FlowDPMPP2MScheduler:
|
||||
|
||||
# Decide order: 1st for first step, last step (if lower_order_final
|
||||
# and few steps), otherwise 2nd
|
||||
use_first_order = (
|
||||
self._prev_x0 is None
|
||||
or (
|
||||
self.lower_order_final
|
||||
and i == self._num_steps - 1
|
||||
and self._num_steps < 15
|
||||
)
|
||||
use_first_order = self._prev_x0 is None or (
|
||||
self.lower_order_final and i == self._num_steps - 1 and self._num_steps < 15
|
||||
)
|
||||
|
||||
if use_first_order or sigma_next == 0.0:
|
||||
|
||||
Reference in New Issue
Block a user