Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -273,9 +273,10 @@ class VideoEncoder(nn.Module):
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build encoder blocks - use dict with int keys for MLX parameter tracking
|
||||
# Build encoder blocks
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for i, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
@@ -287,7 +288,7 @@ class VideoEncoder(nn.Module):
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks[i] = block
|
||||
self.down_blocks[idx] = block
|
||||
|
||||
# Output normalization and convolution
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
@@ -341,7 +342,8 @@ class VideoEncoder(nn.Module):
|
||||
sample = self.conv_in(sample, causal=True)
|
||||
|
||||
# Process through encoder blocks
|
||||
for down_block in self.down_blocks.values():
|
||||
for i in range(len(self.down_blocks)):
|
||||
down_block = self.down_blocks[i]
|
||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||
sample = down_block(sample, causal=True)
|
||||
else:
|
||||
@@ -440,8 +442,9 @@ class VideoDecoder(nn.Module):
|
||||
)
|
||||
|
||||
# Build decoder blocks (reversed order)
|
||||
self.up_blocks = []
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
@@ -454,7 +457,7 @@ class VideoDecoder(nn.Module):
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
self.up_blocks.append(block)
|
||||
self.up_blocks[idx] = block
|
||||
|
||||
# Output normalization
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
@@ -509,7 +512,8 @@ class VideoDecoder(nn.Module):
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
# Process through decoder blocks
|
||||
for up_block in self.up_blocks:
|
||||
for i in range(len(self.up_blocks)):
|
||||
up_block = self.up_blocks[i]
|
||||
if isinstance(up_block, UNetMidBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
elif isinstance(up_block, ResnetBlock3D):
|
||||
|
||||
Reference in New Issue
Block a user