feat(wan): Add DPM++ 2M and UniPC schedulers
This commit is contained in:
@@ -288,21 +288,34 @@ class Resample(nn.Module):
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
if self.mode == "upsample3d":
|
||||
# Temporal upsample via time_conv
|
||||
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
|
||||
# Split into two interleaved temporal streams
|
||||
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
||||
# Interleave: [B, T, 2, H, W, C] → [B, T*2, H, W, C]
|
||||
stream0 = tc_out[:, :, :, :, 0, :] # [B, T, H, W, C]
|
||||
stream1 = tc_out[:, :, :, :, 1, :] # [B, T, H, W, C]
|
||||
x = mx.stack([stream0, stream1], axis=2) # [B, T, 2, H, W, C]
|
||||
x = x.reshape(B, T * 2, H, W, C)
|
||||
if first_chunk and T > 1:
|
||||
# Match official chunked behavior: the first frame bypasses
|
||||
# time_conv entirely (only spatial upsample). Remaining frames
|
||||
# go through time_conv with causal zero-padding, which
|
||||
# naturally gives each frame the same limited temporal context
|
||||
# as the official frame-by-frame decode with caching.
|
||||
first_frame = x[:, 0:1] # [B, 1, H, W, C]
|
||||
rest = x[:, 1:] # [B, T-1, H, W, C]
|
||||
|
||||
if first_chunk:
|
||||
# PyTorch skips time_conv for first chunk entirely. In all-at-once
|
||||
# mode, we trim the first frame to match (the first interleaved
|
||||
# frame is from zero-padded causal context and shouldn't be kept).
|
||||
x = x[:, 1:, :, :, :]
|
||||
# time_conv on remaining frames (causal pad gives zero context
|
||||
# before rest[0], matching the official "Rep" cache path)
|
||||
tc_out = self.time_conv(rest) # [B, T-1, H, W, 2C]
|
||||
tc_out = tc_out.reshape(B, T - 1, H, W, 2, C)
|
||||
stream0 = tc_out[:, :, :, :, 0, :]
|
||||
stream1 = tc_out[:, :, :, :, 1, :]
|
||||
interleaved = mx.stack([stream0, stream1], axis=2)
|
||||
interleaved = interleaved.reshape(B, (T - 1) * 2, H, W, C)
|
||||
|
||||
# first_frame (1) + interleaved (2*(T-1)) = 2T-1 frames
|
||||
x = mx.concatenate([first_frame, interleaved], axis=1)
|
||||
elif self.mode == "upsample3d":
|
||||
# Non-first-chunk or single frame: time_conv all frames
|
||||
tc_out = self.time_conv(x) # [B, T, H, W, 2C]
|
||||
tc_out = tc_out.reshape(B, T, H, W, 2, C)
|
||||
stream0 = tc_out[:, :, :, :, 0, :]
|
||||
stream1 = tc_out[:, :, :, :, 1, :]
|
||||
x = mx.stack([stream0, stream1], axis=2)
|
||||
x = x.reshape(B, T * 2, H, W, C)
|
||||
|
||||
mx.eval(x)
|
||||
T = x.shape[1]
|
||||
|
||||
Reference in New Issue
Block a user