feat(wan): Add DPM++ 2M and UniPC schedulers
This commit is contained in:
@@ -106,29 +106,35 @@ class T5Attention(nn.Module):
|
||||
k = self.k(context).reshape(b, -1, n, c) # [B, Lk, N, C]
|
||||
v = self.v(context).reshape(b, -1, n, c)
|
||||
|
||||
# T5 does not use scaling
|
||||
# attn = einsum('binc,bjnc->bnij', q, k)
|
||||
# T5 uses no scaling — compute attention manually with float32 softmax
|
||||
# to match official: F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
# Using SDPA with bfloat16 inputs causes precision loss in softmax
|
||||
# since unscaled logits can be very large (no 1/sqrt(d) division).
|
||||
q = q.transpose(0, 2, 1, 3) # [B, N, Lq, C]
|
||||
k = k.transpose(0, 2, 1, 3)
|
||||
v = v.transpose(0, 2, 1, 3)
|
||||
|
||||
# Combine position bias and attention mask for SDPA
|
||||
attn_mask = None
|
||||
# QK^T (no scaling) — compute in float32 for precision
|
||||
attn = (q.astype(mx.float32) @ k.astype(mx.float32).transpose(0, 1, 3, 2))
|
||||
|
||||
# Add position bias
|
||||
if pos_bias is not None:
|
||||
attn_mask = pos_bias.astype(q.dtype)
|
||||
attn = attn + pos_bias.astype(mx.float32)
|
||||
|
||||
# Apply attention mask (use dtype min like official, not -1e9)
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, 1, 1, Lk]
|
||||
elif mask.ndim == 3:
|
||||
mask = mask[:, None, :, :] # [B, 1, Lq, Lk]
|
||||
additive_mask = mx.where(mask == 0, -1e9, 0.0).astype(q.dtype)
|
||||
attn_mask = (attn_mask + additive_mask) if attn_mask is not None else additive_mask
|
||||
additive_mask = mx.where(mask == 0, -3.389e38, 0.0).astype(mx.float32)
|
||||
attn = attn + additive_mask
|
||||
|
||||
# T5 uses no scaling (scale=1.0)
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=1.0, mask=attn_mask
|
||||
)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||
# Softmax in float32 (matches official), then cast back
|
||||
attn = mx.softmax(attn, axis=-1).astype(q.dtype)
|
||||
|
||||
# Attention @ V
|
||||
out = (attn @ v).transpose(0, 2, 1, 3).reshape(b, -1, n * c)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user