- Refactor video generation script
- Introduced argparse for parameter handling, streamlined model loading, and enhanced denoising functions. - Updated VAE weight sanitization for compatibility and improved activation function handling in text projection. - Added support for saving individual frames and refined output video generation process.
This commit is contained in:
@@ -161,7 +161,7 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
weights: Dictionary of weights with PyTorch naming
|
||||
|
||||
Returns:
|
||||
Dictionary with MLX-compatible naming for VAE
|
||||
Dictionary with MLX-compatible naming for VAE decoder
|
||||
"""
|
||||
sanitized = {}
|
||||
|
||||
@@ -172,17 +172,40 @@ def sanitize_vae_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
if "position_ids" in key:
|
||||
continue
|
||||
|
||||
# Only process VAE decoder weights (skip audio_vae, etc.)
|
||||
if not key.startswith("vae."):
|
||||
continue
|
||||
|
||||
# Handle per-channel statistics key mapping
|
||||
# PyTorch: vae.per_channel_statistics.mean-of-means -> per_channel_statistics.mean
|
||||
# PyTorch: vae.per_channel_statistics.std-of-means -> per_channel_statistics.std
|
||||
# Be careful: mean-of-stds_over_std-of-means also ends with std-of-means
|
||||
if "vae.per_channel_statistics" in key:
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
# Skip other per_channel_statistics keys (channel, mean-of-stds, etc.)
|
||||
continue
|
||||
elif key.startswith("vae.decoder."):
|
||||
# Strip the vae.decoder. prefix for decoder weights
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
else:
|
||||
# Skip other vae.* keys that are not decoder weights
|
||||
continue
|
||||
|
||||
# Handle Conv3d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, D, H, W)
|
||||
# MLX: (out_channels, D, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 5:
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 5:
|
||||
# Transpose from (O, I, D, H, W) to (O, D, H, W, I)
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
# Handle Conv2d weight shape conversion
|
||||
# PyTorch: (out_channels, in_channels, H, W)
|
||||
# MLX: (out_channels, H, W, in_channels)
|
||||
if "conv" in key.lower() and "weight" in key and value.ndim == 4:
|
||||
if "conv" in new_key.lower() and "weight" in new_key and value.ndim == 4:
|
||||
value = mx.transpose(value, (0, 2, 3, 1))
|
||||
|
||||
sanitized[new_key] = value
|
||||
|
||||
Reference in New Issue
Block a user