format
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
@@ -37,7 +38,9 @@ class Head(nn.Module):
|
||||
proj_dim = math.prod(patch_size) * out_dim
|
||||
self.norm = WanLayerNorm(dim, eps)
|
||||
self.head = nn.Linear(dim, proj_dim)
|
||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(mx.float32)
|
||||
self.modulation = (mx.random.normal((1, 2, dim)) * (dim**-0.5)).astype(
|
||||
mx.float32
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
|
||||
"""
|
||||
@@ -111,20 +114,23 @@ class WanModel(nn.Module):
|
||||
# Reference computes three rope_params with different dim normalizations
|
||||
# so each axis (temporal/height/width) gets its own full frequency range.
|
||||
d = dim // config.num_heads
|
||||
self.freqs = mx.concatenate([
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
], axis=1)
|
||||
self.freqs = mx.concatenate(
|
||||
[
|
||||
rope_params(1024, d - 4 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
rope_params(1024, 2 * (d // 6)),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
# Precompute sinusoidal inv_freq for time embedding.
|
||||
half = config.freq_dim // 2
|
||||
self._inv_freq = mx.array(
|
||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half
|
||||
).astype(np.float32)
|
||||
np.power(10000.0, -np.arange(half, dtype=np.float64) / half).astype(
|
||||
np.float32
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _patchify(self, x: mx.array) -> tuple:
|
||||
"""Convert video tensor to patch embeddings.
|
||||
|
||||
@@ -297,12 +303,19 @@ class WanModel(nn.Module):
|
||||
seq_lens_list.append(p.shape[1])
|
||||
x = mx.concatenate(
|
||||
[
|
||||
mx.concatenate(
|
||||
[p, mx.zeros((1, seq_len - p.shape[1], self.dim), dtype=p.dtype)],
|
||||
axis=1,
|
||||
(
|
||||
mx.concatenate(
|
||||
[
|
||||
p,
|
||||
mx.zeros(
|
||||
(1, seq_len - p.shape[1], self.dim), dtype=p.dtype
|
||||
),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
)
|
||||
if p.shape[1] < seq_len
|
||||
else p
|
||||
for p in patches
|
||||
],
|
||||
axis=0,
|
||||
@@ -315,9 +328,7 @@ class WanModel(nn.Module):
|
||||
t = t[None]
|
||||
|
||||
sinusoid = t[..., None].astype(mx.float32) * self._inv_freq
|
||||
sin_emb = mx.concatenate(
|
||||
[mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1
|
||||
)
|
||||
sin_emb = mx.concatenate([mx.cos(sinusoid), mx.sin(sinusoid)], axis=-1)
|
||||
|
||||
if t.ndim == 1:
|
||||
# Standard T2V: scalar timestep per batch element [B]
|
||||
|
||||
Reference in New Issue
Block a user