Optimize positional embedding handling in TransformerArgsPreprocessor and improve RoPE frequency computation in _precompute_freqs_cis_double_precision for enhanced performance and precision.

This commit is contained in:
Prince Canuma
2026-01-18 11:13:32 +01:00
parent 62fc4805a0
commit e483eab039
3 changed files with 72 additions and 61 deletions

View File

@@ -121,13 +121,18 @@ class TransformerArgsPreprocessor:
timestep, embedded_timestep = self._prepare_timestep(modality.timesteps, x.shape[0], hidden_dtype=x.dtype)
context, attention_mask = self._prepare_context(modality.context, x, modality.context_mask)
attention_mask = self._prepare_attention_mask(attention_mask, modality.latent.dtype)
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
# Use precomputed positional embeddings if provided (avoids expensive RoPE recomputation)
if modality.positional_embeddings is not None:
pe = modality.positional_embeddings
else:
pe = self._prepare_positional_embeddings(
positions=modality.positions,
inner_dim=self.inner_dim,
max_pos=self.max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return TransformerArgs(
x=x,