feat(wan): Add DPM++ 2M and UniPC schedulers

This commit is contained in:
Daniel
2026-02-27 10:28:33 +01:00
parent e64483a66a
commit 93da550f65
8 changed files with 1792 additions and 89 deletions

View File

@@ -106,29 +106,35 @@ class T5Attention(nn.Module):
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
v = self.v(context).reshape(b, -1, n, c)
# T5 does not use scaling
# attn = einsum('binc,bjnc->bnij', q, k)
# T5 uses no scaling — compute attention manually with float32 softmax
# to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
# Using SDPA with bfloat16 inputs causes precision loss in softmax
# since unscaled logits can be very large (no 1/sqrt(d) division).
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
# Combine position bias and attention mask for SDPA
attn_mask = None
# QK^T (no scaling) — compute in float32 for precision
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
# Add position bias
if pos_bias is not None:
attn_mask = pos_bias.astype(q.dtype)
attn = attn + pos_bias.astype(mx.float32)
# Apply attention mask (use dtype min like official, not -1e9)
if mask is not None:
if mask.ndim == 2:
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
elif mask.ndim == 3:
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype)
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask
additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
attn = attn + additive_mask
# T5 uses no scaling (scale=1.0)
out = mx.fast.scaled_dot_product_attention(
q, k, v, scale=1.0, mask=attn_mask
)
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c)
# Softmax in float32 (matches official), then cast back
attn = mx.softmax(attn, axis=-1).astype(q.dtype)
# Attention @ V
out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
return self.o(out)