Optimize positional embedding handling in TransformerArgsPreprocessor and improve RoPE frequency computation in _precompute_freqs_cis_double_precision for enhanced performance and precision.

This commit is contained in:
Prince Canuma
2026-01-18 11:13:32 +01:00
parent 62fc4805a0
commit e483eab039
3 changed files with 72 additions and 61 deletions

View File

@@ -12,12 +12,14 @@ from mlx_video.utils import rms_norm
@dataclass(frozen=True)
class Modality:
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
latent: mx.array
timesteps: mx.array
positions: mx.array
context: mx.array
enabled: bool = True
context_mask: Optional[mx.array] = None
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
@dataclass(frozen=True)