fix tiling, rope precision and weights

This commit is contained in:
Prince Canuma
2026-03-15 22:58:55 +01:00
parent ebcd5dd4e4
commit cecd68197c
5 changed files with 86 additions and 149 deletions

View File

@@ -160,6 +160,9 @@ class TilingConfig:
) -> Optional["TilingConfig"]:
"""Automatically determine tiling config based on video dimensions.
Uses PyTorch's default tiling (512px spatial, 64f temporal) which provides
enough context for CausalConv3d and sufficient overlap for clean blending.
Args:
height: Video height in pixels
width: Video width in pixels
@@ -176,37 +179,17 @@ class TilingConfig:
if not needs_spatial and not needs_temporal:
return None
# Estimate memory requirement (rough heuristic)
# Output size in bytes (float32): B * 3 * F * H * W * 4
estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3)
# For very large videos, use aggressive tiling
if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100):
return cls.aggressive()
# Use the same defaults as PyTorch (512px spatial, 64f temporal).
# Smaller tiles cause quality degradation because CausalConv3d needs
# sufficient temporal context and overlap for clean blending.
spatial_config = None
temporal_config = None
if needs_spatial:
# Choose tile size based on resolution
max_dim = max(height, width)
if max_dim > 1024:
tile_size = 384 # Smaller tiles for very large resolutions
elif max_dim > 768:
tile_size = 512
else:
tile_size = 384
spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64)
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
if needs_temporal:
# Choose tile size based on frame count
if num_frames > 200:
tile_size, overlap = 32, 8 # Aggressive for very long videos
elif num_frames > 100:
tile_size, overlap = 48, 16
else:
tile_size, overlap = 64, 24
temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap)
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)