fix(wan): Fix scheduler sigma schedule and add debug flags
This commit is contained in:
@@ -72,6 +72,8 @@ def generate_video(
|
||||
loras_low: list | None = None,
|
||||
tiling: str = "auto",
|
||||
no_compile: bool = False,
|
||||
trim_first_frames: int = 0,
|
||||
debug_latents: bool = False,
|
||||
):
|
||||
|
||||
"""Generate video using Wan pipeline (supports T2V and I2V).
|
||||
@@ -100,6 +102,12 @@ def generate_video(
|
||||
- "spatial": Spatial tiling only
|
||||
- "temporal": Temporal tiling only
|
||||
no_compile: If True, skip mx.compile on models (useful for debugging)
|
||||
trim_first_frames: Number of temporal latent positions to generate extra
|
||||
and discard from the start. Each position = 4 pixel frames. Use 1
|
||||
to fix first-frame artifacts on 14B models (generates 4 extra frames,
|
||||
discards first 4). Use 2 for more aggressive trimming. Default: 0.
|
||||
debug_latents: If True, print per-temporal-position latent statistics
|
||||
after denoising for diagnosing first-frame artifacts.
|
||||
|
||||
"""
|
||||
import json
|
||||
@@ -207,6 +215,9 @@ def generate_video(
|
||||
assert (num_frames - 1) % 4 == 0, f"num_frames must be 4n+1, got {num_frames}"
|
||||
|
||||
gen_frames = num_frames
|
||||
if trim_first_frames > 0:
|
||||
gen_frames = num_frames + trim_first_frames * 4
|
||||
print(f"{Colors.DIM} Trim: generating {gen_frames} frames, will discard first {trim_first_frames * 4}{Colors.RESET}")
|
||||
|
||||
version_str = f"Wan{config.model_version}"
|
||||
mode_str = "dual-model" if is_dual else "single-model"
|
||||
@@ -595,6 +606,22 @@ def generate_video(
|
||||
|
||||
print(f"{Colors.DIM} Denoising: {time.time() - t3:.1f}s{Colors.RESET}")
|
||||
|
||||
# Diagnostic: per-temporal-position latent statistics
|
||||
if debug_latents:
|
||||
lat_np = np.array(latents) # [C, T, H, W]
|
||||
n_t = lat_np.shape[1]
|
||||
print(f"\n{Colors.CYAN} Latent diagnostics (shape {lat_np.shape}):{Colors.RESET}")
|
||||
print(f" {'Pos':>4s} {'Mean':>8s} {'Std':>8s} {'Min':>8s} {'Max':>8s} {'AbsMean':>8s}")
|
||||
for t_pos in range(min(n_t, 8)):
|
||||
frame = lat_np[:, t_pos, :, :]
|
||||
print(f" {t_pos:4d} {frame.mean():8.4f} {frame.std():8.4f} "
|
||||
f"{frame.min():8.4f} {frame.max():8.4f} {np.abs(frame).mean():8.4f}")
|
||||
if n_t > 8:
|
||||
interior = lat_np[:, 4:, :, :]
|
||||
print(f" {'4+':>4s} {interior.mean():8.4f} {interior.std():8.4f} "
|
||||
f"{interior.min():8.4f} {interior.max():8.4f} {np.abs(interior).mean():8.4f}")
|
||||
print()
|
||||
|
||||
# Free transformer models and text embeddings
|
||||
if is_dual:
|
||||
del low_noise_model, high_noise_model, cross_kv_low, cross_kv_high
|
||||
@@ -621,6 +648,9 @@ def generate_video(
|
||||
|
||||
is_wan22_vae = config.vae_z_dim == 48
|
||||
|
||||
# Temporal extend: prepend reflected latent frames to the VAE input so that
|
||||
# the CausalConv3d zero-padding artifacts fall on the prefix (which we crop).
|
||||
# This gives the first real frame a full temporal receptive field of real data.
|
||||
# Select tiling configuration
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||
|
||||
@@ -676,6 +706,12 @@ def generate_video(
|
||||
video = np.clip(video * 255.0, 0, 255).astype(np.uint8)
|
||||
video = video.transpose(1, 2, 3, 0) # [T, H, W, 3]
|
||||
|
||||
# Trim first N temporal chunks if requested (avoids first-frame artifacts)
|
||||
if trim_first_frames > 0:
|
||||
trim_pixels = trim_first_frames * 4
|
||||
video = video[trim_pixels:]
|
||||
print(f"{Colors.DIM} Trimmed first {trim_pixels} frames ({video.shape[0]} remaining){Colors.RESET}")
|
||||
|
||||
save_video(video, output_path, fps=config.sample_fps)
|
||||
print(f"\n{Colors.GREEN}✓ Video saved to {output_path}{Colors.RESET}")
|
||||
print(f"{Colors.DIM} Total time: {time.time() - t1:.1f}s{Colors.RESET}")
|
||||
@@ -727,6 +763,16 @@ def main():
|
||||
"--no-compile", action="store_true",
|
||||
help="Disable mx.compile on models (for debugging)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trim-first-frames", type=int, default=0, metavar="N",
|
||||
help="Generate N extra temporal chunks (N×4 frames) and discard them from the start. "
|
||||
"Fixes first-frame color/lighting artifacts on 14B models. Try 1 first (4 frames). "
|
||||
"Default: 0 (disabled)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug-latents", action="store_true",
|
||||
help="Print per-temporal-position latent statistics after denoising (diagnostic)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -766,6 +812,8 @@ def main():
|
||||
loras_low=_parse_lora_args(args.lora_low),
|
||||
tiling=args.tiling,
|
||||
no_compile=args.no_compile,
|
||||
trim_first_frames=args.trim_first_frames,
|
||||
debug_latents=args.debug_latents,
|
||||
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user