- Refactor video generation script
- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions. - Updated VAE weight sanitization for compatibility and improved activation function handling in text projection. - Added support for saving individual frames and refined output video generation process.
This commit is contained in:
@@ -381,8 +381,7 @@ def precompute_freqs_cis(
|
||||
if max_pos is None:
|
||||
max_pos = [20, 2048, 2048]
|
||||
|
||||
# For double precision, compute in numpy (float64) then convert back to MLX
|
||||
# MLX GPU doesn't support float64, so we use numpy for high precision computation
|
||||
|
||||
if double_precision:
|
||||
return _precompute_freqs_cis_double_precision(
|
||||
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
|
||||
@@ -418,10 +417,7 @@ def _precompute_freqs_cis_double_precision(
|
||||
num_attention_heads: int,
|
||||
rope_type: LTXRopeType,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
"""Compute RoPE frequencies in double precision using numpy.
|
||||
|
||||
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
|
||||
"""
|
||||
# Convert to numpy float64
|
||||
indices_grid_np = np.array(indices_grid).astype(np.float64)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user