feat(wan): Add tiled VAE decoding and fix TI2V quality

This commit is contained in:
Daniel
2026-03-04 14:32:45 +01:00
parent 9597b7c9c5
commit 9bdda9f22e
7 changed files with 407 additions and 34 deletions

View File

@@ -48,9 +48,8 @@ class Head(nn.Module):
"""
if e.ndim == 2:
e = e[:, None, :] # [B, 1, dim]
# Compute modulation in float32 for precision, cast to working dtype
w_dtype = _linear_dtype(self.head)
mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype)
# Compute modulation in float32 (matching reference's autocast(float32))
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
x_norm = self.norm(x)
@@ -120,10 +119,13 @@ class WanModel(nn.Module):
], axis=1)
# Precompute sinusoidal inv_freq for time embedding
# Use numpy float64 for precision (matches reference torch.float64),
# then store as float32 since MLX GPU doesn't support float64.
half = config.freq_dim // 2
self._inv_freq = mx.power(
10000.0, -mx.arange(half).astype(mx.float32) / half
inv_freq_np = np.power(
10000.0, -np.arange(half, dtype=np.float64) / half
)
self._inv_freq = mx.array(inv_freq_np.astype(np.float32))
def _patchify(self, x: mx.array) -> tuple: