Optimize positional embedding handling in TransformerArgsPreprocessor and improve RoPE frequency computation in _precompute_freqs_cis_double_precision for enhanced performance and precision.
This commit is contained in:
@@ -12,12 +12,14 @@ from mlx_video.utils import rms_norm
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Modality:
|
||||
latent: mx.array
|
||||
timesteps: mx.array
|
||||
positions: mx.array
|
||||
context: mx.array
|
||||
latent: mx.array
|
||||
timesteps: mx.array
|
||||
positions: mx.array
|
||||
context: mx.array
|
||||
enabled: bool = True
|
||||
context_mask: Optional[mx.array] = None
|
||||
# Optional precomputed positional embeddings (RoPE) to avoid recomputation
|
||||
positional_embeddings: Optional[Tuple[mx.array, mx.array]] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
Reference in New Issue
Block a user