add vae tiling
This commit is contained in:
@@ -30,6 +30,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_w
|
||||
from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig
|
||||
from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents
|
||||
from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning
|
||||
from mlx_video.conditioning.latent import LatentState, apply_denoise_mask
|
||||
@@ -363,6 +364,7 @@ def generate_video_with_audio(
|
||||
image: Optional[str] = None,
|
||||
image_strength: float = 1.0,
|
||||
image_frame_idx: int = 0,
|
||||
tiling: str = "auto",
|
||||
):
|
||||
"""Generate video with synchronized audio from text prompt, optionally conditioned on an image.
|
||||
|
||||
@@ -384,6 +386,7 @@ def generate_video_with_audio(
|
||||
image: Path to conditioning image for I2V
|
||||
image_strength: Conditioning strength (1.0 = full denoise)
|
||||
image_frame_idx: Frame index to condition (0 = first frame)
|
||||
tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
@@ -623,9 +626,36 @@ def generate_video_with_audio(
|
||||
del transformer
|
||||
mx.clear_cache()
|
||||
|
||||
# Decode video
|
||||
# Decode video with tiling
|
||||
print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}")
|
||||
video = vae_decoder(video_latents)
|
||||
|
||||
# Select tiling configuration
|
||||
if tiling == "none":
|
||||
tiling_config = None
|
||||
elif tiling == "auto":
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
elif tiling == "default":
|
||||
tiling_config = TilingConfig.default()
|
||||
elif tiling == "aggressive":
|
||||
tiling_config = TilingConfig.aggressive()
|
||||
elif tiling == "conservative":
|
||||
tiling_config = TilingConfig.conservative()
|
||||
elif tiling == "spatial":
|
||||
tiling_config = TilingConfig.spatial_only()
|
||||
elif tiling == "temporal":
|
||||
tiling_config = TilingConfig.temporal_only()
|
||||
else:
|
||||
print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}")
|
||||
tiling_config = TilingConfig.auto(height, width, num_frames)
|
||||
|
||||
if tiling_config is not None:
|
||||
spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none"
|
||||
temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none"
|
||||
print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}")
|
||||
video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose)
|
||||
else:
|
||||
print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}")
|
||||
video = vae_decoder(video_latents)
|
||||
mx.eval(video)
|
||||
|
||||
# Convert video to uint8 frames
|
||||
@@ -762,6 +792,11 @@ Examples:
|
||||
help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)")
|
||||
parser.add_argument("--image-frame-idx", type=int, default=0,
|
||||
help="Frame index to condition for I2V (0 = first frame, default: 0)")
|
||||
parser.add_argument("--tiling", type=str, default="auto",
|
||||
choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"],
|
||||
help="Tiling mode for VAE decoding (default: auto). "
|
||||
"auto=based on size, none=disabled, default=512px/64f, "
|
||||
"aggressive=256px/32f (lowest memory), conservative=768px/96f")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -783,6 +818,7 @@ Examples:
|
||||
image=args.image,
|
||||
image_strength=args.image_strength,
|
||||
image_frame_idx=args.image_frame_idx,
|
||||
tiling=args.tiling,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user