feat(wan): Add I2V-14B dual-model support
This commit is contained in:
@@ -245,24 +245,71 @@ def generate_video(
|
||||
z_img = None
|
||||
i2v_mask = None
|
||||
i2v_mask_tokens = None
|
||||
y_i2v = None
|
||||
is_i2v_channel_concat = is_i2v and config.model_type == "i2v"
|
||||
is_i2v_mask_blend = is_i2v and config.model_type != "i2v"
|
||||
if is_i2v:
|
||||
print(f"\n{Colors.BLUE}Encoding input image...{Colors.RESET}")
|
||||
t_img = time.time()
|
||||
img_tensor = preprocess_image(image, width, height)
|
||||
mx.eval(img_tensor)
|
||||
|
||||
vae_path = model_dir / "vae.safetensors"
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||
mx.eval(z_img)
|
||||
|
||||
# Convert to channels-first: [z_dim, 1, H_lat, W_lat]
|
||||
z_img = z_img[0].transpose(3, 0, 1, 2)
|
||||
if is_i2v_channel_concat:
|
||||
# I2V-14B: encode full video (first frame = image, rest = zeros)
|
||||
# and construct y tensor with mask + encoded latents
|
||||
from PIL import Image
|
||||
|
||||
# Build I2V mask
|
||||
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||
img = Image.open(image).convert("RGB")
|
||||
scale = max(width / img.width, height / img.height)
|
||||
img = img.resize((round(img.width * scale), round(img.height * scale)), Image.LANCZOS)
|
||||
x1, y1 = (img.width - width) // 2, (img.height - height) // 2
|
||||
img = img.crop((x1, y1, x1 + width, y1 + height))
|
||||
img_arr = mx.array(np.array(img, dtype=np.float32) / 255.0 * 2.0 - 1.0) # [H, W, 3]
|
||||
img_chw = img_arr.transpose(2, 0, 1) # [3, H, W]
|
||||
|
||||
# Build video: first frame = image, rest = zeros -> [3, F, H, W]
|
||||
# Chunked encoding processes 1-frame + 4-frame chunks with temporal caching
|
||||
video = mx.concatenate([
|
||||
img_chw[:, None, :, :],
|
||||
mx.zeros((3, num_frames - 1, height, width)),
|
||||
], axis=1)
|
||||
|
||||
# Encode through Wan2.1 VAE -> [1, z_dim, T_lat, H_lat, W_lat]
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_video = vae_enc.encode(video[None]) # [1, 16, T_lat, H_lat, W_lat]
|
||||
mx.eval(z_video)
|
||||
z_video = z_video[0] # [16, T_lat, H_lat, W_lat]
|
||||
|
||||
# Build mask: 1 for first frame, 0 for rest -> rearrange to [4, T_lat, H, W]
|
||||
msk = mx.ones((1, num_frames, h_latent, w_latent))
|
||||
msk = mx.concatenate([msk[:, :1], mx.zeros((1, num_frames - 1, h_latent, w_latent))], axis=1)
|
||||
# Repeat first frame 4x, concat rest: [1, 4 + (F-1), H_lat, W_lat]
|
||||
msk = mx.concatenate([
|
||||
mx.repeat(msk[:, :1], 4, axis=1),
|
||||
msk[:, 1:],
|
||||
], axis=1)
|
||||
# Reshape to [1, T_lat, 4, H_lat, W_lat] then transpose -> [4, T_lat, H_lat, W_lat]
|
||||
msk = msk.reshape(1, msk.shape[1] // 4, 4, h_latent, w_latent)
|
||||
msk = msk.transpose(0, 2, 1, 3, 4)[0] # [4, T_lat, H_lat, W_lat]
|
||||
|
||||
# y = concat([mask, encoded_video]) -> [20, T_lat, H_lat, W_lat]
|
||||
y_i2v = mx.concatenate([msk, z_video], axis=0)
|
||||
mx.eval(y_i2v)
|
||||
|
||||
del vae_enc, img_arr, img_chw, video, z_video, msk
|
||||
else:
|
||||
# TI2V-5B: encode single image, blend with noise via mask
|
||||
img_tensor = preprocess_image(image, width, height)
|
||||
mx.eval(img_tensor)
|
||||
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||
mx.eval(z_img)
|
||||
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
|
||||
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||
|
||||
del vae_enc, img_tensor
|
||||
|
||||
del vae_enc, img_tensor
|
||||
gc.collect(); mx.clear_cache()
|
||||
print(f"{Colors.DIM} Image encoding: {time.time() - t_img:.1f}s{Colors.RESET}")
|
||||
|
||||
@@ -282,23 +329,40 @@ def generate_video(
|
||||
print(f"{Colors.DIM} Models loaded: {time.time() - t2:.1f}s{Colors.RESET}")
|
||||
|
||||
# Precompute text embeddings once (avoids redundant MLP in every step)
|
||||
ref_model = single_model if not is_dual else low_noise_model
|
||||
context_emb = ref_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
context_cond = context_emb[0:1] # [1, text_len, dim]
|
||||
context_uncond = context_emb[1:2] # [1, text_len, dim]
|
||||
# Stack for batched CFG: [2, text_len, dim]
|
||||
context_cfg = mx.concatenate([context_cond, context_uncond], axis=0)
|
||||
# Each model has its own text_embedding weights, so dual models need separate embeddings
|
||||
if is_dual:
|
||||
context_emb_low = low_noise_model.embed_text([context, context_null])
|
||||
context_emb_high = high_noise_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb_low, context_emb_high)
|
||||
context_cfg_low = mx.concatenate([context_emb_low[0:1], context_emb_low[1:2]], axis=0)
|
||||
context_cfg_high = mx.concatenate([context_emb_high[0:1], context_emb_high[1:2]], axis=0)
|
||||
else:
|
||||
context_emb = single_model.embed_text([context, context_null])
|
||||
mx.eval(context_emb)
|
||||
context_cfg = mx.concatenate([context_emb[0:1], context_emb[1:2]], axis=0)
|
||||
|
||||
# Precompute cross-attention K/V caches (constant across all steps)
|
||||
if is_dual:
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg)
|
||||
cross_kv_low = low_noise_model.prepare_cross_kv(context_cfg_low)
|
||||
cross_kv_high = high_noise_model.prepare_cross_kv(context_cfg_high)
|
||||
mx.eval(cross_kv_low, cross_kv_high)
|
||||
else:
|
||||
cross_kv = single_model.prepare_cross_kv(context_cfg)
|
||||
mx.eval(cross_kv)
|
||||
|
||||
# Precompute RoPE frequencies (grid sizes are constant across all steps)
|
||||
f_grid = t_latent // patch_size[0]
|
||||
h_grid = h_latent // patch_size[1]
|
||||
w_grid = w_latent // patch_size[2]
|
||||
cfg_grid_sizes = [(f_grid, h_grid, w_grid), (f_grid, h_grid, w_grid)]
|
||||
if is_dual:
|
||||
rope_cos_sin_low = low_noise_model.prepare_rope(cfg_grid_sizes)
|
||||
rope_cos_sin_high = high_noise_model.prepare_rope(cfg_grid_sizes)
|
||||
mx.eval(rope_cos_sin_low, rope_cos_sin_high)
|
||||
else:
|
||||
rope_cos_sin = ref_model.prepare_rope(cfg_grid_sizes)
|
||||
mx.eval(rope_cos_sin)
|
||||
|
||||
# Setup scheduler
|
||||
_schedulers = {
|
||||
"euler": FlowMatchEulerScheduler,
|
||||
@@ -312,9 +376,8 @@ def generate_video(
|
||||
# Generate initial noise
|
||||
noise = mx.random.normal(target_shape)
|
||||
|
||||
# I2V: blend first-frame latent into noise
|
||||
if is_i2v:
|
||||
# Broadcast z_img [z_dim, 1, H, W] across T for first-frame conditioning
|
||||
# I2V initialization: TI2V-5B blends image with noise, I2V-14B uses pure noise
|
||||
if is_i2v_mask_blend:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * noise
|
||||
else:
|
||||
latents = noise
|
||||
@@ -326,26 +389,32 @@ def generate_video(
|
||||
print(f"\n{Colors.GREEN}Denoising ({steps} steps)...{Colors.RESET}")
|
||||
t3 = time.time()
|
||||
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
timestep_val = sched.timesteps[i].item()
|
||||
# Pre-convert timesteps to Python list to avoid .item() sync each step
|
||||
timestep_list = sched.timesteps.tolist()
|
||||
|
||||
# Select model, guide scale, and cached K/V
|
||||
for i, t in enumerate(tqdm(range(steps), desc="Diffusion")):
|
||||
timestep_val = timestep_list[i]
|
||||
|
||||
# Select model, guide scale, cached K/V, and precomputed RoPE
|
||||
if is_dual:
|
||||
if timestep_val >= boundary:
|
||||
model = high_noise_model
|
||||
gs = guide_scale[1]
|
||||
kv = cross_kv_high
|
||||
rcs = rope_cos_sin_high
|
||||
else:
|
||||
model = low_noise_model
|
||||
gs = guide_scale[0]
|
||||
kv = cross_kv_low
|
||||
rcs = rope_cos_sin_low
|
||||
else:
|
||||
model = single_model
|
||||
gs = guide_scale if isinstance(guide_scale, (int, float)) else guide_scale[0]
|
||||
kv = cross_kv
|
||||
rcs = rope_cos_sin
|
||||
|
||||
# Build per-token timesteps for I2V (first-frame patches get t=0)
|
||||
if is_i2v:
|
||||
# Build per-token timesteps for TI2V-5B (first-frame patches get t=0)
|
||||
if is_i2v_mask_blend:
|
||||
t_tokens = i2v_mask_tokens * timestep_val # [1, L]
|
||||
# Pad to seq_len if needed
|
||||
pad_len = seq_len - t_tokens.shape[1]
|
||||
@@ -358,22 +427,31 @@ def generate_video(
|
||||
else:
|
||||
t_batch = mx.array([timestep_val, timestep_val])
|
||||
|
||||
# I2V-14B: pass y conditioning to model (same y for cond and uncond)
|
||||
y_arg = [y_i2v, y_i2v] if is_i2v_channel_concat else None
|
||||
|
||||
# CFG: batch cond + uncond into single B=2 forward pass
|
||||
ctx = context_cfg if not is_dual else (
|
||||
context_cfg_high if timestep_val >= boundary else context_cfg_low
|
||||
)
|
||||
preds = model(
|
||||
[latents, latents],
|
||||
t=t_batch,
|
||||
context=context_cfg,
|
||||
context=ctx,
|
||||
seq_len=seq_len,
|
||||
cross_kv_caches=kv,
|
||||
y=y_arg,
|
||||
rope_cos_sin=rcs,
|
||||
)
|
||||
noise_pred_cond, noise_pred_uncond = preds[0], preds[1]
|
||||
|
||||
# Classifier-free guidance + scheduler step
|
||||
noise_pred = noise_pred_uncond + gs * (noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
latents = sched.step(noise_pred[None], timestep_val, latents[None]).squeeze(0)
|
||||
|
||||
# I2V: re-apply mask to keep first frame frozen
|
||||
if is_i2v:
|
||||
# TI2V-5B: re-apply mask to keep first frame frozen
|
||||
if is_i2v_mask_blend:
|
||||
latents = (1.0 - i2v_mask) * z_img + i2v_mask * latents
|
||||
|
||||
# Release temporaries before eval to free memory for graph execution
|
||||
@@ -385,9 +463,11 @@ def generate_video(
|
||||
# Free transformer models and text embeddings
|
||||
if is_dual:
|
||||
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
||||
del context_cfg_low, context_cfg_high
|
||||
else:
|
||||
del single_model, cross_kv
|
||||
del model, kv, context, context_null, context_cfg
|
||||
del context_cfg
|
||||
del model, kv, context, context_null
|
||||
gc.collect(); mx.clear_cache()
|
||||
|
||||
# Load VAE and decode
|
||||
|
||||
Reference in New Issue
Block a user