- Refactor video generation script

- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions.
- Updated VAE weight sanitization for compatibility and improved activation function handling in text projection.
- Added support for saving individual frames and refined output video generation process.
This commit is contained in:
Prince Canuma
2026-01-12 14:04:53 +01:00
parent d1ca36a315
commit 7114b023bd
6 changed files with 270 additions and 304 deletions

View File

@@ -381,8 +381,7 @@ def precompute_freqs_cis(
if max_pos is None:
max_pos = [20, 2048, 2048]
# For double precision, compute in numpy (float64) then convert back to MLX
# MLX GPU doesn't support float64, so we use numpy for high precision computation
if double_precision:
return _precompute_freqs_cis_double_precision(
indices_grid, dim, theta, max_pos, use_middle_indices_grid,
@@ -418,10 +417,7 @@ def _precompute_freqs_cis_double_precision(
num_attention_heads: int,
rope_type: LTXRopeType,
) -> Tuple[mx.array, mx.array]:
"""Compute RoPE frequencies in double precision using numpy.
MLX GPU doesn't support float64, so we use numpy for computation then convert back.
"""
# Convert to numpy float64
indices_grid_np = np.array(indices_grid).astype(np.float64)