feat(wan): Add tiled VAE decoding and fix TI2V quality
This commit is contained in:
@@ -51,10 +51,11 @@ class WanAttentionBlock(nn.Module):
|
||||
rope_cos_sin: tuple | None = None,
|
||||
attn_mask: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
# Modulation: compute in float32 for precision, cast to working dtype
|
||||
# to avoid promoting the full hidden state (seq_len × dim) to float32
|
||||
w_dtype = _linear_dtype(self.self_attn.q)
|
||||
mod = (self.modulation + e).astype(w_dtype)
|
||||
# Modulation: compute in float32 for precision, matching the reference
|
||||
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
|
||||
# By keeping modulation in float32, type promotion ensures the residual
|
||||
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
|
||||
mod = self.modulation + e # float32
|
||||
e0, e1, e2, e3, e4, e5 = (
|
||||
mod[:, :, 0, :], # shift for self-attn
|
||||
mod[:, :, 1, :], # scale for self-attn
|
||||
|
||||
Reference in New Issue
Block a user