fix(wan): Fix RoPE frequency construction
This commit is contained in:
@@ -108,9 +108,15 @@ class WanModel(nn.Module):
|
||||
# Output head
|
||||
self.head = Head(dim, config.out_dim, config.patch_size, config.eps)
|
||||
|
||||
# Precompute RoPE frequencies — single table, split by rope_apply
|
||||
# Reference computes one rope_params(head_dim) and splits into t/h/w.
|
||||
self.freqs = rope_params(1024, dim // config.num_heads)
|
||||
# Precompute RoPE frequencies — three separate tables concatenated.
|
||||
# Reference computes three rope_params with different dim normalizations
|
||||
# so each axis (temporal/height/width) gets its own full frequency range.
|
||||
d = dim // config.num_heads
|
||||
self.freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding
|
||||
half = config.freq_dim // 2
|
||||
|
||||
Reference in New Issue
Block a user