feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -28,6 +28,7 @@ def rope_apply(
|
||||
x: mx.array,
|
||||
grid_sizes: list,
|
||||
freqs: mx.array,
|
||||
precomputed_cos_sin: tuple | None = None,
|
||||
) -> mx.array:
|
||||
"""Apply 3-way factorized RoPE to Q or K tensor.
|
||||
|
||||
@@ -35,10 +36,48 @@ def rope_apply(
|
||||
x: Shape [B, L, num_heads, head_dim]
|
||||
grid_sizes: List of (F, H, W) tuples per batch element
|
||||
freqs: Precomputed cos/sin, shape [1024, d//2, 2] split into 3 parts
|
||||
precomputed_cos_sin: Optional (cos, sin) from rope_precompute_cos_sin()
|
||||
"""
|
||||
b, s, n, d = x.shape
|
||||
half_d = d // 2
|
||||
|
||||
if precomputed_cos_sin is not None:
|
||||
cos_f, sin_f = precomputed_cos_sin
|
||||
# Check if all batch elements have the same grid (common for CFG B=2)
|
||||
f0, h0, w0 = grid_sizes[0]
|
||||
seq_len = f0 * h0 * w0
|
||||
all_same_grid = all(
|
||||
grid_sizes[i] == grid_sizes[0] for i in range(1, b)
|
||||
) if b > 1 else True
|
||||
|
||||
if all_same_grid:
|
||||
# Vectorized path: apply RoPE to all batch elements at once
|
||||
x_seq = x[:, :seq_len].reshape(b, seq_len, n, half_d, 2)
|
||||
x_real = x_seq[..., 0]
|
||||
x_imag = x_seq[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(b, seq_len, n, d)
|
||||
if seq_len < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[:, seq_len:]], axis=1)
|
||||
return x_rotated
|
||||
else:
|
||||
# Per-element path for mixed grid sizes
|
||||
outputs = []
|
||||
for i in range(b):
|
||||
f, h, w = grid_sizes[i]
|
||||
sl = f * h * w
|
||||
x_i = x[i, :sl].reshape(sl, n, half_d, 2)
|
||||
x_real = x_i[..., 0]
|
||||
x_imag = x_i[..., 1]
|
||||
out_real = x_real * cos_f - x_imag * sin_f
|
||||
out_imag = x_real * sin_f + x_imag * cos_f
|
||||
x_rotated = mx.stack([out_real, out_imag], axis=-1).reshape(sl, n, d)
|
||||
if sl < s:
|
||||
x_rotated = mx.concatenate([x_rotated, x[i, sl:]], axis=0)
|
||||
outputs.append(x_rotated)
|
||||
return mx.stack(outputs)
|
||||
|
||||
# Cast freqs to input dtype to prevent float32 promotion cascade
|
||||
if freqs.dtype != x.dtype:
|
||||
freqs = freqs.astype(x.dtype)
|
||||
@@ -98,3 +137,42 @@ def rope_apply(
|
||||
outputs.append(x_rotated)
|
||||
|
||||
return mx.stack(outputs)
|
||||
|
||||
|
||||
def rope_precompute_cos_sin(
|
||||
grid_sizes: list, freqs: mx.array, dtype: type = mx.float32
|
||||
) -> tuple:
|
||||
"""Precompute cos/sin frequency tensors for constant grid sizes.
|
||||
|
||||
Call once before the diffusion loop. Pass result as precomputed_cos_sin
|
||||
to rope_apply to skip per-step broadcast/concat.
|
||||
|
||||
Args:
|
||||
grid_sizes: List of (F, H, W) tuples (must be same for all batch elements)
|
||||
freqs: Precomputed frequencies [1024, d//2, 2]
|
||||
dtype: Target dtype for the output tensors
|
||||
|
||||
Returns:
|
||||
(cos_f, sin_f) each [seq_len, 1, half_d]
|
||||
"""
|
||||
if freqs.dtype != dtype:
|
||||
freqs = freqs.astype(dtype)
|
||||
|
||||
f, h, w = grid_sizes[0]
|
||||
seq_len = f * h * w
|
||||
half_d = freqs.shape[1]
|
||||
|
||||
d_t = half_d - 2 * (half_d // 3)
|
||||
d_h = half_d // 3
|
||||
d_w = half_d // 3
|
||||
|
||||
freqs_t = freqs[:, :d_t]
|
||||
freqs_h = freqs[:, d_t : d_t + d_h]
|
||||
freqs_w = freqs[:, d_t + d_h : d_t + d_h + d_w]
|
||||
|
||||
ft = mx.broadcast_to(freqs_t[:f].reshape(f, 1, 1, d_t, 2), (f, h, w, d_t, 2))
|
||||
fh = mx.broadcast_to(freqs_h[:h].reshape(1, h, 1, d_h, 2), (f, h, w, d_h, 2))
|
||||
fw = mx.broadcast_to(freqs_w[:w].reshape(1, 1, w, d_w, 2), (f, h, w, d_w, 2))
|
||||
|
||||
freqs_i = mx.concatenate([ft, fh, fw], axis=3).reshape(seq_len, 1, half_d, 2)
|
||||
return freqs_i[..., 0], freqs_i[..., 1]
|
||||
|
||||
Reference in New Issue
Block a user