From 883c6b0ad8e615a6922fc3455d3f482235c55ec6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 17 Jan 2026 13:03:48 +0100 Subject: [PATCH] ensure dtype cast --- mlx_video/conditioning/latent.py | 9 ++++-- mlx_video/models/ltx/ltx.py | 11 +++++--- mlx_video/models/ltx/rope.py | 12 ++++++-- mlx_video/models/ltx/text_encoder.py | 34 +++++++++++++---------- mlx_video/models/ltx/video_vae/decoder.py | 3 +- mlx_video/utils.py | 15 ++++++---- 6 files changed, 52 insertions(+), 32 deletions(-) diff --git a/mlx_video/conditioning/latent.py b/mlx_video/conditioning/latent.py index 1825e3d..acf3d99 100644 --- a/mlx_video/conditioning/latent.py +++ b/mlx_video/conditioning/latent.py @@ -95,6 +95,7 @@ def apply_conditioning( Updated LatentState with conditioning applied """ state = state.clone() + dtype = state.latent.dtype b, c, f, h, w = state.latent.shape for cond in conditionings: @@ -132,7 +133,7 @@ def apply_conditioning( latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1]) # Set mask: 1.0 - strength means less denoising for conditioned frames - mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength)) + mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype)) else: # Keep original latent_list.append(state.latent[:, :, i:i+1]) @@ -161,7 +162,8 @@ def apply_denoise_mask( Returns: Blended latent """ - return denoised * denoise_mask + clean * (1.0 - denoise_mask) + one = mx.array(1.0, dtype=denoised.dtype) + return denoised * denoise_mask + clean * (one - denoise_mask) def add_noise_with_state( @@ -191,6 +193,7 @@ def add_noise_with_state( # But we scale sigma by the mask for conditioned regions effective_scale = noise_scale * state.denoise_mask - state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale) + one = mx.array(1.0, dtype=state.latent.dtype) + state.latent = noise * effective_scale + state.latent * (one - effective_scale) return state diff --git a/mlx_video/models/ltx/ltx.py b/mlx_video/models/ltx/ltx.py index 75987bc..a3eef42 100644 --- a/mlx_video/models/ltx/ltx.py +++ b/mlx_video/models/ltx/ltx.py @@ -52,10 +52,11 @@ class TransformerArgsPreprocessor: self, timestep: mx.array, batch_size: int, + hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * self.timestep_scale_multiplier - timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1)) + timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) # Reshape to (batch, tokens, dim) timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1])) @@ -117,7 +118,7 @@ class TransformerArgsPreprocessor: def prepare(self, modality: Modality) -> TransformerArgs: x = self.patchify_proj(modality.latent) - timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0]) + timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype) context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask) attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype) pe = self._prepare_positional_embeddings( @@ -201,6 +202,7 @@ class MultiModalTransformerArgsPreprocessor: timestep=modality.timesteps, timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier, batch_size=transformer_args.x.shape[0], + hidden_dtype=transformer_args.x.dtype, ) return replace( @@ -215,15 +217,16 @@ class MultiModalTransformerArgsPreprocessor: timestep: mx.array, timestep_scale_multiplier: int, batch_size: int, + hidden_dtype: mx.Dtype = None, ) -> Tuple[mx.array, mx.array]: timestep = timestep * timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier - scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1)) + scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype) scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1])) - gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor) + gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype) gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1])) return scale_shift_timestep, gate_timestep diff --git a/mlx_video/models/ltx/rope.py b/mlx_video/models/ltx/rope.py index 54b721a..a00d019 100644 --- a/mlx_video/models/ltx/rope.py +++ b/mlx_video/models/ltx/rope.py @@ -128,6 +128,7 @@ def apply_split_rotary_emb( Returns: Tensor with split rotary embeddings applied """ + input_dtype = input_tensor.dtype needs_reshape = False original_shape = input_tensor.shape @@ -139,6 +140,11 @@ def apply_split_rotary_emb( input_tensor = mx.swapaxes(input_tensor, 1, 2) needs_reshape = True + # Cast to float32 for computation precision + input_tensor = input_tensor.astype(mx.float32) + cos_freqs = cos_freqs.astype(mx.float32) + sin_freqs = sin_freqs.astype(mx.float32) + # Split into two halves: (..., dim) -> (..., 2, dim//2) dim = input_tensor.shape[-1] split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2)) @@ -167,7 +173,7 @@ def apply_split_rotary_emb( output = mx.swapaxes(output, 1, 2) output = mx.reshape(output, (b, t, h * d)) - return output + return output.astype(input_dtype) def generate_freq_grid( @@ -424,8 +430,8 @@ def _precompute_freqs_cis_double_precision( rope_type: LTXRopeType, ) -> Tuple[mx.array, mx.array]: - # Convert to numpy float64 - indices_grid_np = np.array(indices_grid).astype(np.float64) + # Convert to numpy float64 (first to float32 for numpy compatibility) + indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64) # Generate frequency indices in float64 n_pos_dims = indices_grid_np.shape[1] diff --git a/mlx_video/models/ltx/text_encoder.py b/mlx_video/models/ltx/text_encoder.py index 29993fb..d6461d5 100644 --- a/mlx_video/models/ltx/text_encoder.py +++ b/mlx_video/models/ltx/text_encoder.py @@ -273,6 +273,13 @@ class ConnectorAttention(nn.Module): Returns: Tensor with SPLIT rotary embeddings applied """ + input_dtype = x.dtype + + # Cast to float32 for precision, then cast back + x = x.astype(mx.float32) + cos_freq = cos_freq.astype(mx.float32) + sin_freq = sin_freq.astype(mx.float32) + # Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2) half_dim = x.shape[-1] // 2 x1 = x[..., :half_dim] @@ -284,7 +291,7 @@ class ConnectorAttention(nn.Module): out1 = x1 * cos_freq - x2 * sin_freq out2 = x2 * cos_freq + x1 * sin_freq - return mx.concatenate([out1, out2], axis=-1) + return mx.concatenate([out1, out2], axis=-1).astype(input_dtype) class GEGLU(nn.Module): @@ -437,14 +444,15 @@ class Embeddings1DConnector(nn.Module): attention_mask: mx.array, ) -> Tuple[mx.array, mx.array]: batch_size, seq_len, dim = hidden_states.shape + dtype = hidden_states.dtype # Binary mask: 1 for valid tokens, 0 for padded # attention_mask is additive: 0 for valid, large negative for padded mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq) - # Tile registers to match sequence length + # Tile registers to match sequence length, cast to hidden_states dtype num_tiles = seq_len // self.num_learnable_registers - registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim) + registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim) # Process each batch item (PyTorch uses advanced indexing) result_list = [] @@ -462,7 +470,7 @@ class Embeddings1DConnector(nn.Module): # Pad with zeros on the right to get back to seq_len pad_length = seq_len - num_valid if pad_length > 0: - padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype) + padding = mx.zeros((pad_length, dim), dtype=dtype) adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim) else: adjusted = valid_tokens @@ -474,9 +482,8 @@ class Embeddings1DConnector(nn.Module): ], axis=0) # (seq,) # Combine: valid tokens at front, registers at back - flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1) + flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1) combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers - result_list.append(combined) hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim) @@ -491,7 +498,6 @@ class Embeddings1DConnector(nn.Module): hidden_states: mx.array, attention_mask: Optional[mx.array] = None, ) -> Tuple[mx.array, mx.array]: - # Replace padded tokens with learnable registers if self.num_learnable_registers > 0 and attention_mask is not None: hidden_states, attention_mask = self._replace_padded_with_registers( @@ -521,6 +527,7 @@ def norm_and_concat_hidden_states( # Stack hidden states: (batch, seq, dim, num_layers) stacked = mx.stack(hidden_states, axis=-1) + dtype = stacked.dtype b, t, d, num_layers = stacked.shape # Compute sequence lengths from attention mask @@ -536,16 +543,16 @@ def norm_and_concat_hidden_states( mask = token_indices >= start_indices # (B, T) mask = mask[:, :, None, None] # (B, T, 1, 1) - eps = 1e-6 + eps = mx.array(1e-6, dtype=dtype) - # Compute masked mean per layer + # Compute masked mean per layer - ensure dtype consistency masked = mx.where(mask, stacked, mx.zeros_like(stacked)) - denom = (sequence_lengths * d).reshape(b, 1, 1, 1) + denom = (sequence_lengths * d).reshape(b, 1, 1, 1).astype(dtype) mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps) # Compute masked min/max per layer - x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype)) - x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype)) + x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype)) + x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype)) x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True) x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True) range_val = x_max - x_min @@ -749,13 +756,10 @@ class LTX2TextEncoder(nn.Module): attention_mask = mx.array(inputs["attention_mask"]) _, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True) - concat_hidden = norm_and_concat_hidden_states( all_hidden_states, attention_mask, padding_side="left" ) - features = self.feature_extractor(concat_hidden) - additive_mask = (attention_mask - 1).astype(features.dtype) additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9 diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index 390a92b..0cb0d7b 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -348,10 +348,11 @@ class LTX2VideoDecoder(nn.Module): def denormalize(self, x: mx.array) -> mx.array: """Denormalize latents using per-channel statistics.""" + dtype = x.dtype # Cast to float32 for precision (statistics may be in bfloat16) mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1) std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1) - return x * std + mean + return (x * std + mean).astype(dtype) def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array: """Apply pixel normalization.""" diff --git a/mlx_video/utils.py b/mlx_video/utils.py index 4b50536..aff48ed 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -44,10 +44,9 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict): class_predicate=get_class_predicate, ) - -@partial(mx.compile, shapeless=True) +@partial(mx.compile, shapeless=True) def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: - return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps) + return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps) @@ -71,9 +70,12 @@ def to_denoised( Denoised tensor x_0 """ if isinstance(sigma, (int, float)): - return noisy - sigma * velocity + # Convert to array with matching dtype to avoid float32 promotion + sigma_arr = mx.array(sigma, dtype=velocity.dtype) + return noisy - sigma_arr * velocity else: - # sigma is per-sample + # sigma is per-sample - ensure dtype matches + sigma = sigma.astype(velocity.dtype) while sigma.ndim < velocity.ndim: sigma = mx.expand_dims(sigma, axis=-1) return noisy - sigma * velocity @@ -251,6 +253,7 @@ def prepare_image_for_encoding( image: mx.array, target_height: int, target_width: int, + dtype: mx.Dtype = mx.float32, ) -> mx.array: """Prepare image for VAE encoding by resizing and normalizing. @@ -281,4 +284,4 @@ def prepare_image_for_encoding( image = mx.expand_dims(image, axis=0) # (1, 3, H, W) image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W) - return image + return image.astype(dtype)