feat(wan): Add I2V-14B dual-model support

This commit is contained in:
Daniel
2026-02-27 23:43:42 +01:00
parent 2bb95c61ed
commit f4195f0118
14 changed files with 1332 additions and 152 deletions

View File

@@ -28,6 +28,7 @@ def rope_apply(
x: mx.array,
grid_sizes: list,
freqs: mx.array,
precomputed_cos_sin: tuple | None = None,
) -> mx.array:
"""Apply 3-way factorized RoPE to Q or K tensor.
@@ -35,10 +36,48 @@ def rope_apply(
x: Shape [B, L, num_heads, head_dim]
grid_sizes: List of (F, H, W) tuples per batch element
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin()
"""
b, s, n, d = x.shape
half_d = d // 2
if precomputed_cos_sin is not None:
cos_f, sin_f = precomputed_cos_sin
# Check if all batch elements have the same grid (common for CFG B=2)
f0, h0, w0 = grid_sizes[0]
seq_len = f0 * h0 * w0
all_same_grid = all(
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
) if b > 1 else True
if all_same_grid:
# Vectorized path: apply RoPE to all batch elements at once
x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2)
x_real = x_seq[..., 0]
x_imag = x_seq[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
if seq_len < s:
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
return x_rotated
else:
# Per-element path for mixed grid sizes
outputs = []
for i in range(b):
f, h, w = grid_sizes[i]
sl = f * h * w
x_i = x[i, :sl].reshape(sl, n, half_d, 2)
x_real = x_i[..., 0]
x_imag = x_i[..., 1]
out_real = x_real * cos_f - x_imag * sin_f
out_imag = x_real * sin_f + x_imag * cos_f
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d)
if sl < s:
x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0)
outputs.append(x_rotated)
return mx.stack(outputs)
# Cast freqs to input dtype to prevent float32 promotion cascade
if freqs.dtype != x.dtype:
freqs = freqs.astype(x.dtype)
@@ -98,3 +137,42 @@ def rope_apply(
outputs.append(x_rotated)
return mx.stack(outputs)
def rope_precompute_cos_sin(
grid_sizes: list, freqs: mx.array, dtype: type = mx.float32
) -> tuple:
"""Precompute cos/sin frequency tensors for constant grid sizes.
Call once before the diffusion loop. Pass result as precomputed_cos_sin
to rope_apply to skip per-step broadcast/concat.
Args:
grid_sizes: List of (F, H, W) tuples (must be same for all batch elements)
freqs: Precomputed frequencies [1024, d//2, 2]
dtype: Target dtype for the output tensors
Returns:
(cos_f, sin_f) each [seq_len, 1, half_d]
"""
if freqs.dtype != dtype:
freqs = freqs.astype(dtype)
f, h, w = grid_sizes[0]
seq_len = f * h * w
half_d = freqs.shape[1]
d_t = half_d - 2 * (half_d // 3)
d_h = half_d // 3
d_w = half_d // 3
freqs_t = freqs[:, :d_t]
freqs_h = freqs[:, d_t : d_t + d_h]
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w]
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
return freqs_i[..., 0], freqs_i[..., 1]