format
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
@@ -36,11 +37,20 @@ class Conv3d(nn.Module):
|
||||
self.groups = groups
|
||||
|
||||
# Weight shape: (C_out, KD, KH, KW, C_in)
|
||||
scale = 1.0 / (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||
scale = (
|
||||
1.0
|
||||
/ (in_channels * kernel_size[0] * kernel_size[1] * kernel_size[2]) ** 0.5
|
||||
)
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(out_channels, kernel_size[0], kernel_size[1], kernel_size[2], in_channels),
|
||||
shape=(
|
||||
out_channels,
|
||||
kernel_size[0],
|
||||
kernel_size[1],
|
||||
kernel_size[2],
|
||||
in_channels,
|
||||
),
|
||||
)
|
||||
|
||||
if bias:
|
||||
@@ -87,7 +97,6 @@ class GroupNorm3d(nn.Module):
|
||||
n, d, h, w, c = x.shape
|
||||
input_dtype = x.dtype
|
||||
|
||||
|
||||
x = x.astype(mx.float32)
|
||||
|
||||
# Reshape to (N, D*H*W, num_groups, C//num_groups)
|
||||
@@ -219,7 +228,9 @@ class SpatialRationalResampler(nn.Module):
|
||||
self.den = den
|
||||
|
||||
# Conv2d: mid_channels -> num^2 * mid_channels for PixelShuffle(num)
|
||||
self.conv = nn.Conv2d(mid_channels, num * num * mid_channels, kernel_size=3, padding=1)
|
||||
self.conv = nn.Conv2d(
|
||||
mid_channels, num * num * mid_channels, kernel_size=3, padding=1
|
||||
)
|
||||
self.pixel_shuffle = PixelShuffle2D(num, num)
|
||||
self.blur_down = BlurDownsample(stride=den)
|
||||
|
||||
@@ -230,7 +241,7 @@ class SpatialRationalResampler(nn.Module):
|
||||
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x) # H*num, W*num
|
||||
x = self.blur_down(x) # H*num/den, W*num/den
|
||||
x = self.blur_down(x) # H*num/den, W*num/den
|
||||
|
||||
_, h_out, w_out, _ = x.shape
|
||||
x = mx.reshape(x, (n, d, h_out, w_out, c))
|
||||
@@ -240,6 +251,7 @@ class SpatialRationalResampler(nn.Module):
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
"""Convert a float scale to a rational fraction (numerator, denominator)."""
|
||||
from fractions import Fraction
|
||||
|
||||
frac = Fraction(scale).limit_denominator(10)
|
||||
return frac.numerator, frac.denominator
|
||||
|
||||
@@ -290,16 +302,22 @@ class LatentUpsampler(nn.Module):
|
||||
self.initial_norm = GroupNorm3d(32, mid_channels)
|
||||
|
||||
# Pre-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
self.res_blocks = {
|
||||
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
|
||||
}
|
||||
|
||||
# Upsampler: 2D spatial upsampling (frame-by-frame)
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=spatial_scale)
|
||||
self.upsampler = SpatialRationalResampler(
|
||||
mid_channels=mid_channels, scale=spatial_scale
|
||||
)
|
||||
else:
|
||||
self.upsampler = SpatialUpsampler2x(mid_channels=mid_channels)
|
||||
|
||||
# Post-upsample ResBlocks - use dict with int keys for MLX parameter tracking
|
||||
self.post_upsample_res_blocks = {i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)}
|
||||
self.post_upsample_res_blocks = {
|
||||
i: ResBlock3D(mid_channels) for i in range(num_blocks_per_stage)
|
||||
}
|
||||
|
||||
# Final projection
|
||||
self.final_conv = Conv3d(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
@@ -314,10 +332,13 @@ class LatentUpsampler(nn.Module):
|
||||
Returns:
|
||||
Upsampled tensor of shape (B, C, F, H*scale, W*scale) - channels first
|
||||
"""
|
||||
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
print(
|
||||
f" {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}"
|
||||
)
|
||||
|
||||
if debug:
|
||||
print(" [DEBUG] LatentUpsampler forward pass:")
|
||||
@@ -404,7 +425,11 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
||||
# x2: conv out = 4 * mid (2^2 * mid for PixelShuffle(2))
|
||||
# x1.5: conv out = 9 * mid (3^2 * mid for PixelShuffle(3)) + blur downsample
|
||||
# Both formats may have upsampler.blur_down.kernel, so use channel count
|
||||
conv_key = "upsampler.conv.weight" if "upsampler.conv.weight" in raw_weights else "upsampler.0.weight"
|
||||
conv_key = (
|
||||
"upsampler.conv.weight"
|
||||
if "upsampler.conv.weight" in raw_weights
|
||||
else "upsampler.0.weight"
|
||||
)
|
||||
if conv_key in raw_weights:
|
||||
out_channels = raw_weights[conv_key].shape[0]
|
||||
ratio = out_channels // mid_channels
|
||||
@@ -414,7 +439,9 @@ def load_upsampler(weights_path: str) -> Tuple[LatentUpsampler, float]:
|
||||
rational_resampler = False
|
||||
spatial_scale = 2.0
|
||||
|
||||
print(f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}")
|
||||
print(
|
||||
f" Detected: mid_channels={mid_channels}, scale={spatial_scale}x, rational={rational_resampler}"
|
||||
)
|
||||
|
||||
# Create model
|
||||
upsampler = LatentUpsampler(
|
||||
|
||||
Reference in New Issue
Block a user