- Refactor video generation script

- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions.
- Updated VAE weight sanitization for compatibility and improved activation function handling in text projection.
- Added support for saving individual frames and refined output video generation process.
This commit is contained in:
Prince Canuma
2026-01-12 14:04:53 +01:00
parent d1ca36a315
commit 7114b023bd
6 changed files with 270 additions and 304 deletions

View File

@@ -161,7 +161,7 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
weights: Dictionary of weights with PyTorch naming
Returns:
Dictionary with MLX-compatible naming for VAE
Dictionary with MLX-compatible naming for VAE decoder
"""
sanitized = {}
@@ -172,17 +172,40 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
if "position_ids" in key:
continue
# Only process VAE decoder weights (skip audio_vae, etc.)
if not key.startswith("vae."):
continue
# Handle per-channel statistics key mapping
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
if "vae.per_channel_statistics" in key:
if key == "vae.per_channel_statistics.mean-of-means":
new_key = "per_channel_statistics.mean"
elif key == "vae.per_channel_statistics.std-of-means":
new_key = "per_channel_statistics.std"
else:
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
continue
elif key.startswith("vae.decoder."):
# Strip the vae.decoder. prefix for decoder weights
new_key = key.replace("vae.decoder.", "")
else:
# Skip other vae.* keys that are not decoder weights
continue
# Handle Conv3d weight shape conversion
# PyTorch: (out_channels, in_channels, D, H, W)
# MLX: (out_channels, D, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
value = mx.transpose(value, (0, 2, 3, 4, 1))
# Handle Conv2d weight shape conversion
# PyTorch: (out_channels, in_channels, H, W)
# MLX: (out_channels, H, W, in_channels)
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
value = mx.transpose(value, (0, 2, 3, 1))
sanitized[new_key] = value