This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,4 +1,5 @@
import math
import mlx.core as mx
import mlx.nn as nn
import numpy as np
@@ -37,7 +38,9 @@ class Head(nn.Module):
proj_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, proj_dim)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
mx.float32
)
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
"""
@@ -111,20 +114,23 @@ class WanModel(nn.Module):
# Reference computes three rope_params with different dim normalizations
# so each axis (temporal/height/width) gets its own full frequency range.
d = dim // config.num_heads
self.freqs = mx.concatenate([
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
], axis=1)
self.freqs = mx.concatenate(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
axis=1,
)
# Precompute sinusoidal inv_freq for time embedding.
half = config.freq_dim // 2
self._inv_freq = mx.array(
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
).astype(np.float32)
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
np.float32
)
)
def _patchify(self, x: mx.array) -> tuple:
"""Convert video tensor to patch embeddings.
@@ -297,12 +303,19 @@ class WanModel(nn.Module):
seq_lens_list.append(p.shape[1])
x = mx.concatenate(
[
mx.concatenate(
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
axis=1,
(
mx.concatenate(
[
p,
mx.zeros(
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
),
],
axis=1,
)
if p.shape[1] < seq_len
else p
)
if p.shape[1] < seq_len
else p
for p in patches
],
axis=0,
@@ -315,9 +328,7 @@ class WanModel(nn.Module):
t = t[None]
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
sin_emb = mx.concatenate(
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
)
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
if t.ndim == 1:
# Standard T2V: scalar timestep per batch element [B]