Refactor LTX2VideoDecoder and ResBlockGroup: Change up_blocks and res_blocks from lists to dictionaries for better parameter tracking in MLX
This commit is contained in:
@@ -201,14 +201,15 @@ class ResBlockGroup(nn.Module):
|
||||
embedding_dim=channels * 4
|
||||
)
|
||||
|
||||
self.res_blocks = [
|
||||
ResnetBlock3DSimple(
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.res_blocks = {
|
||||
i: ResnetBlock3DSimple(
|
||||
channels,
|
||||
spatial_padding_mode,
|
||||
timestep_conditioning=timestep_conditioning
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
for i in range(num_layers)
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -227,7 +228,7 @@ class ResBlockGroup(nn.Module):
|
||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||
|
||||
for res_block in self.res_blocks:
|
||||
for res_block in self.res_blocks.values():
|
||||
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
||||
return x
|
||||
|
||||
@@ -287,10 +288,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
self.conv_in = ConvInWrapper()
|
||||
|
||||
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
||||
|
||||
self.up_blocks = [
|
||||
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
# Use dict with int keys for MLX to track parameters properly
|
||||
self.up_blocks = {
|
||||
0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
1: DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=1024,
|
||||
stride=(2, 2, 2),
|
||||
@@ -298,8 +299,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
3: DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=512,
|
||||
stride=(2, 2, 2),
|
||||
@@ -307,8 +308,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
DepthToSpaceUpsample(
|
||||
4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
5: DepthToSpaceUpsample(
|
||||
dims=3,
|
||||
in_channels=256,
|
||||
stride=(2, 2, 2),
|
||||
@@ -316,8 +317,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
out_channels_reduction_factor=2,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
),
|
||||
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
]
|
||||
6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||
}
|
||||
|
||||
final_out_channels = out_channels * patch_size * patch_size
|
||||
class ConvOutWrapper(nn.Module):
|
||||
@@ -396,7 +397,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
if debug:
|
||||
debug_stats("After conv_in", x)
|
||||
|
||||
for i, block in enumerate(self.up_blocks):
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||
else:
|
||||
@@ -443,10 +444,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
||||
from pathlib import Path
|
||||
|
||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||
import json
|
||||
from safetensors import safe_open
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
@@ -461,6 +462,25 @@ def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE decoder from {weights_path}...")
|
||||
|
||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||
if timestep_conditioning is None:
|
||||
try:
|
||||
with safe_open(str(weights_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||
else:
|
||||
timestep_conditioning = False
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||
timestep_conditioning = False
|
||||
|
||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
|
||||
Reference in New Issue
Block a user