This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
# Apply right ramp (fade out)
if ramp_right > 0:
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
fade_out = [
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
]
for i in range(ramp_right):
mask[length - ramp_right + i] *= fade_out[i]
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
def __post_init__(self) -> None:
if self.tile_size_in_pixels < 64:
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
raise ValueError(
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
)
if self.tile_size_in_pixels % 32 != 0:
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
raise ValueError(
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
)
if self.tile_overlap_in_pixels % 32 != 0:
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
raise ValueError(
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
)
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
def __post_init__(self) -> None:
if self.tile_size_in_frames < 16:
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
raise ValueError(
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
)
if self.tile_size_in_frames % 8 != 0:
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
raise ValueError(
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
)
if self.tile_overlap_in_frames % 8 != 0:
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
raise ValueError(
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
)
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
raise ValueError(
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
@@ -113,15 +127,21 @@ class TilingConfig:
def default(cls) -> "TilingConfig":
"""Default tiling: 512px spatial, 64 frame temporal."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
),
)
@classmethod
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
"""Spatial tiling only (for short videos with large resolution)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
),
temporal_config=None,
)
@@ -130,23 +150,33 @@ class TilingConfig:
"""Temporal tiling only (for long videos with small resolution)."""
return cls(
spatial_config=None,
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
),
)
@classmethod
def aggressive(cls) -> "TilingConfig":
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=256, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=32, tile_overlap_in_frames=8
),
)
@classmethod
def conservative(cls) -> "TilingConfig":
"""Conservative tiling (larger tiles, less memory savings but faster)."""
return cls(
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
spatial_config=SpatialTilingConfig(
tile_size_in_pixels=768, tile_overlap_in_pixels=64
),
temporal_config=TemporalTilingConfig(
tile_size_in_frames=96, tile_overlap_in_frames=24
),
)
@classmethod
@@ -186,10 +216,14 @@ class TilingConfig:
temporal_config = None
if needs_spatial:
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
spatial_config = SpatialTilingConfig(
tile_size_in_pixels=512, tile_overlap_in_pixels=64
)
if needs_temporal:
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
temporal_config = TemporalTilingConfig(
tile_size_in_frames=64, tile_overlap_in_frames=24
)
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
@@ -197,16 +231,21 @@ class TilingConfig:
@dataclass
class DimensionIntervals:
"""Intervals for splitting a single dimension."""
starts: List[int]
ends: List[int]
left_ramps: List[int]
right_ramps: List[int]
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
def split_in_spatial(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a spatial dimension into intervals."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
starts = [i * (size - overlap) for i in range(amount)]
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
left_ramps = [0] + [overlap] * (amount - 1)
right_ramps = [overlap] * (amount - 1) + [0]
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
return DimensionIntervals(
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
)
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
def split_in_temporal(
size: int, overlap: int, dimension_size: int
) -> DimensionIntervals:
"""Split a temporal dimension into intervals with causal adjustment."""
if dimension_size <= size:
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
return DimensionIntervals(
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
)
# Start with spatial split
intervals = split_in_spatial(size, overlap, dimension_size)
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
starts[i] = starts[i] - 1
left_ramps[i] = left_ramps[i] + 1
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
return DimensionIntervals(
starts=starts,
ends=intervals.ends,
left_ramps=left_ramps,
right_ramps=intervals.right_ramps,
)
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
def map_temporal_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map temporal latent interval to output coordinates and mask."""
start = begin * scale
stop = 1 + (end - 1) * scale
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, True
)
return slice(start, stop), mask
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
def map_spatial_slice(
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
) -> Tuple[slice, mx.array]:
"""Map spatial latent interval to output coordinates and mask."""
start = begin * scale
stop = end * scale
left_ramp_scaled = left_ramp * scale
right_ramp_scaled = right_ramp * scale
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
mask = compute_trapezoidal_mask_1d(
stop - start, left_ramp_scaled, right_ramp_scaled, False
)
return slice(start, stop), mask
@@ -315,7 +373,9 @@ def decode_with_tiling(
temporal_overlap = 0
# Compute intervals for each dimension
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
temporal_intervals = split_in_temporal(
temporal_tile_size, temporal_overlap, f_latent
)
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
@@ -338,7 +398,9 @@ def decode_with_tiling(
t_right = temporal_intervals.right_ramps[t_idx]
# Map temporal coordinates
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
out_t_slice, t_mask = map_temporal_slice(
t_start, t_end, t_left, t_right, temporal_scale
)
for h_idx in range(num_h_tiles):
h_start = height_intervals.starts[h_idx]
@@ -347,7 +409,9 @@ def decode_with_tiling(
h_right = height_intervals.right_ramps[h_idx]
# Map height coordinates
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
out_h_slice, h_mask = map_spatial_slice(
h_start, h_end, h_left, h_right, spatial_scale
)
for w_idx in range(num_w_tiles):
w_start = width_intervals.starts[w_idx]
@@ -356,13 +420,23 @@ def decode_with_tiling(
w_right = width_intervals.right_ramps[w_idx]
# Map width coordinates
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
out_w_slice, w_mask = map_spatial_slice(
w_start, w_end, w_left, w_right, spatial_scale
)
# Extract tile latents (small slice)
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
tile_latents = latents[
:, :, t_start:t_end, h_start:h_end, w_start:w_end
]
# Decode tile
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
tile_output = decoder_fn(
tile_latents,
causal=causal,
timestep=timestep,
debug=False,
chunked_conv=chunked_conv,
)
mx.eval(tile_output)
# Clear tile_latents reference
@@ -385,13 +459,15 @@ def decode_with_tiling(
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
blend_mask = (
t_mask_slice.reshape(1, 1, -1, 1, 1) *
h_mask_slice.reshape(1, 1, 1, -1, 1) *
w_mask_slice.reshape(1, 1, 1, 1, -1)
t_mask_slice.reshape(1, 1, -1, 1, 1)
* h_mask_slice.reshape(1, 1, 1, -1, 1)
* w_mask_slice.reshape(1, 1, 1, 1, -1)
)
# Slice tile output to match
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
tile_output_slice = tile_output[
:, :, :actual_t, :actual_h, :actual_w
].astype(mx.float32)
# Clear full tile_output
del tile_output
@@ -409,11 +485,37 @@ def decode_with_tiling(
weighted_tile = tile_output_slice * blend_mask
# Update output using slice assignment
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
output[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ weighted_tile
)
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
] = (
weights[
:,
:,
t_out_start:t_out_end,
h_out_start:h_out_end,
w_out_start:w_out_end,
]
+ blend_mask
)
# Force evaluation to free memory
@@ -445,10 +547,12 @@ def decode_with_tiling(
if next_tile_start_latent == 0:
next_tile_start_out = 0
else:
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
next_tile_start_out = (
1 + (next_tile_start_latent - 1) * temporal_scale
)
# We need to track how many frames we've already emitted
if not hasattr(decode_with_tiling, '_emitted_frames'):
if not hasattr(decode_with_tiling, "_emitted_frames"):
decode_with_tiling._emitted_frames = 0
emitted = decode_with_tiling._emitted_frames
@@ -456,7 +560,10 @@ def decode_with_tiling(
# Normalize and emit frames [emitted, next_tile_start_out)
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
finalized_weights = mx.maximum(finalized_weights, 1e-8)
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
finalized_output = (
output[:, :, emitted:next_tile_start_out, :, :]
/ finalized_weights
)
finalized_output = finalized_output.astype(latents.dtype)
mx.eval(finalized_output)
@@ -473,7 +580,7 @@ def decode_with_tiling(
# Emit remaining frames if callback provided
if on_frames_ready is not None:
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
if emitted < out_f:
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
mx.eval(remaining_output)
@@ -481,7 +588,7 @@ def decode_with_tiling(
del remaining_output
# Reset emitted frames counter for next call
if hasattr(decode_with_tiling, '_emitted_frames'):
if hasattr(decode_with_tiling, "_emitted_frames"):
del decode_with_tiling._emitted_frames
# Clean up weights