From e4cdbb7eab52ec17af0b6245a42b078c7312d59a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 07:51:54 +0100 Subject: [PATCH 1/4] add vae tiling --- mlx_video/generate.py | 50 ++- mlx_video/generate_av.py | 40 +- mlx_video/models/ltx/text_encoder.py | 2 +- mlx_video/models/ltx/video_vae/__init__.py | 5 + mlx_video/models/ltx/video_vae/decoder.py | 70 +++ mlx_video/models/ltx/video_vae/tiling.py | 470 +++++++++++++++++++++ 6 files changed, 632 insertions(+), 5 deletions(-) create mode 100644 mlx_video/models/ltx/video_vae/tiling.py diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 8a6c5d5..55f0965 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -27,6 +27,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_vae_encoder from mlx_video.utils import to_denoised, load_image, prepare_image_for_encoding from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder +from mlx_video.models.ltx.video_vae.tiling import TilingConfig from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning from mlx_video.conditioning.latent import LatentState, create_initial_state, apply_denoise_mask, add_noise_with_state @@ -207,6 +208,7 @@ def generate_video( image: Optional[str] = None, image_strength: float = 1.0, image_frame_idx: int = 0, + tiling: str = "auto", ): """Generate video from text prompt, optionally conditioned on an image. @@ -228,6 +230,14 @@ def generate_video( image: Path to conditioning image for I2V (Image-to-Video) image_strength: Conditioning strength (1.0 = full denoise, 0.0 = keep original) image_frame_idx: Frame index to condition (0 = first frame) + tiling: Tiling mode for VAE decoding. Options: + - "auto": Automatically determine based on video size (default) + - "none": Disable tiling + - "default": 512px spatial, 64 frame temporal + - "aggressive": 256px spatial, 32 frame temporal (lowest memory) + - "conservative": 768px spatial, 96 frame temporal (faster) + - "spatial": Spatial tiling only + - "temporal": Temporal tiling only """ start_time = time.time() @@ -435,9 +445,36 @@ def generate_video( del transformer mx.clear_cache() - # Decode to video + # Decode to video with tiling print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") - video = vae_decoder(latents) + + # Select tiling configuration + if tiling == "none": + tiling_config = None + elif tiling == "auto": + tiling_config = TilingConfig.auto(height, width, num_frames) + elif tiling == "default": + tiling_config = TilingConfig.default() + elif tiling == "aggressive": + tiling_config = TilingConfig.aggressive() + elif tiling == "conservative": + tiling_config = TilingConfig.conservative() + elif tiling == "spatial": + tiling_config = TilingConfig.spatial_only() + elif tiling == "temporal": + tiling_config = TilingConfig.temporal_only() + else: + print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") + tiling_config = TilingConfig.auto(height, width, num_frames) + + if tiling_config is not None: + spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" + temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" + print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") + video = vae_decoder.decode_tiled(latents, tiling_config=tiling_config, debug=verbose) + else: + print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") + video = vae_decoder(latents) mx.eval(video) mx.clear_cache() @@ -594,6 +631,15 @@ Examples: default=0, help="Frame index to condition for I2V (0 = first frame, default: 0)" ) + parser.add_argument( + "--tiling", + type=str, + default="auto", + choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], + help="Tiling mode for VAE decoding (default: auto). " + "auto=based on video size, none=disabled, default=512px/64f, " + "aggressive=256px/32f (lowest memory), conservative=768px/96f, spatial=spatial only, temporal=temporal only" + ) args = parser.parse_args() generate_video( diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py index a481c23..b7cba4a 100644 --- a/mlx_video/generate_av.py +++ b/mlx_video/generate_av.py @@ -30,6 +30,7 @@ from mlx_video.convert import sanitize_transformer_weights, sanitize_audio_vae_w from mlx_video.utils import to_denoised, get_model_path, load_image, prepare_image_for_encoding from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder +from mlx_video.models.ltx.video_vae.tiling import TilingConfig from mlx_video.models.ltx.upsampler import load_upsampler, upsample_latents from mlx_video.conditioning import VideoConditionByLatentIndex, apply_conditioning from mlx_video.conditioning.latent import LatentState, apply_denoise_mask @@ -363,6 +364,7 @@ def generate_video_with_audio( image: Optional[str] = None, image_strength: float = 1.0, image_frame_idx: int = 0, + tiling: str = "auto", ): """Generate video with synchronized audio from text prompt, optionally conditioned on an image. @@ -384,6 +386,7 @@ def generate_video_with_audio( image: Path to conditioning image for I2V image_strength: Conditioning strength (1.0 = full denoise) image_frame_idx: Frame index to condition (0 = first frame) + tiling: Tiling mode for VAE decoding (auto/none/default/aggressive/conservative/spatial/temporal) """ start_time = time.time() @@ -623,9 +626,36 @@ def generate_video_with_audio( del transformer mx.clear_cache() - # Decode video + # Decode video with tiling print(f"{Colors.BLUE}🎞️ Decoding video...{Colors.RESET}") - video = vae_decoder(video_latents) + + # Select tiling configuration + if tiling == "none": + tiling_config = None + elif tiling == "auto": + tiling_config = TilingConfig.auto(height, width, num_frames) + elif tiling == "default": + tiling_config = TilingConfig.default() + elif tiling == "aggressive": + tiling_config = TilingConfig.aggressive() + elif tiling == "conservative": + tiling_config = TilingConfig.conservative() + elif tiling == "spatial": + tiling_config = TilingConfig.spatial_only() + elif tiling == "temporal": + tiling_config = TilingConfig.temporal_only() + else: + print(f"{Colors.YELLOW} Unknown tiling mode '{tiling}', using auto{Colors.RESET}") + tiling_config = TilingConfig.auto(height, width, num_frames) + + if tiling_config is not None: + spatial_info = f"{tiling_config.spatial_config.tile_size_in_pixels}px" if tiling_config.spatial_config else "none" + temporal_info = f"{tiling_config.temporal_config.tile_size_in_frames}f" if tiling_config.temporal_config else "none" + print(f"{Colors.DIM} Tiling ({tiling}): spatial={spatial_info}, temporal={temporal_info}{Colors.RESET}") + video = vae_decoder.decode_tiled(video_latents, tiling_config=tiling_config, debug=verbose) + else: + print(f"{Colors.DIM} Tiling: disabled{Colors.RESET}") + video = vae_decoder(video_latents) mx.eval(video) # Convert video to uint8 frames @@ -762,6 +792,11 @@ Examples: help="Conditioning strength for I2V (1.0 = full denoise, 0.0 = keep original, default: 1.0)") parser.add_argument("--image-frame-idx", type=int, default=0, help="Frame index to condition for I2V (0 = first frame, default: 0)") + parser.add_argument("--tiling", type=str, default="auto", + choices=["auto", "none", "default", "aggressive", "conservative", "spatial", "temporal"], + help="Tiling mode for VAE decoding (default: auto). " + "auto=based on size, none=disabled, default=512px/64f, " + "aggressive=256px/32f (lowest memory), conservative=768px/96f") args = parser.parse_args() @@ -783,6 +818,7 @@ Examples: image=args.image, image_strength=args.image_strength, image_frame_idx=args.image_frame_idx, + tiling=args.tiling, ) diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index bcd7cf4..29993fb 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -918,7 +918,7 @@ class LTX2TextEncoder(nn.Module): if response.token == 1 or response.token == 107: # EOS tokens break - + mx.clear_cache() # Decode only the new tokens diff --git a/mlx_video/models/ltx/video_vae/__init__.py b/mlx_video/models/ltx/video_vae/__init__.py index d620f17..bac1644 100644 --- a/mlx_video/models/ltx/video_vae/__init__.py +++ b/mlx_video/models/ltx/video_vae/__init__.py @@ -1,3 +1,8 @@ from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder +from mlx_video.models.ltx.video_vae.tiling import ( + TilingConfig, + SpatialTilingConfig, + TemporalTilingConfig, +) diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 1bc0983..390a92b 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -23,6 +23,7 @@ import mlx.nn as nn from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType from mlx_video.models.ltx.video_vae.ops import unpatchify from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample +from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling def get_timestep_embedding( @@ -444,6 +445,75 @@ class LTX2VideoDecoder(nn.Module): return x + def decode_tiled( + self, + sample: mx.array, + tiling_config: Optional[TilingConfig] = None, + causal: bool = False, + timestep: Optional[mx.array] = None, + debug: bool = False, + ) -> mx.array: + """Decode latents using tiling to reduce memory usage. + + This method is useful for decoding large videos that would otherwise + cause out-of-memory errors. It divides the latents into tiles, + decodes each tile separately, and blends them together. + + Args: + sample: Input latents of shape (B, C, F, H, W). + tiling_config: Tiling configuration. If None, uses TilingConfig.default(). + causal: Whether to use causal convolutions. + timestep: Optional timestep for conditioning. + debug: Whether to print debug info. + + Returns: + Decoded video of shape (B, 3, F*8, H*8, W*8). + """ + if tiling_config is None: + tiling_config = TilingConfig.default() + + # Check if tiling is actually needed + _, _, f, h, w = sample.shape + needs_spatial_tiling = False + needs_temporal_tiling = False + + # Spatial scale is 32 (8x VAE upsample + 4x unpatchify) + # Temporal scale is 8 + spatial_scale = 32 + temporal_scale = 8 + + if tiling_config.spatial_config is not None: + s_cfg = tiling_config.spatial_config + tile_size_latent = s_cfg.tile_size_in_pixels // spatial_scale + if h > tile_size_latent or w > tile_size_latent: + needs_spatial_tiling = True + + if tiling_config.temporal_config is not None: + t_cfg = tiling_config.temporal_config + tile_size_latent = t_cfg.tile_size_in_frames // temporal_scale + if f > tile_size_latent: + needs_temporal_tiling = True + + if not needs_spatial_tiling and not needs_temporal_tiling: + # No tiling needed, use regular decode + if debug: + print("[Tiling] Input fits within tile size, using regular decode") + return self(sample, causal=causal, timestep=timestep, debug=debug) + + if debug: + print(f"[Tiling] Using tiled decode (spatial={needs_spatial_tiling}, temporal={needs_temporal_tiling})") + + return decode_with_tiling( + decoder_fn=self, + latents=sample, + tiling_config=tiling_config, + spatial_scale=32, # VAE spatial: 8x upsampling + 4x unpatchify = 32x + temporal_scale=8, # VAE temporal upsampling factor + causal=causal, + timestep=timestep, + debug=debug, + ) + def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder: from pathlib import Path diff --git a/mlx_video/models/ltx/video_vae/tiling.py b/mlx_video/models/ltx/video_vae/tiling.py new file mode 100644 index 0000000..20950fc --- /dev/null +++ b/mlx_video/models/ltx/video_vae/tiling.py @@ -0,0 +1,470 @@ +"""VAE Tiling Configuration for decoding large videos. + +Implements spatial and temporal tiling with trapezoidal blending masks +to decode large videos without running out of memory. + +Default configuration (from PyTorch): +- Spatial: 512px tiles with 64px overlap +- Temporal: 64 frames with 24 frame overlap +""" + +from dataclasses import dataclass, replace +from typing import List, Optional, Tuple + +import mlx.core as mx + + +def compute_trapezoidal_mask_1d( + length: int, + ramp_left: int, + ramp_right: int, + left_starts_from_0: bool = False, +) -> mx.array: + """Generate a 1D trapezoidal blending mask with linear ramps. + + Args: + length: Output length of the mask. + ramp_left: Fade-in length on the left. + ramp_right: Fade-out length on the right. + left_starts_from_0: Whether the ramp starts from 0 or first non-zero value. + Useful for temporal tiles where the first tile is causal. + + Returns: + A 1D array of shape (length,) with values in [0, 1]. + """ + if length <= 0: + raise ValueError("Mask length must be positive.") + + ramp_left = max(0, min(ramp_left, length)) + ramp_right = max(0, min(ramp_right, length)) + + # Start with ones + mask = [1.0] * length + + # Apply left ramp (fade in) + if ramp_left > 0: + interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2 + # Create fade_in values using linspace logic + fade_in_full = [i / (interval_length - 1) for i in range(interval_length)] + fade_in = fade_in_full[:-1] # Remove last element + if not left_starts_from_0: + fade_in = fade_in[1:] # Remove first element too + for i in range(min(ramp_left, len(fade_in))): + mask[i] *= fade_in[i] + + # Apply right ramp (fade out) + if ramp_right > 0: + # Create fade_out: linspace(1, 0, ramp_right + 2)[1:-1] + fade_out = [(ramp_right + 1 - i) / (ramp_right + 1) for i in range(1, ramp_right + 1)] + for i in range(ramp_right): + mask[length - ramp_right + i] *= fade_out[i] + + return mx.clip(mx.array(mask), 0, 1) + + +@dataclass(frozen=True) +class SpatialTilingConfig: + """Configuration for dividing each frame into spatial tiles with optional overlap.""" + + tile_size_in_pixels: int + tile_overlap_in_pixels: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_pixels < 64: + raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") + if self.tile_size_in_pixels % 32 != 0: + raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") + if self.tile_overlap_in_pixels % 32 != 0: + raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") + if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" + ) + + +@dataclass(frozen=True) +class TemporalTilingConfig: + """Configuration for dividing a video into temporal tiles.""" + + tile_size_in_frames: int + tile_overlap_in_frames: int = 0 + + def __post_init__(self) -> None: + if self.tile_size_in_frames < 16: + raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") + if self.tile_size_in_frames % 8 != 0: + raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") + if self.tile_overlap_in_frames % 8 != 0: + raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") + if self.tile_overlap_in_frames >= self.tile_size_in_frames: + raise ValueError( + f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" + ) + + +@dataclass(frozen=True) +class TilingConfig: + """Configuration for splitting video into tiles with optional overlap.""" + + spatial_config: Optional[SpatialTilingConfig] = None + temporal_config: Optional[TemporalTilingConfig] = None + + @classmethod + def default(cls) -> "TilingConfig": + """Default tiling: 512px spatial, 64 frame temporal.""" + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), + ) + + @classmethod + def spatial_only(cls, tile_size: int = 512, overlap: int = 64) -> "TilingConfig": + """Spatial tiling only (for short videos with large resolution).""" + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=overlap), + temporal_config=None, + ) + + @classmethod + def temporal_only(cls, tile_size: int = 64, overlap: int = 24) -> "TilingConfig": + """Temporal tiling only (for long videos with small resolution).""" + return cls( + spatial_config=None, + temporal_config=TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap), + ) + + @classmethod + def aggressive(cls) -> "TilingConfig": + """Aggressive tiling for very large videos (smaller tiles, much lower memory).""" + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=256, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=32, tile_overlap_in_frames=8), + ) + + @classmethod + def conservative(cls) -> "TilingConfig": + """Conservative tiling (larger tiles, less memory savings but faster).""" + return cls( + spatial_config=SpatialTilingConfig(tile_size_in_pixels=768, tile_overlap_in_pixels=64), + temporal_config=TemporalTilingConfig(tile_size_in_frames=96, tile_overlap_in_frames=24), + ) + + @classmethod + def auto( + cls, + height: int, + width: int, + num_frames: int, + spatial_threshold: int = 512, + temporal_threshold: int = 65, + ) -> Optional["TilingConfig"]: + """Automatically determine tiling config based on video dimensions. + + Args: + height: Video height in pixels + width: Video width in pixels + num_frames: Number of video frames + spatial_threshold: Enable spatial tiling if either dimension exceeds this + temporal_threshold: Enable temporal tiling if frames exceed this + + Returns: + TilingConfig if tiling is needed, None otherwise + """ + needs_spatial = height > spatial_threshold or width > spatial_threshold + needs_temporal = num_frames > temporal_threshold + + if not needs_spatial and not needs_temporal: + return None + + # Estimate memory requirement (rough heuristic) + # Output size in bytes (float32): B * 3 * F * H * W * 4 + estimated_output_gb = (3 * num_frames * height * width * 4) / (1024**3) + + # For very large videos, use aggressive tiling + if estimated_output_gb > 2.0 or (height * width > 768 * 1024 and num_frames > 100): + return cls.aggressive() + + spatial_config = None + temporal_config = None + + if needs_spatial: + # Choose tile size based on resolution + max_dim = max(height, width) + if max_dim > 1024: + tile_size = 384 # Smaller tiles for very large resolutions + elif max_dim > 768: + tile_size = 512 + else: + tile_size = 384 + spatial_config = SpatialTilingConfig(tile_size_in_pixels=tile_size, tile_overlap_in_pixels=64) + + if needs_temporal: + # Choose tile size based on frame count + if num_frames > 200: + tile_size, overlap = 32, 8 # Aggressive for very long videos + elif num_frames > 100: + tile_size, overlap = 48, 16 + else: + tile_size, overlap = 64, 24 + temporal_config = TemporalTilingConfig(tile_size_in_frames=tile_size, tile_overlap_in_frames=overlap) + + return cls(spatial_config=spatial_config, temporal_config=temporal_config) + + +@dataclass +class DimensionIntervals: + """Intervals for splitting a single dimension.""" + starts: List[int] + ends: List[int] + left_ramps: List[int] + right_ramps: List[int] + + +def split_in_spatial(size: int, overlap: int, dimension_size: int) -> DimensionIntervals: + """Split a spatial dimension into intervals.""" + if dimension_size <= size: + return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]) + + amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) + starts = [i * (size - overlap) for i in range(amount)] + ends = [start + size for start in starts] + ends[-1] = dimension_size + left_ramps = [0] + [overlap] * (amount - 1) + right_ramps = [overlap] * (amount - 1) + [0] + + return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) + + +def split_in_temporal(size: int, overlap: int, dimension_size: int) -> DimensionIntervals: + """Split a temporal dimension into intervals with causal adjustment.""" + if dimension_size <= size: + return DimensionIntervals(starts=[0], ends=[dimension_size], left_ramps=[0], right_ramps=[0]) + + # Start with spatial split + intervals = split_in_spatial(size, overlap, dimension_size) + + # Adjust for temporal: starts[1:] -= 1, left_ramps[1:] += 1 + starts = intervals.starts.copy() + left_ramps = intervals.left_ramps.copy() + + for i in range(1, len(starts)): + starts[i] = starts[i] - 1 + left_ramps[i] = left_ramps[i] + 1 + + return DimensionIntervals(starts=starts, ends=intervals.ends, left_ramps=left_ramps, right_ramps=intervals.right_ramps) + + +def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]: + """Map temporal latent interval to output coordinates and mask.""" + start = begin * scale + stop = 1 + (end - 1) * scale + left_ramp_scaled = 1 + (left_ramp - 1) * scale if left_ramp > 0 else 0 + right_ramp_scaled = right_ramp * scale + + mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, True) + return slice(start, stop), mask + + +def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, mx.array]: + """Map spatial latent interval to output coordinates and mask.""" + start = begin * scale + stop = end * scale + left_ramp_scaled = left_ramp * scale + right_ramp_scaled = right_ramp * scale + + mask = compute_trapezoidal_mask_1d(stop - start, left_ramp_scaled, right_ramp_scaled, False) + return slice(start, stop), mask + + +def decode_with_tiling( + decoder_fn, + latents: mx.array, + tiling_config: TilingConfig, + spatial_scale: int = 32, + temporal_scale: int = 8, + causal: bool = False, + timestep: Optional[mx.array] = None, + debug: bool = False, +) -> mx.array: + """Decode latents using tiling to reduce memory usage. + + Args: + decoder_fn: Decoder function to call for each tile. + latents: Input latents of shape (B, C, F, H, W). + tiling_config: Tiling configuration. + spatial_scale: Spatial scale factor (32 for LTX VAE: 8x upsample + 4x unpatchify). + temporal_scale: Temporal scale factor (8 for LTX VAE). + causal: Whether to use causal convolutions. + timestep: Optional timestep for conditioning. + debug: Whether to print debug info. + + Returns: + Decoded video. + """ + import gc + + b, c, f_latent, h_latent, w_latent = latents.shape + + # Compute output shape + out_f = 1 + (f_latent - 1) * temporal_scale + out_h = h_latent * spatial_scale + out_w = w_latent * spatial_scale + + # Get tile size and overlap in latent space + if tiling_config.spatial_config is not None: + s_cfg = tiling_config.spatial_config + spatial_tile_size = s_cfg.tile_size_in_pixels // spatial_scale + spatial_overlap = s_cfg.tile_overlap_in_pixels // spatial_scale + else: + spatial_tile_size = max(h_latent, w_latent) + spatial_overlap = 0 + + if tiling_config.temporal_config is not None: + t_cfg = tiling_config.temporal_config + temporal_tile_size = t_cfg.tile_size_in_frames // temporal_scale + temporal_overlap = t_cfg.tile_overlap_in_frames // temporal_scale + else: + temporal_tile_size = f_latent + temporal_overlap = 0 + + # Compute intervals for each dimension + temporal_intervals = split_in_temporal(temporal_tile_size, temporal_overlap, f_latent) + height_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, h_latent) + width_intervals = split_in_spatial(spatial_tile_size, spatial_overlap, w_latent) + + num_t_tiles = len(temporal_intervals.starts) + num_h_tiles = len(height_intervals.starts) + num_w_tiles = len(width_intervals.starts) + total_tiles = num_t_tiles * num_h_tiles * num_w_tiles + + if debug: + print(f"[Tiling] Latent shape: {latents.shape}, Output shape: ({b}, 3, {out_f}, {out_h}, {out_w})") + print(f"[Tiling] Tiles: {num_t_tiles} temporal x {num_h_tiles} height x {num_w_tiles} width = {total_tiles}") + + # Initialize output and weight accumulator + # Use float32 for accumulation to avoid precision issues + output = mx.zeros((b, 3, out_f, out_h, out_w), dtype=mx.float32) + weights = mx.zeros((b, 1, out_f, out_h, out_w), dtype=mx.float32) + mx.eval(output, weights) + + tile_idx = 0 + for t_idx in range(num_t_tiles): + t_start = temporal_intervals.starts[t_idx] + t_end = temporal_intervals.ends[t_idx] + t_left = temporal_intervals.left_ramps[t_idx] + t_right = temporal_intervals.right_ramps[t_idx] + + # Map temporal coordinates + out_t_slice, t_mask = map_temporal_slice(t_start, t_end, t_left, t_right, temporal_scale) + + for h_idx in range(num_h_tiles): + h_start = height_intervals.starts[h_idx] + h_end = height_intervals.ends[h_idx] + h_left = height_intervals.left_ramps[h_idx] + h_right = height_intervals.right_ramps[h_idx] + + # Map height coordinates + out_h_slice, h_mask = map_spatial_slice(h_start, h_end, h_left, h_right, spatial_scale) + + for w_idx in range(num_w_tiles): + w_start = width_intervals.starts[w_idx] + w_end = width_intervals.ends[w_idx] + w_left = width_intervals.left_ramps[w_idx] + w_right = width_intervals.right_ramps[w_idx] + + # Map width coordinates + out_w_slice, w_mask = map_spatial_slice(w_start, w_end, w_left, w_right, spatial_scale) + + if debug: + print(f"[Tiling] Tile {tile_idx + 1}/{total_tiles}: " + f"latent t=[{t_start},{t_end}) h=[{h_start},{h_end}) w=[{w_start},{w_end})") + + # Extract tile latents (small slice) + tile_latents = latents[:, :, t_start:t_end, h_start:h_end, w_start:w_end] + + # Decode tile + tile_output = decoder_fn(tile_latents, causal=causal, timestep=timestep, debug=False) + mx.eval(tile_output) + + # Clear tile_latents reference + del tile_latents + + # Get actual decoded dimensions + _, _, decoded_t, decoded_h, decoded_w = tile_output.shape + expected_t = out_t_slice.stop - out_t_slice.start + expected_h = out_h_slice.stop - out_h_slice.start + expected_w = out_w_slice.stop - out_w_slice.start + + # Handle potential size mismatches (use minimum) + actual_t = min(decoded_t, expected_t) + actual_h = min(decoded_h, expected_h) + actual_w = min(decoded_w, expected_w) + + # Build blend mask + t_mask_slice = t_mask[:actual_t] if len(t_mask) > actual_t else t_mask + h_mask_slice = h_mask[:actual_h] if len(h_mask) > actual_h else h_mask + w_mask_slice = w_mask[:actual_w] if len(w_mask) > actual_w else w_mask + + blend_mask = ( + t_mask_slice.reshape(1, 1, -1, 1, 1) * + h_mask_slice.reshape(1, 1, 1, -1, 1) * + w_mask_slice.reshape(1, 1, 1, 1, -1) + ) + + # Slice tile output to match + tile_output_slice = tile_output[:, :, :actual_t, :actual_h, :actual_w].astype(mx.float32) + + # Clear full tile_output + del tile_output + + # Compute output coordinates + t_out_start = out_t_slice.start + t_out_end = t_out_start + actual_t + h_out_start = out_h_slice.start + h_out_end = h_out_start + actual_h + w_out_start = out_w_slice.start + w_out_end = w_out_start + actual_w + + # Use direct slice assignment (MLX supports this) + # Weighted accumulation + weighted_tile = tile_output_slice * blend_mask + + # Update output using slice assignment + output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( + output[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + weighted_tile + ) + weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] = ( + weights[:, :, t_out_start:t_out_end, h_out_start:h_out_end, w_out_start:w_out_end] + blend_mask + ) + + # Force evaluation to free memory + mx.eval(output, weights) + + # Clean up tile-specific arrays + del tile_output_slice, weighted_tile, blend_mask + del t_mask_slice, h_mask_slice, w_mask_slice + + tile_idx += 1 + + # Periodic garbage collection and cache clearing + if tile_idx % 4 == 0: + gc.collect() + try: + mx.clear_cache() + except Exception: + pass # May not be available on all platforms + + # Normalize by weights + weights = mx.maximum(weights, 1e-8) + output = output / weights + mx.eval(output) + + # Clean up weights + del weights + gc.collect() + + if debug: + print(f"[Tiling] Done. Final shape: {output.shape}") + + # Convert back to original dtype if needed + return output.astype(latents.dtype) From 883c6b0ad8e615a6922fc3455d3f482235c55ec6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 13:03:48 +0100 Subject: [PATCH 2/4] ensure dtype cast --- mlx_video/conditioning/latent.py | 9 ++++-- mlx_video/models/ltx/ltx.py | 11 +++++--- mlx_video/models/ltx/rope.py | 12 ++++++-- mlx_video/models/ltx/text_encoder.py | 34 +++++++++++++---------- mlx_video/models/ltx/video_vae/decoder.py | 3 +- mlx_video/utils.py | 15 ++++++---- 6 files changed, 52 insertions(+), 32 deletions(-) diff --git a/mlx_video/conditioning/latent.py b/mlx_video/conditioning/latent.py index 1825e3d..acf3d99 100644 --- a/mlx_video/conditioning/latent.py +++ b/mlx_video/conditioning/latent.py @@ -95,6 +95,7 @@ def apply_conditioning( Updated LatentState with conditioning applied """ state = state.clone() + dtype = state.latent.dtype b, c, f, h, w = state.latent.shape for cond in conditionings: @@ -132,7 +133,7 @@ def apply_conditioning( latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) # Set mask: 1.0 - strength means less denoising for conditioned frames - mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength)) + mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype)) else: # Keep original latent_list.append(state.latent[:, :, i:i+1]) @@ -161,7 +162,8 @@ def apply_denoise_mask( Returns: Blended latent """ - return denoised * denoise_mask + clean * (1.0 - denoise_mask) + one = mx.array(1.0, dtype=denoised.dtype) + return denoised * denoise_mask + clean * (one - denoise_mask) def add_noise_with_state( @@ -191,6 +193,7 @@ def add_noise_with_state( # But we scale sigma by the mask for conditioned regions effective_scale = noise_scale * state.denoise_mask - state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale) + one = mx.array(1.0, dtype=state.latent.dtype) + state.latent = noise * effective_scale + state.latent * (one - effective_scale) return state diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index 75987bc..a3eef42 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -52,10 +52,11 @@ class TransformerArgsPreprocessor: self, timestep: mx.array, batch_size: int, + hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * self.timestep_scale_multiplier - timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1)) + timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) # Reshape to (batch, tokens, dim) timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) @@ -117,7 +118,7 @@ class TransformerArgsPreprocessor: def prepare(self, modality: Modality) -> TransformerArgs: x = self.patchify_proj(modality.latent) - timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0]) + timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) pe = self._prepare_positional_embeddings( @@ -201,6 +202,7 @@ class MultiModalTransformerArgsPreprocessor: timestep=modality.timesteps, timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, batch_size=transformer_args.x.shape[0], + hidden_dtype=transformer_args.x.dtype, ) return replace( @@ -215,15 +217,16 @@ class MultiModalTransformerArgsPreprocessor: timestep: mx.array, timestep_scale_multiplier: int, batch_size: int, + hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier - scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1)) + scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])) - gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor) + gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype) gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1])) return scale_shift_timestep, gate_timestep diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index 54b721a..a00d019 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -128,6 +128,7 @@ def apply_split_rotary_emb( Returns: Tensor with split rotary embeddings applied """ + input_dtype = input_tensor.dtype needs_reshape = False original_shape = input_tensor.shape @@ -139,6 +140,11 @@ def apply_split_rotary_emb( input_tensor = mx.swapaxes(input_tensor, 1, 2) needs_reshape = True + # Cast to float32 for computation precision + input_tensor = input_tensor.astype(mx.float32) + cos_freqs = cos_freqs.astype(mx.float32) + sin_freqs = sin_freqs.astype(mx.float32) + # Split into two halves: (..., dim) -> (..., 2, dim//2) dim = input_tensor.shape[-1] split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2)) @@ -167,7 +173,7 @@ def apply_split_rotary_emb( output = mx.swapaxes(output, 1, 2) output = mx.reshape(output, (b, t, h * d)) - return output + return output.astype(input_dtype) def generate_freq_grid( @@ -424,8 +430,8 @@ def _precompute_freqs_cis_double_precision( rope_type: LTXRopeType, ) -> Tuple[mx.array, mx.array]: - # Convert to numpy float64 - indices_grid_np = np.array(indices_grid).astype(np.float64) + # Convert to numpy float64 (first to float32 for numpy compatibility) + indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) # Generate frequency indices in float64 n_pos_dims = indices_grid_np.shape[1] diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 29993fb..d6461d5 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -273,6 +273,13 @@ class ConnectorAttention(nn.Module): Returns: Tensor with SPLIT rotary embeddings applied """ + input_dtype = x.dtype + + # Cast to float32 for precision, then cast back + x = x.astype(mx.float32) + cos_freq = cos_freq.astype(mx.float32) + sin_freq = sin_freq.astype(mx.float32) + # Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2) half_dim = x.shape[-1] // 2 x1 = x[..., :half_dim] @@ -284,7 +291,7 @@ class ConnectorAttention(nn.Module): out1 = x1 * cos_freq - x2 * sin_freq out2 = x2 * cos_freq + x1 * sin_freq - return mx.concatenate([out1, out2], axis=-1) + return mx.concatenate([out1, out2], axis=-1).astype(input_dtype) class GEGLU(nn.Module): @@ -437,14 +444,15 @@ class Embeddings1DConnector(nn.Module): attention_mask: mx.array, ) -> Tuple[mx.array, mx.array]: batch_size, seq_len, dim = hidden_states.shape + dtype = hidden_states.dtype # Binary mask: 1 for valid tokens, 0 for padded # attention_mask is additive: 0 for valid, large negative for padded mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq) - # Tile registers to match sequence length + # Tile registers to match sequence length, cast to hidden_states dtype num_tiles = seq_len // self.num_learnable_registers - registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim) + registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim) # Process each batch item (PyTorch uses advanced indexing) result_list = [] @@ -462,7 +470,7 @@ class Embeddings1DConnector(nn.Module): # Pad with zeros on the right to get back to seq_len pad_length = seq_len - num_valid if pad_length > 0: - padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype) + padding = mx.zeros((pad_length, dim), dtype=dtype) adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim) else: adjusted = valid_tokens @@ -474,9 +482,8 @@ class Embeddings1DConnector(nn.Module): ], axis=0) # (seq,) # Combine: valid tokens at front, registers at back - flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1) + flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1) combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers - result_list.append(combined) hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim) @@ -491,7 +498,6 @@ class Embeddings1DConnector(nn.Module): hidden_states: mx.array, attention_mask: Optional[mx.array] = None, ) -> Tuple[mx.array, mx.array]: - # Replace padded tokens with learnable registers if self.num_learnable_registers > 0 and attention_mask is not None: hidden_states, attention_mask = self._replace_padded_with_registers( @@ -521,6 +527,7 @@ def norm_and_concat_hidden_states( # Stack hidden states: (batch, seq, dim, num_layers) stacked = mx.stack(hidden_states, axis=-1) + dtype = stacked.dtype b, t, d, num_layers = stacked.shape # Compute sequence lengths from attention mask @@ -536,16 +543,16 @@ def norm_and_concat_hidden_states( mask = token_indices >= start_indices # (B, T) mask = mask[:, :, None, None] # (B, T, 1, 1) - eps = 1e-6 + eps = mx.array(1e-6, dtype=dtype) - # Compute masked mean per layer + # Compute masked mean per layer - ensure dtype consistency masked = mx.where(mask, stacked, mx.zeros_like(stacked)) - denom = (sequence_lengths * d).reshape(b, 1, 1, 1) + denom = (sequence_lengths * d).reshape(b, 1, 1, 1).astype(dtype) mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps) # Compute masked min/max per layer - x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype)) - x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype)) + x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype)) + x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype)) x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True) x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True) range_val = x_max - x_min @@ -749,13 +756,10 @@ class LTX2TextEncoder(nn.Module): attention_mask = mx.array(inputs["attention_mask"]) _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True) - concat_hidden = norm_and_concat_hidden_states( all_hidden_states, attention_mask, padding_side="left" ) - features = self.feature_extractor(concat_hidden) - additive_mask = (attention_mask - 1).astype(features.dtype) additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 390a92b..0cb0d7b 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -348,10 +348,11 @@ class LTX2VideoDecoder(nn.Module): def denormalize(self, x: mx.array) -> mx.array: """Denormalize latents using per-channel statistics.""" + dtype = x.dtype # Cast to float32 for precision (statistics may be in bfloat16) mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1) - return x * std + mean + return (x * std + mean).astype(dtype) def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: """Apply pixel normalization.""" diff --git a/mlx_video/utils.py b/mlx_video/utils.py index 4b50536..aff48ed 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -44,10 +44,9 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict): class_predicate=get_class_predicate, ) - -@partial(mx.compile, shapeless=True) +@partial(mx.compile, shapeless=True) def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: - return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps) + return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps) @@ -71,9 +70,12 @@ def to_denoised( Denoised tensor x_0 """ if isinstance(sigma, (int, float)): - return noisy - sigma * velocity + # Convert to array with matching dtype to avoid float32 promotion + sigma_arr = mx.array(sigma, dtype=velocity.dtype) + return noisy - sigma_arr * velocity else: - # sigma is per-sample + # sigma is per-sample - ensure dtype matches + sigma = sigma.astype(velocity.dtype) while sigma.ndim < velocity.ndim: sigma = mx.expand_dims(sigma, axis=-1) return noisy - sigma * velocity @@ -251,6 +253,7 @@ def prepare_image_for_encoding( image: mx.array, target_height: int, target_width: int, + dtype: mx.Dtype = mx.float32, ) -> mx.array: """Prepare image for VAE encoding by resizing and normalizing. @@ -281,4 +284,4 @@ def prepare_image_for_encoding( image = mx.expand_dims(image, axis=0) # (1, 3, H, W) image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W) - return image + return image.astype(dtype) From 78244a2d664e0e9a15cb5af975c8f50da197b1d7 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 17:20:22 +0100 Subject: [PATCH 3/4] Cast dtype to bf16 in video and audio generation processes --- mlx_video/generate.py | 64 ++++++++++++++++----------- mlx_video/generate_av.py | 95 +++++++++++++++++++--------------------- mlx_video/utils.py | 3 +- 3 files changed, 86 insertions(+), 76 deletions(-) diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 55f0965..17f3770 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -1,9 +1,10 @@ import argparse import time from pathlib import Path -from typing import Optional, List, Tuple +from typing import Optional import mlx.core as mx +import mlx.nn as nn import numpy as np from PIL import Image from tqdm import tqdm @@ -110,6 +111,7 @@ def create_position_grid( # Convert temporal to time in seconds by dividing by fps pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / fps + # Always return float32 for RoPE precision - bfloat16 causes quality degradation return mx.array(pixel_coords, dtype=mx.float32) @@ -137,6 +139,7 @@ def denoise( Denoised latent tensor """ # If state is provided, use its latent (which may have conditioning applied) + dtype = latents.dtype if state is not None: latents = state.latent @@ -154,11 +157,11 @@ def denoise( denoise_mask_flat = mx.reshape(state.denoise_mask, (b, 1, f, 1, 1)) denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_tokens)) - # Per-token timesteps: sigma * mask - timesteps = sigma * denoise_mask_flat + # Per-token timesteps: sigma * mask (preserve dtype) + timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat else: - # All tokens get the same timestep - timesteps = mx.full((b, num_tokens), sigma) + # All tokens get the same timestep (use latent dtype) + timesteps = mx.full((b, num_tokens), sigma, dtype=dtype) video_modality = Modality( latent=latents_flat, @@ -181,8 +184,11 @@ def denoise( mx.eval(denoised) + # Euler step (preserve dtype by converting Python floats to arrays) if sigma_next > 0: - latents = denoised + sigma_next * (latents - denoised) / sigma + sigma_next_arr = mx.array(sigma_next, dtype=dtype) + sigma_arr = mx.array(sigma, dtype=dtype) + latents = denoised + sigma_next_arr * (latents - denoised) / sigma_arr else: latents = denoised mx.eval(latents) @@ -283,6 +289,7 @@ def generate_video( print(f"{Colors.DIM}Enhanced: {prompt[:150]}{'...' if len(prompt) > 150 else ''}{Colors.RESET}") text_embeddings, _ = text_encoder(prompt, return_audio_embeddings=False) + model_dtype = text_embeddings.dtype # bfloat16 from text encoder mx.eval(text_embeddings) del text_encoder @@ -292,6 +299,8 @@ def generate_video( print(f"{Colors.BLUE}🤖 Loading transformer...{Colors.RESET}") raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) + # Convert transformer weights to bfloat16 for memory efficiency + sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} config = LTXModelConfig( model_type=LTXModelType.VideoOnly, @@ -310,7 +319,7 @@ def generate_video( timestep_scale_multiplier=1000, ) - transformer = LTXModel(config) + transformer = LTXModel(config) transformer.load_weights(list(sanitized.items()), strict=False) mx.eval(transformer.parameters()) @@ -323,15 +332,15 @@ def generate_video( mx.eval(vae_encoder.parameters()) # Load and prepare image for stage 1 (half resolution) - input_image = load_image(image, height=height // 2, width=width // 2) - stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) stage1_image_latent = vae_encoder(stage1_image_tensor) mx.eval(stage1_image_latent) print(f" Stage 1 image latent: {stage1_image_latent.shape}") # Load and prepare image for stage 2 (full resolution) - input_image = load_image(image, height=height, width=width) - stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) stage2_image_latent = vae_encoder(stage2_image_tensor) mx.eval(stage2_image_latent) print(f" Stage 2 image latent: {stage2_image_latent.shape}") @@ -343,6 +352,7 @@ def generate_video( print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) + # Position grids stay float32 for RoPE precision positions = create_position_grid(1, latent_frames, stage1_h, stage1_w) mx.eval(positions) @@ -353,24 +363,26 @@ def generate_video( # Create initial state with zeros latent_shape = (1, 128, latent_frames, stage1_h, stage1_w) state1 = LatentState( - latent=mx.zeros(latent_shape), - clean_latent=mx.zeros(latent_shape), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + latent=mx.zeros(latent_shape, dtype=model_dtype), + clean_latent=mx.zeros(latent_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( latent=stage1_image_latent, frame_idx=image_frame_idx, strength=image_strength, ) + state1 = apply_conditioning(state1, [conditioning]) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # For Stage 1, noise_scale = 1.0 (first sigma) - noise = mx.random.normal(latent_shape) - noise_scale = STAGE_1_SIGMAS[0] # 1.0 + noise = mx.random.normal(latent_shape, dtype=model_dtype) + noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0 scaled_mask = state1.denoise_mask * noise_scale + state1 = LatentState( - latent=noise * scaled_mask + state1.latent * (1.0 - scaled_mask), + latent=noise * scaled_mask + state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state1.clean_latent, denoise_mask=state1.denoise_mask, ) @@ -378,7 +390,7 @@ def generate_video( mx.eval(latents) else: # T2V: just use random noise - latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w)) + latents = mx.random.normal((1, 128, latent_frames, stage1_h, stage1_w), dtype=model_dtype) mx.eval(latents) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_1_SIGMAS, verbose=verbose, state=state1) @@ -401,6 +413,7 @@ def generate_video( # Stage 2: Refine at full resolution print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") + # Position grids stay float32 for RoPE precision positions = create_position_grid(1, latent_frames, stage2_h, stage2_w) mx.eval(positions) @@ -411,7 +424,7 @@ def generate_video( state2 = LatentState( latent=latents, # Start with upscaled latent clean_latent=mx.zeros_like(latents), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( latent=stage2_image_latent, @@ -423,11 +436,11 @@ def generate_video( # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) # For Stage 2, noise_scale = stage_2_sigmas[0] # Conditioned frames (mask=0) keep image latent, unconditioned get partial noise - noise = mx.random.normal(latents.shape) - noise_scale = STAGE_2_SIGMAS[0] + noise = mx.random.normal(latents.shape).astype(model_dtype) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) scaled_mask = state2.denoise_mask * noise_scale state2 = LatentState( - latent=noise * scaled_mask + state2.latent * (1.0 - scaled_mask), + latent=noise * scaled_mask + state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=state2.clean_latent, denoise_mask=state2.denoise_mask, ) @@ -435,9 +448,10 @@ def generate_video( mx.eval(latents) else: # T2V: add noise to all frames for refinement - noise_scale = STAGE_2_SIGMAS[0] - noise = mx.random.normal(latents.shape) - latents = noise * noise_scale + latents * (1 - noise_scale) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + one_minus_scale = mx.array(1.0 - STAGE_2_SIGMAS[0], dtype=model_dtype) + noise = mx.random.normal(latents.shape).astype(model_dtype) + latents = noise * noise_scale + latents * one_minus_scale mx.eval(latents) latents = denoise(latents, positions, text_embeddings, transformer, STAGE_2_SIGMAS, verbose=verbose, state=state2) diff --git a/mlx_video/generate_av.py b/mlx_video/generate_av.py index b7cba4a..e0fb22b 100644 --- a/mlx_video/generate_av.py +++ b/mlx_video/generate_av.py @@ -3,7 +3,7 @@ import argparse import time from pathlib import Path -from typing import Optional, List +from typing import Optional import mlx.core as mx import numpy as np @@ -164,6 +164,7 @@ def denoise_av( Returns: Tuple of (video_latents, audio_latents) """ + dtype = video_latents.dtype # If video state is provided, use its latent if video_state is not None: video_latents = video_state.latent @@ -189,10 +190,10 @@ def denoise_av( denoise_mask_flat = mx.broadcast_to(denoise_mask_flat, (b, 1, f, h, w)) denoise_mask_flat = mx.reshape(denoise_mask_flat, (b, num_video_tokens)) # Per-token timesteps: sigma * mask - video_timesteps = sigma * denoise_mask_flat + video_timesteps = mx.array(sigma, dtype=dtype) * denoise_mask_flat else: # All tokens get the same timestep - video_timesteps = mx.full((b, num_video_tokens), sigma) + video_timesteps = mx.full((b, num_video_tokens), sigma, dtype=dtype) video_modality = Modality( latent=video_flat, @@ -205,7 +206,7 @@ def denoise_av( audio_modality = Modality( latent=audio_flat, - timesteps=mx.full((ab, at), sigma), + timesteps=mx.full((ab, at), sigma, dtype=dtype), positions=audio_positions, context=audio_embeddings, context_mask=None, @@ -230,10 +231,12 @@ def denoise_av( mx.eval(video_denoised, audio_denoised) - # Euler step + # Euler step - use dtype-preserving arrays to avoid float32 promotion if sigma_next > 0: - video_latents = video_denoised + sigma_next * (video_latents - video_denoised) / sigma - audio_latents = audio_denoised + sigma_next * (audio_latents - audio_denoised) / sigma + sigma_next_arr = mx.array(sigma_next, dtype=dtype) + sigma_arr = mx.array(sigma, dtype=dtype) + video_latents = video_denoised + sigma_next_arr * (video_latents - video_denoised) / sigma_arr + audio_latents = audio_denoised + sigma_next_arr * (audio_latents - audio_denoised) / sigma_arr else: video_latents = video_denoised audio_latents = audio_denoised @@ -435,6 +438,7 @@ def generate_video_with_audio( # Get both video and audio embeddings video_embeddings, audio_embeddings = text_encoder(prompt) + model_dtype = video_embeddings.dtype # bfloat16 from text encoder mx.eval(video_embeddings, audio_embeddings) del text_encoder @@ -445,6 +449,9 @@ def generate_video_with_audio( raw_weights = mx.load(str(model_path / 'ltx-2-19b-distilled.safetensors')) sanitized = sanitize_transformer_weights(raw_weights) + # Convert transformer weights to bfloat16 for memory efficiency + sanitized = {k: v.astype(mx.bfloat16) if v.dtype == mx.float32 else v for k, v in sanitized.items()} + config = LTXModelConfig( model_type=LTXModelType.AudioVideo, num_attention_heads=32, @@ -482,18 +489,16 @@ def generate_video_with_audio( mx.eval(vae_encoder.parameters()) # Load and prepare image for stage 1 (half resolution) - input_image = load_image(image, height=height // 2, width=width // 2) - stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2) + input_image = load_image(image, height=height // 2, width=width // 2, dtype=model_dtype) + stage1_image_tensor = prepare_image_for_encoding(input_image, height // 2, width // 2, dtype=model_dtype) stage1_image_latent = vae_encoder(stage1_image_tensor) mx.eval(stage1_image_latent) - print(f" Stage 1 image latent: {stage1_image_latent.shape}") # Load and prepare image for stage 2 (full resolution) - input_image = load_image(image, height=height, width=width) - stage2_image_tensor = prepare_image_for_encoding(input_image, height, width) + input_image = load_image(image, height=height, width=width, dtype=model_dtype) + stage2_image_tensor = prepare_image_for_encoding(input_image, height, width, dtype=model_dtype) stage2_image_latent = vae_encoder(stage2_image_tensor) mx.eval(stage2_image_latent) - print(f" Stage 2 image latent: {stage2_image_latent.shape}") del vae_encoder mx.clear_cache() @@ -502,9 +507,10 @@ def generate_video_with_audio( print(f"{Colors.YELLOW}⚡ Stage 1: Generating at {width//2}x{height//2} (8 steps)...{Colors.RESET}") mx.random.seed(seed) - # Create position grids - video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) - audio_positions = create_audio_position_grid(1, audio_frames) + # Create position grids - MUST stay float32 for RoPE precision + # bfloat16 positions cause quality degradation due to precision loss in sin/cos calculations + video_positions = create_video_position_grid(1, latent_frames, stage1_h, stage1_w) # float32 + audio_positions = create_audio_position_grid(1, audio_frames) # float32 mx.eval(video_positions, audio_positions) # Apply I2V conditioning for stage 1 if provided @@ -513,9 +519,9 @@ def generate_video_with_audio( if is_i2v and stage1_image_latent is not None: # PyTorch flow: create zeros -> apply conditioning -> apply noiser video_state1 = LatentState( - latent=mx.zeros(video_latent_shape), - clean_latent=mx.zeros(video_latent_shape), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + latent=mx.zeros(video_latent_shape, dtype=model_dtype), + clean_latent=mx.zeros(video_latent_shape, dtype=model_dtype), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( latent=stage1_image_latent, @@ -525,11 +531,11 @@ def generate_video_with_audio( video_state1 = apply_conditioning(video_state1, [conditioning]) # Apply noiser: latent = noise * (mask * noise_scale) + latent * (1 - mask * noise_scale) - noise = mx.random.normal(video_latent_shape) - noise_scale = STAGE_1_SIGMAS[0] # 1.0 + noise = mx.random.normal(video_latent_shape).astype(model_dtype) + noise_scale = mx.array(STAGE_1_SIGMAS[0], dtype=model_dtype) # 1.0 scaled_mask = video_state1.denoise_mask * noise_scale video_state1 = LatentState( - latent=noise * scaled_mask + video_state1.latent * (1.0 - scaled_mask), + latent=noise * scaled_mask + video_state1.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=video_state1.clean_latent, denoise_mask=video_state1.denoise_mask, ) @@ -537,11 +543,11 @@ def generate_video_with_audio( mx.eval(video_latents) else: # T2V: just use random noise - video_latents = mx.random.normal(video_latent_shape) + video_latents = mx.random.normal(video_latent_shape).astype(model_dtype) mx.eval(video_latents) # Audio always uses pure noise (no I2V for audio) - audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)) + audio_latents = mx.random.normal((1, AUDIO_LATENT_CHANNELS, audio_frames, AUDIO_MEL_BINS)).astype(model_dtype) mx.eval(audio_latents) # Stage 1 denoising @@ -571,7 +577,8 @@ def generate_video_with_audio( # Stage 2: Refine at full resolution print(f"{Colors.YELLOW}⚡ Stage 2: Refining at {width}x{height} (3 steps)...{Colors.RESET}") - video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) + # Position grids stay float32 for RoPE precision + video_positions = create_video_position_grid(1, latent_frames, stage2_h, stage2_w) # float32 mx.eval(video_positions) # Apply I2V conditioning for stage 2 if provided @@ -581,7 +588,7 @@ def generate_video_with_audio( video_state2 = LatentState( latent=video_latents, # Start with upscaled latent clean_latent=mx.zeros_like(video_latents), - denoise_mask=mx.ones((1, 1, latent_frames, 1, 1)), + denoise_mask=mx.ones((1, 1, latent_frames, 1, 1), dtype=model_dtype), ) conditioning = VideoConditionByLatentIndex( latent=stage2_image_latent, @@ -591,11 +598,11 @@ def generate_video_with_audio( video_state2 = apply_conditioning(video_state2, [conditioning]) # Apply noiser: conditioned frames (mask=0) keep image latent, unconditioned get partial noise - video_noise = mx.random.normal(video_latents.shape) - noise_scale = STAGE_2_SIGMAS[0] + video_noise = mx.random.normal(video_latents.shape).astype(model_dtype) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) scaled_mask = video_state2.denoise_mask * noise_scale video_state2 = LatentState( - latent=video_noise * scaled_mask + video_state2.latent * (1.0 - scaled_mask), + latent=video_noise * scaled_mask + video_state2.latent * (mx.array(1.0, dtype=model_dtype) - scaled_mask), clean_latent=video_state2.clean_latent, denoise_mask=video_state2.denoise_mask, ) @@ -603,16 +610,18 @@ def generate_video_with_audio( mx.eval(video_latents) # Audio still gets noise (no I2V for audio) - audio_noise = mx.random.normal(audio_latents.shape) - audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale mx.eval(audio_latents) else: # T2V: add noise to all frames for refinement - noise_scale = STAGE_2_SIGMAS[0] - video_noise = mx.random.normal(video_latents.shape) - audio_noise = mx.random.normal(audio_latents.shape) - video_latents = video_noise * noise_scale + video_latents * (1 - noise_scale) - audio_latents = audio_noise * noise_scale + audio_latents * (1 - noise_scale) + noise_scale = mx.array(STAGE_2_SIGMAS[0], dtype=model_dtype) + one_minus_scale = mx.array(1.0, dtype=model_dtype) - noise_scale + video_noise = mx.random.normal(video_latents.shape).astype(model_dtype) + audio_noise = mx.random.normal(audio_latents.shape).astype(model_dtype) + video_latents = video_noise * noise_scale + video_latents * one_minus_scale + audio_latents = audio_noise * noise_scale + audio_latents * one_minus_scale mx.eval(video_latents, audio_latents) video_latents, audio_latents = denoise_av( @@ -671,27 +680,13 @@ def generate_video_with_audio( vocoder = load_vocoder(model_path) mx.eval(audio_decoder.parameters(), vocoder.parameters()) - # Debug: check per-channel statistics are loaded - pcs = audio_decoder.per_channel_statistics - print(f"Per-channel stats: mean_of_means range=[{pcs._mean_of_means.min():.4f}, {pcs._mean_of_means.max():.4f}], std_of_means range=[{pcs._std_of_means.min():.4f}, {pcs._std_of_means.max():.4f}]") - - # Debug: check audio latent statistics - print(f"Audio latents shape: {audio_latents.shape}") - print(f"Audio latents stats: min={audio_latents.min():.4f}, max={audio_latents.max():.4f}, mean={audio_latents.mean():.4f}, std={mx.std(audio_latents):.4f}") - mel_spectrogram = audio_decoder(audio_latents) mx.eval(mel_spectrogram) - print(f"Mel spectrogram shape: {mel_spectrogram.shape}") - print(f"Mel spectrogram stats: min={mel_spectrogram.min():.4f}, max={mel_spectrogram.max():.4f}, mean={mel_spectrogram.mean():.4f}") - # Audio decoder output is already in vocoder format (B, C, T, F) audio_waveform = vocoder(mel_spectrogram) mx.eval(audio_waveform) - print(f"Audio waveform shape: {audio_waveform.shape}") - print(f"Audio waveform stats: min={audio_waveform.min():.4f}, max={audio_waveform.max():.4f}, mean={audio_waveform.mean():.4f}") - audio_np = np.array(audio_waveform) if audio_np.ndim == 3: audio_np = audio_np[0] # Remove batch dim diff --git a/mlx_video/utils.py b/mlx_video/utils.py index aff48ed..cebbed7 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -171,6 +171,7 @@ def load_image( image_path: Union[str, Path], height: Optional[int] = None, width: Optional[int] = None, + dtype: mx.Dtype = mx.float32, ) -> mx.array: """Load and preprocess an image for I2V conditioning. @@ -210,7 +211,7 @@ def load_image( # Convert to numpy then MLX image_np = np.array(image).astype(np.float32) / 255.0 - return mx.array(image_np) + return mx.array(image_np, dtype=dtype) def resize_image_aspect_ratio( From 61c56cd98903e154938bcdafc86bdf91af9a3e01 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 19:28:05 +0100 Subject: [PATCH 4/4] Add RoPE tests and warning for bfloat16 precision loss in RoPE calculations --- mlx_video/models/ltx/rope.py | 12 ++ tests/__init__.py | 0 tests/test_rope.py | 280 +++++++++++++++++++++++++++++++++++ 3 files changed, 292 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_rope.py diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index a00d019..4852942 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -430,7 +430,19 @@ def _precompute_freqs_cis_double_precision( rope_type: LTXRopeType, ) -> Tuple[mx.array, mx.array]: + # Warn if positions are bfloat16 - this causes quality degradation + if indices_grid.dtype == mx.bfloat16: + import warnings + warnings.warn( + "Position grid has dtype bfloat16, which causes precision loss in RoPE that causes quality degradation in generated videos/audio. " + "Use float32 for position grids to avoid quality degradation. " + "See tests/test_rope.py::test_bfloat16_positions_cause_precision_loss", + UserWarning, + stacklevel=2 + ) + # Convert to numpy float64 (first to float32 for numpy compatibility) + # Note: If input is bfloat16, precision is already lost at this step indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) # Generate frequency indices in float64 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_rope.py b/tests/test_rope.py new file mode 100644 index 0000000..cef8d6f --- /dev/null +++ b/tests/test_rope.py @@ -0,0 +1,280 @@ +import pytest +import mlx.core as mx +import numpy as np + +from mlx_video.models.ltx.rope import ( + precompute_freqs_cis, +) +from mlx_video.models.ltx.config import LTXRopeType + + +def create_video_position_grid( + batch_size: int, + num_frames: int, + height: int, + width: int, + dtype: mx.Dtype = mx.float32, +) -> mx.array: + """Create a simple video position grid for testing.""" + t_coords = np.arange(0, num_frames) + h_coords = np.arange(0, height) + w_coords = np.arange(0, width) + + t_grid, h_grid, w_grid = np.meshgrid(t_coords, h_coords, w_coords, indexing='ij') + patch_starts = np.stack([t_grid, h_grid, w_grid], axis=0) + patch_ends = patch_starts + 1 + + latent_coords = np.stack([patch_starts, patch_ends], axis=-1) + num_patches = num_frames * height * width + latent_coords = latent_coords.reshape(3, num_patches, 2) + latent_coords = np.tile(latent_coords[np.newaxis, ...], (batch_size, 1, 1, 1)) + + # Scale to pixel space + scale_factors = np.array([8, 32, 32]).reshape(1, 3, 1, 1) + pixel_coords = (latent_coords * scale_factors).astype(np.float32) + pixel_coords[:, 0, :, :] = pixel_coords[:, 0, :, :] / 24.0 # Convert to seconds + + return mx.array(pixel_coords, dtype=dtype) + +class TestRoPEPositionPrecision: + """Test suite for RoPE position precision requirements.""" + + def test_float32_positions_produce_consistent_output(self): + """Float32 position grids should produce stable RoPE frequencies.""" + positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + # Verify output dtype is float32 + assert cos_freq.dtype == mx.float32, f"Expected float32, got {cos_freq.dtype}" + assert sin_freq.dtype == mx.float32, f"Expected float32, got {sin_freq.dtype}" + + # Verify no NaN or Inf values + assert not mx.any(mx.isnan(cos_freq)).item(), "cos_freq contains NaN" + assert not mx.any(mx.isnan(sin_freq)).item(), "sin_freq contains NaN" + assert not mx.any(mx.isinf(cos_freq)).item(), "cos_freq contains Inf" + assert not mx.any(mx.isinf(sin_freq)).item(), "sin_freq contains Inf" + + # Verify cos/sin are in valid range [-1, 1] + assert mx.all(cos_freq >= -1.0).item() and mx.all(cos_freq <= 1.0).item(), \ + "cos_freq values out of [-1, 1] range" + assert mx.all(sin_freq >= -1.0).item() and mx.all(sin_freq <= 1.0).item(), \ + "sin_freq values out of [-1, 1] range" + + def test_bfloat16_positions_cause_precision_loss(self): + """bfloat16 positions should produce different (less precise) results than float32. + + This test documents the known issue: bfloat16 has only 7 bits of mantissa + vs 23 bits for float32, causing quantization errors that get amplified + by sin/cos calculations in RoPE. + """ + # Create identical position grids in different dtypes + positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) + + # Compute RoPE frequencies + cos_f32, sin_f32 = precompute_freqs_cis( + indices_grid=positions_f32, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + cos_bf16, sin_bf16 = precompute_freqs_cis( + indices_grid=positions_bf16, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + # Calculate the difference + cos_diff = mx.abs(cos_f32 - cos_bf16) + sin_diff = mx.abs(sin_f32 - sin_bf16) + + max_cos_diff = mx.max(cos_diff).item() + max_sin_diff = mx.max(sin_diff).item() + + # bfloat16 positions WILL cause measurable differences + # This test documents this known behavior + # The threshold here is intentionally low to catch the issue + precision_threshold = 1e-6 + + has_precision_loss = max_cos_diff > precision_threshold or max_sin_diff > precision_threshold + + # Document the precision loss (this is expected behavior) + if has_precision_loss: + print(f"\nPrecision loss detected (expected):") + print(f" Max cos difference: {max_cos_diff:.6e}") + print(f" Max sin difference: {max_sin_diff:.6e}") + + # This assertion documents the issue - bfloat16 positions cause precision loss + assert has_precision_loss, \ + "Expected precision loss with bfloat16 positions - if this fails, the issue may be fixed" + + def test_double_precision_converts_to_float32_internally(self): + """Verify that double_precision mode converts bfloat16 to float32 first.""" + positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) + + # The double precision path in rope.py line 434: + # indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) + # This means bfloat16 -> float32 -> float64 + # The precision is already lost at the bfloat16 -> float32 step + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions_bf16, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + # Output should still be float32 + assert cos_freq.dtype == mx.float32 + assert sin_freq.dtype == mx.float32 + + def test_position_grid_should_be_float32_recommendation(self): + """Test that validates the recommended practice: positions should be float32. + + This test serves as documentation that position grids MUST be float32 + to avoid quality degradation in generated videos/audio. + """ + # Recommended: create positions in float32 + positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + + assert positions.dtype == mx.float32, \ + "Position grids should be created in float32 for RoPE precision" + + # Verify the position values are reasonable + # Temporal positions should be small (seconds) + temporal_positions = positions[:, 0, :, :] + assert mx.max(temporal_positions).item() < 100, \ + "Temporal positions should be in seconds (small values)" + + # Spatial positions should be larger (pixels) + spatial_h = positions[:, 1, :, :] + spatial_w = positions[:, 2, :, :] + assert mx.max(spatial_h).item() > 0, "Spatial height positions should be positive" + assert mx.max(spatial_w).item() > 0, "Spatial width positions should be positive" + + +class TestRoPEInterleaved: + """Tests for interleaved RoPE mode.""" + + def test_interleaved_rope_with_float32_positions(self): + """Interleaved RoPE should work correctly with float32 positions.""" + positions = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.INTERLEAVED, + double_precision=False, + ) + + assert cos_freq.dtype == mx.float32 + assert sin_freq.dtype == mx.float32 + assert not mx.any(mx.isnan(cos_freq)).item() + assert not mx.any(mx.isnan(sin_freq)).item() + + +class TestRoPEWarnings: + """Tests for RoPE warnings.""" + + def test_bfloat16_positions_trigger_warning(self): + """Verify that bfloat16 positions trigger a UserWarning.""" + positions_bf16 = create_video_position_grid(1, 4, 4, 4, dtype=mx.bfloat16) + + with pytest.warns(UserWarning, match="Position grid has dtype bfloat16"): + precompute_freqs_cis( + indices_grid=positions_bf16, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + def test_float32_positions_no_warning(self): + """Verify that float32 positions do NOT trigger a warning.""" + positions_f32 = create_video_position_grid(1, 4, 4, 4, dtype=mx.float32) + + # This should not raise any warnings + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + precompute_freqs_cis( + indices_grid=positions_f32, + dim=128, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=32, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + +class TestRoPESplit: + """Tests for split RoPE mode (used by LTX-2).""" + + def test_split_rope_output_shape(self): + """Verify split RoPE output has correct shape (B, H, T, dim_per_head//2).""" + batch_size = 1 + num_frames = 4 + height = 4 + width = 4 + num_heads = 32 + dim = 128 + + positions = create_video_position_grid(batch_size, num_frames, height, width) + num_tokens = num_frames * height * width + + cos_freq, sin_freq = precompute_freqs_cis( + indices_grid=positions, + dim=dim, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=True, + num_attention_heads=num_heads, + rope_type=LTXRopeType.SPLIT, + double_precision=True, + ) + + # Shape should be (B, H, T, dim_per_head//2) + # dim=128, num_heads=32, so dim_per_head=4, and split uses half=2 + dim_per_head = dim // num_heads + expected_shape = (batch_size, num_heads, num_tokens, dim_per_head // 2) + assert cos_freq.shape == expected_shape, \ + f"Expected shape {expected_shape}, got {cos_freq.shape}" + assert sin_freq.shape == expected_shape, \ + f"Expected shape {expected_shape}, got {sin_freq.shape}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])