feat(wan): Add I2V-14B dual-model support

This commit is contained in:
Daniel
2026-02-27 23:43:42 +01:00
parent 2bb95c61ed
commit f4195f0118
14 changed files with 1332 additions and 152 deletions

View File

@@ -245,24 +245,71 @@ def generate_video(
z_img = None
i2v_mask = None
i2v_mask_tokens = None
y_i2v = None
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
if is_i2v:
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
t_img = time.time()
img_tensor = preprocess_image(image, width, height)
mx.eval(img_tensor)
vae_path = model_dir / "vae.safetensors"
vae_enc = load_vae_encoder(vae_path, config)
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
mx.eval(z_img)
# Convert to channels-first: [z_dim, 1, H_lat, W_lat]
z_img = z_img[0].transpose(3, 0, 1, 2)
if is_i2v_channel_concat:
# I2V-14B: encode full video (first frame = image, rest = zeros)
# and construct y tensor with mask + encoded latents
from PIL import Image
# Build I2V mask
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
img = Image.open(image).convert("RGB")
scale = max(width / img.width, height / img.height)
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
img = img.crop((x1, y1, x1 + width, y1 + height))
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
video = mx.concatenate([
img_chw[:, None, :, :],
mx.zeros((3, num_frames - 1, height, width)),
], axis=1)
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
vae_enc = load_vae_encoder(vae_path, config)
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
mx.eval(z_video)
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
msk = mx.ones((1, num_frames, h_latent, w_latent))
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
msk = mx.concatenate([
mx.repeat(msk[:, :1], 4, axis=1),
msk[:, 1:],
], axis=1)
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
y_i2v = mx.concatenate([msk, z_video], axis=0)
mx.eval(y_i2v)
del vae_enc, img_arr, img_chw, video, z_video, msk
else:
# TI2V-5B: encode single image, blend with noise via mask
img_tensor = preprocess_image(image, width, height)
mx.eval(img_tensor)
vae_enc = load_vae_encoder(vae_path, config)
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
mx.eval(z_img)
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
del vae_enc, img_tensor
del vae_enc, img_tensor
gc.collect(); mx.clear_cache()
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
@@ -282,23 +329,40 @@ def generate_video(
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
# Precompute text embeddings once (avoids redundant MLP in every step)
ref_model = single_model if not is_dual else low_noise_model
context_emb = ref_model.embed_text([context, context_null])
mx.eval(context_emb)
context_cond = context_emb[0:1] # [1, text_len, dim]
context_uncond = context_emb[1:2] # [1, text_len, dim]
# Stack for batched CFG: [2, text_len, dim]
context_cfg = mx.concatenate([context_cond, context_uncond], axis=0)
# Each model has its own text_embedding weights, so dual models need separate embeddings
if is_dual:
context_emb_low = low_noise_model.embed_text([context, context_null])
context_emb_high = high_noise_model.embed_text([context, context_null])
mx.eval(context_emb_low, context_emb_high)
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
else:
context_emb = single_model.embed_text([context, context_null])
mx.eval(context_emb)
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
# Precompute cross-attention K/V caches (constant across all steps)
if is_dual:
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg)
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
mx.eval(cross_kv_low, cross_kv_high)
else:
cross_kv = single_model.prepare_cross_kv(context_cfg)
mx.eval(cross_kv)
# Precompute RoPE frequencies (grid sizes are constant across all steps)
f_grid = t_latent // patch_size[0]
h_grid = h_latent // patch_size[1]
w_grid = w_latent // patch_size[2]
cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
if is_dual:
rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes)
rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes)
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
else:
rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes)
mx.eval(rope_cos_sin)
# Setup scheduler
_schedulers = {
"euler": FlowMatchEulerScheduler,
@@ -312,9 +376,8 @@ def generate_video(
# Generate initial noise
noise = mx.random.normal(target_shape)
# I2V: blend first-frame latent into noise
if is_i2v:
# Broadcast z_img [z_dim, 1, H, W] across T for first-frame conditioning
# I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
else:
latents = noise
@@ -326,26 +389,32 @@ def generate_video(
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
t3 = time.time()
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = sched.timesteps[i].item()
# Pre-convert timesteps to Python list to avoid .item() sync each step
timestep_list = sched.timesteps.tolist()
# Select model, guide scale, and cached K/V
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
timestep_val = timestep_list[i]
# Select model, guide scale, cached K/V, and precomputed RoPE
if is_dual:
if timestep_val >= boundary:
model = high_noise_model
gs = guide_scale[1]
kv = cross_kv_high
rcs = rope_cos_sin_high
else:
model = low_noise_model
gs = guide_scale[0]
kv = cross_kv_low
rcs = rope_cos_sin_low
else:
model = single_model
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
kv = cross_kv
rcs = rope_cos_sin
# Build per-token timesteps for I2V (first-frame patches get t=0)
if is_i2v:
# Build per-token timesteps for TI2V-5B (first-frame patches get t=0)
if is_i2v_mask_blend:
t_tokens = i2v_mask_tokens * timestep_val # [1, L]
# Pad to seq_len if needed
pad_len = seq_len - t_tokens.shape[1]
@@ -358,22 +427,31 @@ def generate_video(
else:
t_batch = mx.array([timestep_val, timestep_val])
# I2V-14B: pass y conditioning to model (same y for cond and uncond)
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
# CFG: batch cond + uncond into single B=2 forward pass
ctx = context_cfg if not is_dual else (
context_cfg_high if timestep_val >= boundary else context_cfg_low
)
preds = model(
[latents, latents],
t=t_batch,
context=context_cfg,
context=ctx,
seq_len=seq_len,
cross_kv_caches=kv,
y=y_arg,
rope_cos_sin=rcs,
)
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
# Classifier-free guidance + scheduler step
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
# I2V: re-apply mask to keep first frame frozen
if is_i2v:
# TI2V-5B: re-apply mask to keep first frame frozen
if is_i2v_mask_blend:
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
# Release temporaries before eval to free memory for graph execution
@@ -385,9 +463,11 @@ def generate_video(
# Free transformer models and text embeddings
if is_dual:
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
del context_cfg_low, context_cfg_high
else:
del single_model, cross_kv
del model, kv, context, context_null, context_cfg
del context_cfg
del model, kv, context, context_null
gc.collect(); mx.clear_cache()
# Load VAE and decode