ensure dtype cast

This commit is contained in:
Prince Canuma
2026-01-17 13:03:48 +01:00
parent e4cdbb7eab
commit 883c6b0ad8
6 changed files with 52 additions and 32 deletions

View File

@@ -273,6 +273,13 @@ class ConnectorAttention(nn.Module):
Returns:
Tensor with SPLIT rotary embeddings applied
"""
input_dtype = x.dtype
# Cast to float32 for precision, then cast back
x = x.astype(mx.float32)
cos_freq = cos_freq.astype(mx.float32)
sin_freq = sin_freq.astype(mx.float32)
# Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2)
half_dim = x.shape[-1] // 2
x1 = x[..., :half_dim]
@@ -284,7 +291,7 @@ class ConnectorAttention(nn.Module):
out1 = x1 * cos_freq - x2 * sin_freq
out2 = x2 * cos_freq + x1 * sin_freq
return mx.concatenate([out1, out2], axis=-1)
return mx.concatenate([out1, out2], axis=-1).astype(input_dtype)
class GEGLU(nn.Module):
@@ -437,14 +444,15 @@ class Embeddings1DConnector(nn.Module):
attention_mask: mx.array,
) -> Tuple[mx.array, mx.array]:
batch_size, seq_len, dim = hidden_states.shape
dtype = hidden_states.dtype
# Binary mask: 1 for valid tokens, 0 for padded
# attention_mask is additive: 0 for valid, large negative for padded
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
# Tile registers to match sequence length
# Tile registers to match sequence length, cast to hidden_states dtype
num_tiles = seq_len // self.num_learnable_registers
registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim)
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
# Process each batch item (PyTorch uses advanced indexing)
result_list = []
@@ -462,7 +470,7 @@ class Embeddings1DConnector(nn.Module):
# Pad with zeros on the right to get back to seq_len
pad_length = seq_len - num_valid
if pad_length > 0:
padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype)
padding = mx.zeros((pad_length, dim), dtype=dtype)
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
else:
adjusted = valid_tokens
@@ -474,9 +482,8 @@ class Embeddings1DConnector(nn.Module):
], axis=0) # (seq,)
# Combine: valid tokens at front, registers at back
flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1)
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
result_list.append(combined)
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
@@ -491,7 +498,6 @@ class Embeddings1DConnector(nn.Module):
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
) -> Tuple[mx.array, mx.array]:
# Replace padded tokens with learnable registers
if self.num_learnable_registers > 0 and attention_mask is not None:
hidden_states, attention_mask = self._replace_padded_with_registers(
@@ -521,6 +527,7 @@ def norm_and_concat_hidden_states(
# Stack hidden states: (batch, seq, dim, num_layers)
stacked = mx.stack(hidden_states, axis=-1)
dtype = stacked.dtype
b, t, d, num_layers = stacked.shape
# Compute sequence lengths from attention mask
@@ -536,16 +543,16 @@ def norm_and_concat_hidden_states(
mask = token_indices >= start_indices # (B, T)
mask = mask[:, :, None, None] # (B, T, 1, 1)
eps = 1e-6
eps = mx.array(1e-6, dtype=dtype)
# Compute masked mean per layer
# Compute masked mean per layer - ensure dtype consistency
masked = mx.where(mask, stacked, mx.zeros_like(stacked))
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
denom = (sequence_lengths * d).reshape(b, 1, 1, 1).astype(dtype)
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
# Compute masked min/max per layer
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
range_val = x_max - x_min
@@ -749,13 +756,10 @@ class LTX2TextEncoder(nn.Module):
attention_mask = mx.array(inputs["attention_mask"])
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
concat_hidden = norm_and_concat_hidden_states(
all_hidden_states, attention_mask, padding_side="left"
)
features = self.feature_extractor(concat_hidden)
additive_mask = (attention_mask - 1).astype(features.dtype)
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9