feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -34,6 +34,8 @@ class FlowMatchEulerScheduler:
|
||||
sigmas = _compute_sigmas(num_steps, shift)
|
||||
self.sigmas = mx.array(sigmas)
|
||||
self.timesteps = mx.array(sigmas[:-1] * self.num_train_timesteps)
|
||||
# Store as Python floats to avoid .item() sync in step()
|
||||
self._sigmas_float = sigmas.tolist()
|
||||
self._step_index = 0
|
||||
|
||||
def step(
|
||||
@@ -43,9 +45,7 @@ class FlowMatchEulerScheduler:
|
||||
sample: mx.array,
|
||||
) -> mx.array:
|
||||
"""Euler step: x_next = x + (sigma_next - sigma_cur) * v."""
|
||||
dt = float(self.sigmas[self._step_index + 1].item()) - float(
|
||||
self.sigmas[self._step_index].item()
|
||||
)
|
||||
dt = self._sigmas_float[self._step_index + 1] - self._sigmas_float[self._step_index]
|
||||
x_next = sample + dt * model_output
|
||||
self._step_index += 1
|
||||
return x_next
|
||||
|
||||
Reference in New Issue
Block a user