ensure dtype cast
This commit is contained in:
@@ -95,6 +95,7 @@ def apply_conditioning(
|
|||||||
Updated LatentState with conditioning applied
|
Updated LatentState with conditioning applied
|
||||||
"""
|
"""
|
||||||
state = state.clone()
|
state = state.clone()
|
||||||
|
dtype = state.latent.dtype
|
||||||
b, c, f, h, w = state.latent.shape
|
b, c, f, h, w = state.latent.shape
|
||||||
|
|
||||||
for cond in conditionings:
|
for cond in conditionings:
|
||||||
@@ -132,7 +133,7 @@ def apply_conditioning(
|
|||||||
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
latent_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||||
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
clean_list.append(cond_latent[:, :, cond_idx:cond_idx+1])
|
||||||
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
# Set mask: 1.0 - strength means less denoising for conditioned frames
|
||||||
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength))
|
mask_list.append(mx.full((b, 1, 1, 1, 1), 1.0 - strength, dtype=dtype))
|
||||||
else:
|
else:
|
||||||
# Keep original
|
# Keep original
|
||||||
latent_list.append(state.latent[:, :, i:i+1])
|
latent_list.append(state.latent[:, :, i:i+1])
|
||||||
@@ -161,7 +162,8 @@ def apply_denoise_mask(
|
|||||||
Returns:
|
Returns:
|
||||||
Blended latent
|
Blended latent
|
||||||
"""
|
"""
|
||||||
return denoised * denoise_mask + clean * (1.0 - denoise_mask)
|
one = mx.array(1.0, dtype=denoised.dtype)
|
||||||
|
return denoised * denoise_mask + clean * (one - denoise_mask)
|
||||||
|
|
||||||
|
|
||||||
def add_noise_with_state(
|
def add_noise_with_state(
|
||||||
@@ -191,6 +193,7 @@ def add_noise_with_state(
|
|||||||
# But we scale sigma by the mask for conditioned regions
|
# But we scale sigma by the mask for conditioned regions
|
||||||
|
|
||||||
effective_scale = noise_scale * state.denoise_mask
|
effective_scale = noise_scale * state.denoise_mask
|
||||||
state.latent = noise * effective_scale + state.latent * (1.0 - effective_scale)
|
one = mx.array(1.0, dtype=state.latent.dtype)
|
||||||
|
state.latent = noise * effective_scale + state.latent * (one - effective_scale)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|||||||
@@ -52,10 +52,11 @@ class TransformerArgsPreprocessor:
|
|||||||
self,
|
self,
|
||||||
timestep: mx.array,
|
timestep: mx.array,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
hidden_dtype: mx.Dtype = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep = timestep * self.timestep_scale_multiplier
|
||||||
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1))
|
timestep_emb, embedded_timestep = self.adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||||
|
|
||||||
# Reshape to (batch, tokens, dim)
|
# Reshape to (batch, tokens, dim)
|
||||||
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
timestep_emb = mx.reshape(timestep_emb, (batch_size, -1, timestep_emb.shape[-1]))
|
||||||
@@ -117,7 +118,7 @@ class TransformerArgsPreprocessor:
|
|||||||
|
|
||||||
def prepare(self, modality: Modality) -> TransformerArgs:
|
def prepare(self, modality: Modality) -> TransformerArgs:
|
||||||
x = self.patchify_proj(modality.latent)
|
x = self.patchify_proj(modality.latent)
|
||||||
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0])
|
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
|
||||||
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
|
||||||
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
|
||||||
pe = self._prepare_positional_embeddings(
|
pe = self._prepare_positional_embeddings(
|
||||||
@@ -201,6 +202,7 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
timestep=modality.timesteps,
|
timestep=modality.timesteps,
|
||||||
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
timestep_scale_multiplier=self.simple_preprocessor.timestep_scale_multiplier,
|
||||||
batch_size=transformer_args.x.shape[0],
|
batch_size=transformer_args.x.shape[0],
|
||||||
|
hidden_dtype=transformer_args.x.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
return replace(
|
return replace(
|
||||||
@@ -215,15 +217,16 @@ class MultiModalTransformerArgsPreprocessor:
|
|||||||
timestep: mx.array,
|
timestep: mx.array,
|
||||||
timestep_scale_multiplier: int,
|
timestep_scale_multiplier: int,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
hidden_dtype: mx.Dtype = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
timestep = timestep * timestep_scale_multiplier
|
timestep = timestep * timestep_scale_multiplier
|
||||||
|
|
||||||
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
av_ca_factor = self.av_ca_timestep_scale_multiplier / timestep_scale_multiplier
|
||||||
|
|
||||||
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1))
|
scale_shift_timestep, _ = self.cross_scale_shift_adaln(timestep.reshape(-1), hidden_dtype=hidden_dtype)
|
||||||
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
scale_shift_timestep = mx.reshape(scale_shift_timestep, (batch_size, -1, scale_shift_timestep.shape[-1]))
|
||||||
|
|
||||||
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor)
|
gate_timestep, _ = self.cross_gate_adaln(timestep.reshape(-1) * av_ca_factor, hidden_dtype=hidden_dtype)
|
||||||
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
gate_timestep = mx.reshape(gate_timestep, (batch_size, -1, gate_timestep.shape[-1]))
|
||||||
|
|
||||||
return scale_shift_timestep, gate_timestep
|
return scale_shift_timestep, gate_timestep
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ def apply_split_rotary_emb(
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor with split rotary embeddings applied
|
Tensor with split rotary embeddings applied
|
||||||
"""
|
"""
|
||||||
|
input_dtype = input_tensor.dtype
|
||||||
needs_reshape = False
|
needs_reshape = False
|
||||||
original_shape = input_tensor.shape
|
original_shape = input_tensor.shape
|
||||||
|
|
||||||
@@ -139,6 +140,11 @@ def apply_split_rotary_emb(
|
|||||||
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
input_tensor = mx.swapaxes(input_tensor, 1, 2)
|
||||||
needs_reshape = True
|
needs_reshape = True
|
||||||
|
|
||||||
|
# Cast to float32 for computation precision
|
||||||
|
input_tensor = input_tensor.astype(mx.float32)
|
||||||
|
cos_freqs = cos_freqs.astype(mx.float32)
|
||||||
|
sin_freqs = sin_freqs.astype(mx.float32)
|
||||||
|
|
||||||
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
# Split into two halves: (..., dim) -> (..., 2, dim//2)
|
||||||
dim = input_tensor.shape[-1]
|
dim = input_tensor.shape[-1]
|
||||||
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
split_input = mx.reshape(input_tensor, input_tensor.shape[:-1] + (2, dim // 2))
|
||||||
@@ -167,7 +173,7 @@ def apply_split_rotary_emb(
|
|||||||
output = mx.swapaxes(output, 1, 2)
|
output = mx.swapaxes(output, 1, 2)
|
||||||
output = mx.reshape(output, (b, t, h * d))
|
output = mx.reshape(output, (b, t, h * d))
|
||||||
|
|
||||||
return output
|
return output.astype(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
def generate_freq_grid(
|
def generate_freq_grid(
|
||||||
@@ -424,8 +430,8 @@ def _precompute_freqs_cis_double_precision(
|
|||||||
rope_type: LTXRopeType,
|
rope_type: LTXRopeType,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
# Convert to numpy float64
|
# Convert to numpy float64 (first to float32 for numpy compatibility)
|
||||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
indices_grid_np = np.array(indices_grid.astype(mx.float32)).astype(np.float64)
|
||||||
|
|
||||||
# Generate frequency indices in float64
|
# Generate frequency indices in float64
|
||||||
n_pos_dims = indices_grid_np.shape[1]
|
n_pos_dims = indices_grid_np.shape[1]
|
||||||
|
|||||||
@@ -273,6 +273,13 @@ class ConnectorAttention(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Tensor with SPLIT rotary embeddings applied
|
Tensor with SPLIT rotary embeddings applied
|
||||||
"""
|
"""
|
||||||
|
input_dtype = x.dtype
|
||||||
|
|
||||||
|
# Cast to float32 for precision, then cast back
|
||||||
|
x = x.astype(mx.float32)
|
||||||
|
cos_freq = cos_freq.astype(mx.float32)
|
||||||
|
sin_freq = sin_freq.astype(mx.float32)
|
||||||
|
|
||||||
# Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2)
|
# Split x into two halves: (B, H, T, D) -> two tensors of (B, H, T, D//2)
|
||||||
half_dim = x.shape[-1] // 2
|
half_dim = x.shape[-1] // 2
|
||||||
x1 = x[..., :half_dim]
|
x1 = x[..., :half_dim]
|
||||||
@@ -284,7 +291,7 @@ class ConnectorAttention(nn.Module):
|
|||||||
out1 = x1 * cos_freq - x2 * sin_freq
|
out1 = x1 * cos_freq - x2 * sin_freq
|
||||||
out2 = x2 * cos_freq + x1 * sin_freq
|
out2 = x2 * cos_freq + x1 * sin_freq
|
||||||
|
|
||||||
return mx.concatenate([out1, out2], axis=-1)
|
return mx.concatenate([out1, out2], axis=-1).astype(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
class GEGLU(nn.Module):
|
class GEGLU(nn.Module):
|
||||||
@@ -437,14 +444,15 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
attention_mask: mx.array,
|
attention_mask: mx.array,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
batch_size, seq_len, dim = hidden_states.shape
|
batch_size, seq_len, dim = hidden_states.shape
|
||||||
|
dtype = hidden_states.dtype
|
||||||
|
|
||||||
# Binary mask: 1 for valid tokens, 0 for padded
|
# Binary mask: 1 for valid tokens, 0 for padded
|
||||||
# attention_mask is additive: 0 for valid, large negative for padded
|
# attention_mask is additive: 0 for valid, large negative for padded
|
||||||
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
mask_binary = (attention_mask.squeeze(1).squeeze(1) >= -9000.0).astype(mx.int32) # (batch, seq)
|
||||||
|
|
||||||
# Tile registers to match sequence length
|
# Tile registers to match sequence length, cast to hidden_states dtype
|
||||||
num_tiles = seq_len // self.num_learnable_registers
|
num_tiles = seq_len // self.num_learnable_registers
|
||||||
registers = mx.tile(self.learnable_registers, (num_tiles, 1)) # (seq_len, dim)
|
registers = mx.tile(self.learnable_registers, (num_tiles, 1)).astype(dtype) # (seq_len, dim)
|
||||||
|
|
||||||
# Process each batch item (PyTorch uses advanced indexing)
|
# Process each batch item (PyTorch uses advanced indexing)
|
||||||
result_list = []
|
result_list = []
|
||||||
@@ -462,7 +470,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
# Pad with zeros on the right to get back to seq_len
|
# Pad with zeros on the right to get back to seq_len
|
||||||
pad_length = seq_len - num_valid
|
pad_length = seq_len - num_valid
|
||||||
if pad_length > 0:
|
if pad_length > 0:
|
||||||
padding = mx.zeros((pad_length, dim), dtype=hs_b.dtype)
|
padding = mx.zeros((pad_length, dim), dtype=dtype)
|
||||||
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
adjusted = mx.concatenate([valid_tokens, padding], axis=0) # (seq_len, dim)
|
||||||
else:
|
else:
|
||||||
adjusted = valid_tokens
|
adjusted = valid_tokens
|
||||||
@@ -474,9 +482,8 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
], axis=0) # (seq,)
|
], axis=0) # (seq,)
|
||||||
|
|
||||||
# Combine: valid tokens at front, registers at back
|
# Combine: valid tokens at front, registers at back
|
||||||
flipped_mask_expanded = flipped_mask[:, None].astype(hs_b.dtype) # (seq, 1)
|
flipped_mask_expanded = flipped_mask[:, None].astype(dtype) # (seq, 1)
|
||||||
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
combined = flipped_mask_expanded * adjusted + (1 - flipped_mask_expanded) * registers
|
||||||
|
|
||||||
result_list.append(combined)
|
result_list.append(combined)
|
||||||
|
|
||||||
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
hidden_states = mx.stack(result_list, axis=0) # (batch, seq, dim)
|
||||||
@@ -491,7 +498,6 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
hidden_states: mx.array,
|
hidden_states: mx.array,
|
||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: Optional[mx.array] = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
|
|
||||||
# Replace padded tokens with learnable registers
|
# Replace padded tokens with learnable registers
|
||||||
if self.num_learnable_registers > 0 and attention_mask is not None:
|
if self.num_learnable_registers > 0 and attention_mask is not None:
|
||||||
hidden_states, attention_mask = self._replace_padded_with_registers(
|
hidden_states, attention_mask = self._replace_padded_with_registers(
|
||||||
@@ -521,6 +527,7 @@ def norm_and_concat_hidden_states(
|
|||||||
|
|
||||||
# Stack hidden states: (batch, seq, dim, num_layers)
|
# Stack hidden states: (batch, seq, dim, num_layers)
|
||||||
stacked = mx.stack(hidden_states, axis=-1)
|
stacked = mx.stack(hidden_states, axis=-1)
|
||||||
|
dtype = stacked.dtype
|
||||||
b, t, d, num_layers = stacked.shape
|
b, t, d, num_layers = stacked.shape
|
||||||
|
|
||||||
# Compute sequence lengths from attention mask
|
# Compute sequence lengths from attention mask
|
||||||
@@ -536,16 +543,16 @@ def norm_and_concat_hidden_states(
|
|||||||
mask = token_indices >= start_indices # (B, T)
|
mask = token_indices >= start_indices # (B, T)
|
||||||
|
|
||||||
mask = mask[:, :, None, None] # (B, T, 1, 1)
|
mask = mask[:, :, None, None] # (B, T, 1, 1)
|
||||||
eps = 1e-6
|
eps = mx.array(1e-6, dtype=dtype)
|
||||||
|
|
||||||
# Compute masked mean per layer
|
# Compute masked mean per layer - ensure dtype consistency
|
||||||
masked = mx.where(mask, stacked, mx.zeros_like(stacked))
|
masked = mx.where(mask, stacked, mx.zeros_like(stacked))
|
||||||
denom = (sequence_lengths * d).reshape(b, 1, 1, 1)
|
denom = (sequence_lengths * d).reshape(b, 1, 1, 1).astype(dtype)
|
||||||
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
mean = mx.sum(masked, axis=(1, 2), keepdims=True) / (denom + eps)
|
||||||
|
|
||||||
# Compute masked min/max per layer
|
# Compute masked min/max per layer
|
||||||
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=stacked.dtype))
|
x_for_min = mx.where(mask, stacked, mx.full(stacked.shape, float('inf'), dtype=dtype))
|
||||||
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=stacked.dtype))
|
x_for_max = mx.where(mask, stacked, mx.full(stacked.shape, float('-inf'), dtype=dtype))
|
||||||
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
x_min = mx.min(x_for_min, axis=(1, 2), keepdims=True)
|
||||||
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
x_max = mx.max(x_for_max, axis=(1, 2), keepdims=True)
|
||||||
range_val = x_max - x_min
|
range_val = x_max - x_min
|
||||||
@@ -749,13 +756,10 @@ class LTX2TextEncoder(nn.Module):
|
|||||||
attention_mask = mx.array(inputs["attention_mask"])
|
attention_mask = mx.array(inputs["attention_mask"])
|
||||||
|
|
||||||
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
_, all_hidden_states = self.language_model(inputs=input_ids, input_embeddings=None, attention_mask=attention_mask, output_hidden_states=True)
|
||||||
|
|
||||||
concat_hidden = norm_and_concat_hidden_states(
|
concat_hidden = norm_and_concat_hidden_states(
|
||||||
all_hidden_states, attention_mask, padding_side="left"
|
all_hidden_states, attention_mask, padding_side="left"
|
||||||
)
|
)
|
||||||
|
|
||||||
features = self.feature_extractor(concat_hidden)
|
features = self.feature_extractor(concat_hidden)
|
||||||
|
|
||||||
additive_mask = (attention_mask - 1).astype(features.dtype)
|
additive_mask = (attention_mask - 1).astype(features.dtype)
|
||||||
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
additive_mask = additive_mask.reshape(attention_mask.shape[0], 1, 1, -1) * 1e9
|
||||||
|
|
||||||
|
|||||||
@@ -348,10 +348,11 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
|
|
||||||
def denormalize(self, x: mx.array) -> mx.array:
|
def denormalize(self, x: mx.array) -> mx.array:
|
||||||
"""Denormalize latents using per-channel statistics."""
|
"""Denormalize latents using per-channel statistics."""
|
||||||
|
dtype = x.dtype
|
||||||
# Cast to float32 for precision (statistics may be in bfloat16)
|
# Cast to float32 for precision (statistics may be in bfloat16)
|
||||||
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||||
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||||
return x * std + mean
|
return (x * std + mean).astype(dtype)
|
||||||
|
|
||||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||||
"""Apply pixel normalization."""
|
"""Apply pixel normalization."""
|
||||||
|
|||||||
@@ -44,10 +44,9 @@ def apply_quantization(model: nn.Module, weights: mx.array, quantization: dict):
|
|||||||
class_predicate=get_class_predicate,
|
class_predicate=get_class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, shapeless=True)
|
@partial(mx.compile, shapeless=True)
|
||||||
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array:
|
||||||
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],)), eps)
|
return mx.fast.rms_norm(x, mx.ones((x.shape[-1],), dtype=x.dtype), eps)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -71,9 +70,12 @@ def to_denoised(
|
|||||||
Denoised tensor x_0
|
Denoised tensor x_0
|
||||||
"""
|
"""
|
||||||
if isinstance(sigma, (int, float)):
|
if isinstance(sigma, (int, float)):
|
||||||
return noisy - sigma * velocity
|
# Convert to array with matching dtype to avoid float32 promotion
|
||||||
|
sigma_arr = mx.array(sigma, dtype=velocity.dtype)
|
||||||
|
return noisy - sigma_arr * velocity
|
||||||
else:
|
else:
|
||||||
# sigma is per-sample
|
# sigma is per-sample - ensure dtype matches
|
||||||
|
sigma = sigma.astype(velocity.dtype)
|
||||||
while sigma.ndim < velocity.ndim:
|
while sigma.ndim < velocity.ndim:
|
||||||
sigma = mx.expand_dims(sigma, axis=-1)
|
sigma = mx.expand_dims(sigma, axis=-1)
|
||||||
return noisy - sigma * velocity
|
return noisy - sigma * velocity
|
||||||
@@ -251,6 +253,7 @@ def prepare_image_for_encoding(
|
|||||||
image: mx.array,
|
image: mx.array,
|
||||||
target_height: int,
|
target_height: int,
|
||||||
target_width: int,
|
target_width: int,
|
||||||
|
dtype: mx.Dtype = mx.float32,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
"""Prepare image for VAE encoding by resizing and normalizing.
|
"""Prepare image for VAE encoding by resizing and normalizing.
|
||||||
|
|
||||||
@@ -281,4 +284,4 @@ def prepare_image_for_encoding(
|
|||||||
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
|
image = mx.expand_dims(image, axis=0) # (1, 3, H, W)
|
||||||
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
|
image = mx.expand_dims(image, axis=2) # (1, 3, 1, H, W)
|
||||||
|
|
||||||
return image
|
return image.astype(dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user