feat(wan): Add DPM++ 2M and UniPC schedulers

This commit is contained in:
Daniel
2026-02-27 10:28:33 +01:00
parent e64483a66a
commit 93da550f65
8 changed files with 1792 additions and 89 deletions

View File

@@ -288,21 +288,34 @@ class Resample(nn.Module):
B, T, H, W, C = x.shape
if self.mode == "upsample3d":
# Temporal upsample via time_conv
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
# Split into two interleaved temporal streams
tc_out = tc_out.reshape(B, T, H, W, 2, C)
# Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C]
stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C]
stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C]
x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C]
x = x.reshape(B, T * 2, H, W, C)
if first_chunk and T > 1:
# Match official chunked behavior: the first frame bypasses
# time_conv entirely (only spatial upsample). Remaining frames
# go through time_conv with causal zero-padding, which
# naturally gives each frame the same limited temporal context
# as the official frame-by-frame decode with caching.
first_frame = x[:, 0:1] # [B, 1, H, W, C]
rest = x[:, 1:] # [B, T-1, H, W, C]
if first_chunk:
# PyTorch skips time_conv for first chunk entirely. In all-at-once
# mode, we trim the first frame to match (the first interleaved
# frame is from zero-padded causal context and shouldn't be kept).
x = x[:, 1:, :, :, :]
# time_conv on remaining frames (causal pad gives zero context
# before rest[0], matching the official "Rep" cache path)
tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :]
interleaved = mx.stack([stream0, stream1], axis=2)
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
x = mx.concatenate([first_frame, interleaved], axis=1)
elif self.mode == "upsample3d":
# Non-first-chunk or single frame: time_conv all frames
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
tc_out = tc_out.reshape(B, T, H, W, 2, C)
stream0 = tc_out[:, :, :, :, 0, :]
stream1 = tc_out[:, :, :, :, 1, :]
x = mx.stack([stream0, stream1], axis=2)
x = x.reshape(B, T * 2, H, W, C)
mx.eval(x)
T = x.shape[1]