format
This commit is contained in:
@@ -55,7 +55,9 @@ def compute_trapezoidal_mask_1d(
|
||||
# Apply right ramp (fade out)
|
||||
if ramp_right > 0:
|
||||
# Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1]
|
||||
fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)]
|
||||
fade_out = [
|
||||
(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)
|
||||
]
|
||||
for i in range(ramp_right):
|
||||
mask[length - ramp_right + i] *= fade_out[i]
|
||||
|
||||
@@ -71,11 +73,17 @@ class SpatialTilingConfig:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_pixels < 64:
|
||||
raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}"
|
||||
)
|
||||
if self.tile_size_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}"
|
||||
)
|
||||
if self.tile_overlap_in_pixels % 32 != 0:
|
||||
raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}")
|
||||
raise ValueError(
|
||||
f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}"
|
||||
)
|
||||
if self.tile_overlap_in_pixels >= self.tile_size_in_pixels:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}"
|
||||
@@ -91,11 +99,17 @@ class TemporalTilingConfig:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.tile_size_in_frames < 16:
|
||||
raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}"
|
||||
)
|
||||
if self.tile_size_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}"
|
||||
)
|
||||
if self.tile_overlap_in_frames % 8 != 0:
|
||||
raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}")
|
||||
raise ValueError(
|
||||
f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}"
|
||||
)
|
||||
if self.tile_overlap_in_frames >= self.tile_size_in_frames:
|
||||
raise ValueError(
|
||||
f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}"
|
||||
@@ -113,15 +127,21 @@ class TilingConfig:
|
||||
def default(cls) -> "TilingConfig":
|
||||
"""Default tiling: 512px spatial, 64 frame temporal."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig":
|
||||
"""Spatial tiling only (for short videos with large resolution)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap
|
||||
),
|
||||
temporal_config=None,
|
||||
)
|
||||
|
||||
@@ -130,23 +150,33 @@ class TilingConfig:
|
||||
"""Temporal tiling only (for long videos with small resolution)."""
|
||||
return cls(
|
||||
spatial_config=None,
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def aggressive(cls) -> "TilingConfig":
|
||||
"""Aggressive tiling for very large videos (smaller tiles, much lower memory)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=256, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=32, tile_overlap_in_frames=8
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def conservative(cls) -> "TilingConfig":
|
||||
"""Conservative tiling (larger tiles, less memory savings but faster)."""
|
||||
return cls(
|
||||
spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64),
|
||||
temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24),
|
||||
spatial_config=SpatialTilingConfig(
|
||||
tile_size_in_pixels=768, tile_overlap_in_pixels=64
|
||||
),
|
||||
temporal_config=TemporalTilingConfig(
|
||||
tile_size_in_frames=96, tile_overlap_in_frames=24
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -186,10 +216,14 @@ class TilingConfig:
|
||||
temporal_config = None
|
||||
|
||||
if needs_spatial:
|
||||
spatial_config = SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64)
|
||||
spatial_config = SpatialTilingConfig(
|
||||
tile_size_in_pixels=512, tile_overlap_in_pixels=64
|
||||
)
|
||||
|
||||
if needs_temporal:
|
||||
temporal_config = TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24)
|
||||
temporal_config = TemporalTilingConfig(
|
||||
tile_size_in_frames=64, tile_overlap_in_frames=24
|
||||
)
|
||||
|
||||
return cls(spatial_config=spatial_config, temporal_config=temporal_config)
|
||||
|
||||
@@ -197,16 +231,21 @@ class TilingConfig:
|
||||
@dataclass
|
||||
class DimensionIntervals:
|
||||
"""Intervals for splitting a single dimension."""
|
||||
|
||||
starts: List[int]
|
||||
ends: List[int]
|
||||
left_ramps: List[int]
|
||||
right_ramps: List[int]
|
||||
|
||||
|
||||
def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
def split_in_spatial(
|
||||
size: int, overlap: int, dimension_size: int
|
||||
) -> DimensionIntervals:
|
||||
"""Split a spatial dimension into intervals."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
return DimensionIntervals(
|
||||
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||
)
|
||||
|
||||
amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap)
|
||||
starts = [i * (size - overlap) for i in range(amount)]
|
||||
@@ -215,13 +254,19 @@ def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionI
|
||||
left_ramps = [0] + [overlap] * (amount - 1)
|
||||
right_ramps = [overlap] * (amount - 1) + [0]
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps)
|
||||
return DimensionIntervals(
|
||||
starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps
|
||||
)
|
||||
|
||||
|
||||
def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals:
|
||||
def split_in_temporal(
|
||||
size: int, overlap: int, dimension_size: int
|
||||
) -> DimensionIntervals:
|
||||
"""Split a temporal dimension into intervals with causal adjustment."""
|
||||
if dimension_size <= size:
|
||||
return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0])
|
||||
return DimensionIntervals(
|
||||
starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]
|
||||
)
|
||||
|
||||
# Start with spatial split
|
||||
intervals = split_in_spatial(size, overlap, dimension_size)
|
||||
@@ -234,28 +279,41 @@ def split_in_temporal(size: int, overlap: int, dimension_size: int) -> Dimension
|
||||
starts[i] = starts[i] - 1
|
||||
left_ramps[i] = left_ramps[i] + 1
|
||||
|
||||
return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps)
|
||||
return DimensionIntervals(
|
||||
starts=starts,
|
||||
ends=intervals.ends,
|
||||
left_ramps=left_ramps,
|
||||
right_ramps=intervals.right_ramps,
|
||||
)
|
||||
|
||||
|
||||
def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
def map_temporal_slice(
|
||||
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||
) -> Tuple[slice, mx.array]:
|
||||
"""Map temporal latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = 1 + (end - 1) * scale
|
||||
left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True)
|
||||
mask = compute_trapezoidal_mask_1d(
|
||||
stop - start, left_ramp_scaled, right_ramp_scaled, True
|
||||
)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]:
|
||||
def map_spatial_slice(
|
||||
begin: int, end: int, left_ramp: int, right_ramp: int, scale: int
|
||||
) -> Tuple[slice, mx.array]:
|
||||
"""Map spatial latent interval to output coordinates and mask."""
|
||||
start = begin * scale
|
||||
stop = end * scale
|
||||
left_ramp_scaled = left_ramp * scale
|
||||
right_ramp_scaled = right_ramp * scale
|
||||
|
||||
mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False)
|
||||
mask = compute_trapezoidal_mask_1d(
|
||||
stop - start, left_ramp_scaled, right_ramp_scaled, False
|
||||
)
|
||||
return slice(start, stop), mask
|
||||
|
||||
|
||||
@@ -315,7 +373,9 @@ def decode_with_tiling(
|
||||
temporal_overlap = 0
|
||||
|
||||
# Compute intervals for each dimension
|
||||
temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent)
|
||||
temporal_intervals = split_in_temporal(
|
||||
temporal_tile_size, temporal_overlap, f_latent
|
||||
)
|
||||
height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent)
|
||||
width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent)
|
||||
|
||||
@@ -338,7 +398,9 @@ def decode_with_tiling(
|
||||
t_right = temporal_intervals.right_ramps[t_idx]
|
||||
|
||||
# Map temporal coordinates
|
||||
out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale)
|
||||
out_t_slice, t_mask = map_temporal_slice(
|
||||
t_start, t_end, t_left, t_right, temporal_scale
|
||||
)
|
||||
|
||||
for h_idx in range(num_h_tiles):
|
||||
h_start = height_intervals.starts[h_idx]
|
||||
@@ -347,7 +409,9 @@ def decode_with_tiling(
|
||||
h_right = height_intervals.right_ramps[h_idx]
|
||||
|
||||
# Map height coordinates
|
||||
out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale)
|
||||
out_h_slice, h_mask = map_spatial_slice(
|
||||
h_start, h_end, h_left, h_right, spatial_scale
|
||||
)
|
||||
|
||||
for w_idx in range(num_w_tiles):
|
||||
w_start = width_intervals.starts[w_idx]
|
||||
@@ -356,13 +420,23 @@ def decode_with_tiling(
|
||||
w_right = width_intervals.right_ramps[w_idx]
|
||||
|
||||
# Map width coordinates
|
||||
out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale)
|
||||
out_w_slice, w_mask = map_spatial_slice(
|
||||
w_start, w_end, w_left, w_right, spatial_scale
|
||||
)
|
||||
|
||||
# Extract tile latents (small slice)
|
||||
tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end]
|
||||
tile_latents = latents[
|
||||
:, :, t_start:t_end, h_start:h_end, w_start:w_end
|
||||
]
|
||||
|
||||
# Decode tile
|
||||
tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False, chunked_conv=chunked_conv)
|
||||
tile_output = decoder_fn(
|
||||
tile_latents,
|
||||
causal=causal,
|
||||
timestep=timestep,
|
||||
debug=False,
|
||||
chunked_conv=chunked_conv,
|
||||
)
|
||||
mx.eval(tile_output)
|
||||
|
||||
# Clear tile_latents reference
|
||||
@@ -385,13 +459,15 @@ def decode_with_tiling(
|
||||
w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask
|
||||
|
||||
blend_mask = (
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1) *
|
||||
h_mask_slice.reshape(1, 1, 1, -1, 1) *
|
||||
w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
t_mask_slice.reshape(1, 1, -1, 1, 1)
|
||||
* h_mask_slice.reshape(1, 1, 1, -1, 1)
|
||||
* w_mask_slice.reshape(1, 1, 1, 1, -1)
|
||||
)
|
||||
|
||||
# Slice tile output to match
|
||||
tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32)
|
||||
tile_output_slice = tile_output[
|
||||
:, :, :actual_t, :actual_h, :actual_w
|
||||
].astype(mx.float32)
|
||||
|
||||
# Clear full tile_output
|
||||
del tile_output
|
||||
@@ -409,11 +485,37 @@ def decode_with_tiling(
|
||||
weighted_tile = tile_output_slice * blend_mask
|
||||
|
||||
# Update output using slice assignment
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
output[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ weighted_tile
|
||||
)
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = (
|
||||
weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
] = (
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
t_out_start:t_out_end,
|
||||
h_out_start:h_out_end,
|
||||
w_out_start:w_out_end,
|
||||
]
|
||||
+ blend_mask
|
||||
)
|
||||
|
||||
# Force evaluation to free memory
|
||||
@@ -445,10 +547,12 @@ def decode_with_tiling(
|
||||
if next_tile_start_latent == 0:
|
||||
next_tile_start_out = 0
|
||||
else:
|
||||
next_tile_start_out = 1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
next_tile_start_out = (
|
||||
1 + (next_tile_start_latent - 1) * temporal_scale
|
||||
)
|
||||
|
||||
# We need to track how many frames we've already emitted
|
||||
if not hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if not hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
decode_with_tiling._emitted_frames = 0
|
||||
emitted = decode_with_tiling._emitted_frames
|
||||
|
||||
@@ -456,7 +560,10 @@ def decode_with_tiling(
|
||||
# Normalize and emit frames [emitted, next_tile_start_out)
|
||||
finalized_weights = weights[:, :, emitted:next_tile_start_out, :, :]
|
||||
finalized_weights = mx.maximum(finalized_weights, 1e-8)
|
||||
finalized_output = output[:, :, emitted:next_tile_start_out, :, :] / finalized_weights
|
||||
finalized_output = (
|
||||
output[:, :, emitted:next_tile_start_out, :, :]
|
||||
/ finalized_weights
|
||||
)
|
||||
finalized_output = finalized_output.astype(latents.dtype)
|
||||
mx.eval(finalized_output)
|
||||
|
||||
@@ -473,7 +580,7 @@ def decode_with_tiling(
|
||||
|
||||
# Emit remaining frames if callback provided
|
||||
if on_frames_ready is not None:
|
||||
emitted = getattr(decode_with_tiling, '_emitted_frames', 0)
|
||||
emitted = getattr(decode_with_tiling, "_emitted_frames", 0)
|
||||
if emitted < out_f:
|
||||
remaining_output = output[:, :, emitted:, :, :].astype(latents.dtype)
|
||||
mx.eval(remaining_output)
|
||||
@@ -481,7 +588,7 @@ def decode_with_tiling(
|
||||
del remaining_output
|
||||
|
||||
# Reset emitted frames counter for next call
|
||||
if hasattr(decode_with_tiling, '_emitted_frames'):
|
||||
if hasattr(decode_with_tiling, "_emitted_frames"):
|
||||
del decode_with_tiling._emitted_frames
|
||||
|
||||
# Clean up weights
|
||||
|
||||
Reference in New Issue
Block a user