feat(wan): Add tiled VAE decoding and fix TI2V quality
This commit is contained in:
@@ -48,9 +48,8 @@ class Head(nn.Module):
|
||||
"""
|
||||
if e.ndim == 2:
|
||||
e = e[:, None, :] # [B, 1, dim]
|
||||
# Compute modulation in float32 for precision, cast to working dtype
|
||||
w_dtype = _linear_dtype(self.head)
|
||||
mod = (self.modulation[:, None, :, :] + e[:, :, None, :]).astype(w_dtype)
|
||||
# Compute modulation in float32 (matching reference's autocast(float32))
|
||||
mod = self.modulation[:, None, :, :] + e[:, :, None, :] # float32
|
||||
e0 = mod[:, :, 0, :] # [B, L_e, dim] shift
|
||||
e1 = mod[:, :, 1, :] # [B, L_e, dim] scale
|
||||
x_norm = self.norm(x)
|
||||
@@ -120,10 +119,13 @@ class WanModel(nn.Module):
|
||||
], axis=1)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding
|
||||
# Use numpy float64 for precision (matches reference torch.float64),
|
||||
# then store as float32 since MLX GPU doesn't support float64.
|
||||
half = config.freq_dim // 2
|
||||
self._inv_freq = mx.power(
|
||||
10000.0, -mx.arange(half).astype(mx.float32) / half
|
||||
inv_freq_np = np.power(
|
||||
10000.0, -np.arange(half, dtype=np.float64) / half
|
||||
)
|
||||
self._inv_freq = mx.array(inv_freq_np.astype(np.float32))
|
||||
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
|
||||
Reference in New Issue
Block a user