format
This commit is contained in:
@@ -98,8 +98,12 @@ class WanSelfAttention(nn.Module):
|
||||
v = self.v(x_w).reshape(b, s, n, d)
|
||||
|
||||
# RoPE in float32 for precision (official uses float64)
|
||||
q = rope_apply(q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
||||
k = rope_apply(k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin)
|
||||
q = rope_apply(
|
||||
q.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
k = rope_apply(
|
||||
k.astype(mx.float32), grid_sizes, freqs, precomputed_cos_sin=rope_cos_sin
|
||||
)
|
||||
|
||||
# Cast back to weight dtype for efficient attention (matching official q.to(v.dtype))
|
||||
q = q.astype(w_dtype).transpose(0, 2, 1, 3)
|
||||
@@ -120,9 +124,7 @@ class WanSelfAttention(nn.Module):
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale
|
||||
)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, s, -1)
|
||||
return self.o(out)
|
||||
@@ -213,9 +215,7 @@ class WanCrossAttention(nn.Module):
|
||||
q, k, v, scale=self.scale, mask=mask
|
||||
)
|
||||
else:
|
||||
out = mx.fast.scaled_dot_product_attention(
|
||||
q, k, v, scale=self.scale
|
||||
)
|
||||
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(b, -1, n * d)
|
||||
return self.o(out)
|
||||
|
||||
Reference in New Issue
Block a user