add vae tiling

This commit is contained in:
Prince Canuma
2026-01-17 07:51:54 +01:00
parent f607112407
commit e4cdbb7eab
6 changed files with 632 additions and 5 deletions

View File

@@ -30,6 +30,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_w
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
@@ -363,6 +364,7 @@ def generate_video_with_audio(
image: Optional[str] = None,
image_strength: float = 1.0,
image_frame_idx: int = 0,
tiling: str = "auto",
):
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image.
@@ -384,6 +386,7 @@ def generate_video_with_audio(
image: Path to conditioning image for I2V
image_strength: Conditioning strength (1.0 = full denoise)
image_frame_idx: Frame index to condition (0 = first frame)
tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal)
"""
start_time = time.time()
@@ -623,9 +626,36 @@ def generate_video_with_audio(
del transformer
mx.clear_cache()
# Decode video
# Decode video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
video = vae_decoder(video_latents)
# Select tiling configuration
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(video_latents)
mx.eval(video)
# Convert video to uint8 frames
@@ -762,6 +792,11 @@ Examples:
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
parser.add_argument("--image-frame-idx", type=int, default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)")
parser.add_argument("--tiling", type=str, default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="Tiling mode for VAE decoding (default: auto). "
"auto=based on size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f")
args = parser.parse_args()
@@ -783,6 +818,7 @@ Examples:
image=args.image,
image_strength=args.image_strength,
image_frame_idx=args.image_frame_idx,
tiling=args.tiling,
)