Refactor LTX2VideoDecoder and ResBlockGroup: Change up_blocks and res_blocks from lists to dictionaries for better parameter tracking in MLX

This commit is contained in:
Prince Canuma
2026-01-15 03:48:16 +01:00
parent 3fcd8f90be
commit 09c2b460a7

View File

@@ -201,14 +201,15 @@ class ResBlockGroup(nn.Module):
embedding_dim=channels * 4 embedding_dim=channels * 4
) )
self.res_blocks = [ # Use dict with int keys for MLX to track parameters properly
ResnetBlock3DSimple( self.res_blocks = {
i: ResnetBlock3DSimple(
channels, channels,
spatial_padding_mode, spatial_padding_mode,
timestep_conditioning=timestep_conditioning timestep_conditioning=timestep_conditioning
) )
for _ in range(num_layers) for i in range(num_layers)
] }
def __call__( def __call__(
self, self,
@@ -227,7 +228,7 @@ class ResBlockGroup(nn.Module):
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting # Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1) timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
for res_block in self.res_blocks: for res_block in self.res_blocks.values():
x = res_block(x, causal=causal, timestep_embed=timestep_embed) x = res_block(x, causal=causal, timestep_embed=timestep_embed)
return x return x
@@ -287,10 +288,10 @@ class LTX2VideoDecoder(nn.Module):
self.conv_in = ConvInWrapper() self.conv_in = ConvInWrapper()
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample # Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
# Use dict with int keys for MLX to track parameters properly
self.up_blocks = [ self.up_blocks = {
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning), 0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample( 1: DepthToSpaceUpsample(
dims=3, dims=3,
in_channels=1024, in_channels=1024,
stride=(2, 2, 2), stride=(2, 2, 2),
@@ -298,8 +299,8 @@ class LTX2VideoDecoder(nn.Module):
out_channels_reduction_factor=2, out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
), ),
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning), 2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample( 3: DepthToSpaceUpsample(
dims=3, dims=3,
in_channels=512, in_channels=512,
stride=(2, 2, 2), stride=(2, 2, 2),
@@ -307,8 +308,8 @@ class LTX2VideoDecoder(nn.Module):
out_channels_reduction_factor=2, out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
), ),
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning), 4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample( 5: DepthToSpaceUpsample(
dims=3, dims=3,
in_channels=256, in_channels=256,
stride=(2, 2, 2), stride=(2, 2, 2),
@@ -316,8 +317,8 @@ class LTX2VideoDecoder(nn.Module):
out_channels_reduction_factor=2, out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
), ),
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning), 6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
] }
final_out_channels = out_channels * patch_size * patch_size final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module): class ConvOutWrapper(nn.Module):
@@ -396,10 +397,10 @@ class LTX2VideoDecoder(nn.Module):
if debug: if debug:
debug_stats("After conv_in", x) debug_stats("After conv_in", x)
for i, block in enumerate(self.up_blocks): for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup): if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep) x = block(x, causal=causal, timestep=scaled_timestep)
else: else:
x = block(x, causal=causal) x = block(x, causal=causal)
if debug: if debug:
block_type = type(block).__name__ block_type = type(block).__name__
@@ -443,10 +444,10 @@ class LTX2VideoDecoder(nn.Module):
return x return x
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder: def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path from pathlib import Path
import json
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning) from safetensors import safe_open
model_path = Path(model_path) model_path = Path(model_path)
@@ -461,6 +462,25 @@ def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX
raise FileNotFoundError(f"VAE weights not found at {model_path}") raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...") print(f"Loading VAE decoder from {weights_path}...")
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
weights = mx.load(str(weights_path)) weights = mx.load(str(weights_path))
# Determine prefix based on weight keys # Determine prefix based on weight keys