feat(wan): Add tiled VAE decoding and fix TI2V quality

This commit is contained in:
Daniel
2026-03-04 14:32:45 +01:00
parent 9597b7c9c5
commit 9bdda9f22e
7 changed files with 407 additions and 34 deletions

View File

@@ -51,10 +51,11 @@ class WanAttentionBlock(nn.Module):
rope_cos_sin: tuple | None = None,
attn_mask: mx.array | None = None,
) -> mx.array:
# Modulation: compute in float32 for precision, cast to working dtype
# to avoid promoting the full hidden state (seq_len × dim) to float32
w_dtype = _linear_dtype(self.self_attn.q)
mod = (self.modulation + e).astype(w_dtype)
# Modulation: compute in float32 for precision, matching the reference
# which keeps residual x in float32 via torch.amp.autocast(dtype=float32).
# By keeping modulation in float32, type promotion ensures the residual
# stream stays float32 throughout all 30 layers (gate * output + x → float32).
mod = self.modulation + e # float32
e0, e1, e2, e3, e4, e5 = (
mod[:, :, 0, :], # shift for self-attn
mod[:, :, 1, :], # scale for self-attn