fix(wan): Fix RoPE frequency construction

This commit is contained in:
Daniel
2026-02-28 11:20:36 +01:00
parent f4195f0118
commit dbab95ec45
3 changed files with 386 additions and 30 deletions

View File

@@ -108,9 +108,15 @@ class WanModel(nn.Module):
# Output head
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
# Precompute RoPE frequencies — single table, split by rope_apply
# Reference computes one rope_params(head_dim) and splits into t/h/w.
self.freqs = rope_params(1024, dim // config.num_heads)
# Precompute RoPE frequencies — three separate tables concatenated.
# Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
# Precompute sinusoidal inv_freq for time embedding
half = config.freq_dim // 2