This commit is contained in:
Prince Canuma
2026-03-18 17:40:05 +01:00
parent 78bcfba31b
commit 17397da70c
77 changed files with 4125 additions and 1655 deletions

View File

@@ -1,4 +1,5 @@
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -36,11 +37,20 @@ class Conv3d(nn.Module):
self.groups = groups
# Weight shape: (C_out, KD, KH, KW, C_in)
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
scale = (
1.0
/ (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
shape=(
out_channels,
kernel_size[0],
kernel_size[1],
kernel_size[2],
in_channels,
),
)
if bias:
@@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module):
n, d, h, w, c = x.shape
input_dtype = x.dtype
x = x.astype(mx.float32)
# Reshape to (N, D*H*W, num_groups, C//num_groups)
@@ -219,7 +228,9 @@ class SpatialRationalResampler(nn.Module):
self.den = den
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1)
self.conv = nn.Conv2d(
mid_channels, num * num * mid_channels, kernel_size=3, padding=1
)
self.pixel_shuffle = PixelShuffle2D(num, num)
self.blur_down = BlurDownsample(stride=den)
@@ -230,7 +241,7 @@ class SpatialRationalResampler(nn.Module):
x = self.conv(x)
x = self.pixel_shuffle(x) # H*num, W*num
x = self.blur_down(x) # H*num/den, W*num/den
x = self.blur_down(x) # H*num/den, W*num/den
_, h_out, w_out, _ = x.shape
x = mx.reshape(x, (n, d, h_out, w_out, c))
@@ -240,6 +251,7 @@ class SpatialRationalResampler(nn.Module):
def _rational_for_scale(scale: float) -> Tuple[int, int]:
"""Convert a float scale to a rational fraction (numerator, denominator)."""
from fractions import Fraction
frac = Fraction(scale).limit_denominator(10)
return frac.numerator, frac.denominator
@@ -290,16 +302,22 @@ class LatentUpsampler(nn.Module):
self.initial_norm = GroupNorm3d(32, mid_channels)
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
self.res_blocks = {
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
}
# Upsampler: 2D spatial upsampling (frame-by-frame)
if rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale)
self.upsampler = SpatialRationalResampler(
mid_channels=mid_channels, scale=spatial_scale
)
else:
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
self.post_upsample_res_blocks = {
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
}
# Final projection
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
@@ -314,10 +332,13 @@ class LatentUpsampler(nn.Module):
Returns:
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
"""
def debug_stats(name, t):
if debug:
mx.eval(t)
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
print(
f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}"
)
if debug:
print(" [DEBUG] LatentUpsampler forward pass:")
@@ -404,7 +425,11 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
# Both formats may have upsampler.blur_down.kernel, so use channel count
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
conv_key = (
"upsampler.conv.weight"
if "upsampler.conv.weight" in raw_weights
else "upsampler.0.weight"
)
if conv_key in raw_weights:
out_channels = raw_weights[conv_key].shape[0]
ratio = out_channels // mid_channels
@@ -414,7 +439,9 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
rational_resampler = False
spatial_scale = 2.0
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
print(
f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}"
)
# Create model
upsampler = LatentUpsampler(