feat(wan): Add chunked VAE encoding and TI2V-5B support
This commit is contained in:
@@ -29,13 +29,37 @@ from mlx_video.utils import Colors
|
||||
_build_i2v_mask = build_i2v_mask
|
||||
|
||||
|
||||
def _best_output_size(w, h, dw, dh, max_area):
|
||||
"""Compute the best output resolution that fits within max_area while
|
||||
preserving the input aspect ratio and satisfying alignment constraints.
|
||||
Matches the reference implementation's best_output_size().
|
||||
"""
|
||||
ratio = w / h
|
||||
ow = (max_area * ratio) ** 0.5
|
||||
oh = max_area / ow
|
||||
|
||||
# Option 1: process width first
|
||||
ow1 = int(ow // dw * dw)
|
||||
oh1 = int(max_area / ow1 // dh * dh)
|
||||
ratio1 = ow1 / oh1
|
||||
|
||||
# Option 2: process height first
|
||||
oh2 = int(oh // dh * dh)
|
||||
ow2 = int(max_area / oh2 // dw * dw)
|
||||
ratio2 = ow2 / oh2
|
||||
|
||||
if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio):
|
||||
return ow1, oh1
|
||||
return ow2, oh2
|
||||
|
||||
|
||||
def generate_video(
|
||||
model_dir: str,
|
||||
prompt: str,
|
||||
negative_prompt: str | None = None,
|
||||
image: str | None = None,
|
||||
width: int = 1280,
|
||||
height: int = 720,
|
||||
height: int = 704,
|
||||
num_frames: int = 81,
|
||||
steps: int = None,
|
||||
guide_scale: str | float | tuple = None,
|
||||
@@ -232,6 +256,15 @@ def generate_video(
|
||||
width = align_w
|
||||
print(f"{Colors.DIM} Aligned {old_w}x{old_h} → {width}x{height} (must be divisible by {align_w}x{align_h}){Colors.RESET}")
|
||||
|
||||
# Enforce max_area constraint (model-specific resolution limit)
|
||||
if config.max_area > 0 and height * width > config.max_area:
|
||||
old_h, old_w = height, width
|
||||
width, height = _best_output_size(width, height, align_w, align_h, config.max_area)
|
||||
print(
|
||||
f"{Colors.YELLOW} ⚠ Resolution {old_w}x{old_h} exceeds model's max area "
|
||||
f"({config.max_area:,}px). Adjusted → {width}x{height}{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Compute target latent shape
|
||||
z_dim = config.vae_z_dim
|
||||
t_latent = (gen_frames - 1) // vae_stride[0] + 1
|
||||
@@ -334,7 +367,7 @@ def generate_video(
|
||||
mx.eval(img_tensor)
|
||||
|
||||
vae_enc = load_vae_encoder(vae_path, config)
|
||||
z_img = vae_enc(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||
z_img = vae_enc.encode(img_tensor) # [1, 1, H_lat, W_lat, z_dim]
|
||||
mx.eval(z_img)
|
||||
z_img = z_img[0].transpose(3, 0, 1, 2) # [z_dim, 1, H_lat, W_lat]
|
||||
i2v_mask, i2v_mask_tokens = build_i2v_mask(target_shape, config.patch_size)
|
||||
@@ -658,8 +691,8 @@ def main():
|
||||
help="Negative prompt for CFG (default: official Chinese prompt from config)")
|
||||
parser.add_argument("--no-negative-prompt", action="store_true",
|
||||
help="Disable negative prompt (use empty string instead of config default)")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width")
|
||||
parser.add_argument("--height", type=int, default=720, help="Video height")
|
||||
parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280)")
|
||||
parser.add_argument("--height", type=int, default=704, help="Video height (default: 704; 720p models use 704)")
|
||||
parser.add_argument("--num-frames", type=int, default=81, help="Number of frames (must be 4n+1)")
|
||||
parser.add_argument("--steps", type=int, default=None, help="Number of diffusion steps (default: from config)")
|
||||
parser.add_argument("--guide-scale", type=str, default=None, help="Guidance scale: single float or low,high pair")
|
||||
|
||||
Reference in New Issue
Block a user