add vae tiling

This commit is contained in:
Prince Canuma
2026-01-17 07:51:54 +01:00
parent f607112407
commit e4cdbb7eab
6 changed files with 632 additions and 5 deletions

View File

@@ -27,6 +27,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder
from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state
@@ -207,6 +208,7 @@ def generate_video(
image: Optional[str] = None,
image_strength: float = 1.0,
image_frame_idx: int = 0,
tiling: str = "auto",
):
"""Generate video from text prompt, optionally conditioned on an image.
@@ -228,6 +230,14 @@ def generate_video(
image: Path to conditioning image for I2V (Image-to-Video)
image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original)
image_frame_idx: Frame index to condition (0 = first frame)
tiling: Tiling mode for VAE decoding. Options:
- "auto": Automatically determine based on video size (default)
- "none": Disable tiling
- "default": 512px spatial, 64 frame temporal
- "aggressive": 256px spatial, 32 frame temporal (lowest memory)
- "conservative": 768px spatial, 96 frame temporal (faster)
- "spatial": Spatial tiling only
- "temporal": Temporal tiling only
"""
start_time = time.time()
@@ -435,9 +445,36 @@ def generate_video(
del transformer
mx.clear_cache()
# Decode to video
# Decode to video with tiling
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
video = vae_decoder(latents)
# Select tiling configuration
if tiling == "none":
tiling_config = None
elif tiling == "auto":
tiling_config = TilingConfig.auto(height, width, num_frames)
elif tiling == "default":
tiling_config = TilingConfig.default()
elif tiling == "aggressive":
tiling_config = TilingConfig.aggressive()
elif tiling == "conservative":
tiling_config = TilingConfig.conservative()
elif tiling == "spatial":
tiling_config = TilingConfig.spatial_only()
elif tiling == "temporal":
tiling_config = TilingConfig.temporal_only()
else:
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
tiling_config = TilingConfig.auto(height, width, num_frames)
if tiling_config is not None:
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose)
else:
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
video = vae_decoder(latents)
mx.eval(video)
mx.clear_cache()
@@ -594,6 +631,15 @@ Examples:
default=0,
help="Frame index to condition for I2V (0 = first frame, default: 0)"
)
parser.add_argument(
"--tiling",
type=str,
default="auto",
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
help="Tiling mode for VAE decoding (default: auto). "
"auto=based on video size, none=disabled, default=512px/64f, "
"aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only"
)
args = parser.parse_args()
generate_video(