Refactor LTX2VideoDecoder and ResBlockGroup: Change up_blocks and res_blocks from lists to dictionaries for better parameter tracking in MLX

This commit is contained in:
Prince Canuma
2026-01-15 03:48:16 +01:00
parent 3fcd8f90be
commit 09c2b460a7

View File

@@ -201,14 +201,15 @@ class ResBlockGroup(nn.Module):
embedding_dim=channels * 4
)
self.res_blocks = [
ResnetBlock3DSimple(
# Use dict with int keys for MLX to track parameters properly
self.res_blocks = {
i: ResnetBlock3DSimple(
channels,
spatial_padding_mode,
timestep_conditioning=timestep_conditioning
)
for _ in range(num_layers)
]
for i in range(num_layers)
}
def __call__(
self,
@@ -227,7 +228,7 @@ class ResBlockGroup(nn.Module):
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
for res_block in self.res_blocks:
for res_block in self.res_blocks.values():
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
return x
@@ -287,10 +288,10 @@ class LTX2VideoDecoder(nn.Module):
self.conv_in = ConvInWrapper()
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
self.up_blocks = [
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
# Use dict with int keys for MLX to track parameters properly
self.up_blocks = {
0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
1: DepthToSpaceUpsample(
dims=3,
in_channels=1024,
stride=(2, 2, 2),
@@ -298,8 +299,8 @@ class LTX2VideoDecoder(nn.Module):
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
3: DepthToSpaceUpsample(
dims=3,
in_channels=512,
stride=(2, 2, 2),
@@ -307,8 +308,8 @@ class LTX2VideoDecoder(nn.Module):
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
DepthToSpaceUpsample(
4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
5: DepthToSpaceUpsample(
dims=3,
in_channels=256,
stride=(2, 2, 2),
@@ -316,8 +317,8 @@ class LTX2VideoDecoder(nn.Module):
out_channels_reduction_factor=2,
spatial_padding_mode=spatial_padding_mode,
),
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
]
6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
}
final_out_channels = out_channels * patch_size * patch_size
class ConvOutWrapper(nn.Module):
@@ -396,7 +397,7 @@ class LTX2VideoDecoder(nn.Module):
if debug:
debug_stats("After conv_in", x)
for i, block in enumerate(self.up_blocks):
for i, block in self.up_blocks.items():
if isinstance(block, ResBlockGroup):
x = block(x, causal=causal, timestep=scaled_timestep)
else:
@@ -443,10 +444,10 @@ class LTX2VideoDecoder(nn.Module):
return x
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
from pathlib import Path
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
import json
from safetensors import safe_open
model_path = Path(model_path)
@@ -461,6 +462,25 @@ def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX
raise FileNotFoundError(f"VAE weights not found at {model_path}")
print(f"Loading VAE decoder from {weights_path}...")
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(weights_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
weights = mx.load(str(weights_path))
# Determine prefix based on weight keys