Refactor LTX2VideoDecoder and ResBlockGroup: Change up_blocks and res_blocks from lists to dictionaries for better parameter tracking in MLX
This commit is contained in:
@@ -201,14 +201,15 @@ class ResBlockGroup(nn.Module):
|
|||||||
embedding_dim=channels * 4
|
embedding_dim=channels * 4
|
||||||
)
|
)
|
||||||
|
|
||||||
self.res_blocks = [
|
# Use dict with int keys for MLX to track parameters properly
|
||||||
ResnetBlock3DSimple(
|
self.res_blocks = {
|
||||||
|
i: ResnetBlock3DSimple(
|
||||||
channels,
|
channels,
|
||||||
spatial_padding_mode,
|
spatial_padding_mode,
|
||||||
timestep_conditioning=timestep_conditioning
|
timestep_conditioning=timestep_conditioning
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for i in range(num_layers)
|
||||||
]
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -227,7 +228,7 @@ class ResBlockGroup(nn.Module):
|
|||||||
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
# Reshape to (B, 4*C, 1, 1, 1) for broadcasting
|
||||||
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
timestep_embed = timestep_embed.reshape(batch_size, -1, 1, 1, 1)
|
||||||
|
|
||||||
for res_block in self.res_blocks:
|
for res_block in self.res_blocks.values():
|
||||||
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
x = res_block(x, causal=causal, timestep_embed=timestep_embed)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -287,10 +288,10 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
self.conv_in = ConvInWrapper()
|
self.conv_in = ConvInWrapper()
|
||||||
|
|
||||||
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
# Up blocks: alternating ResBlockGroup and DepthToSpaceUpsample
|
||||||
|
# Use dict with int keys for MLX to track parameters properly
|
||||||
self.up_blocks = [
|
self.up_blocks = {
|
||||||
ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
0: ResBlockGroup(1024, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
DepthToSpaceUpsample(
|
1: DepthToSpaceUpsample(
|
||||||
dims=3,
|
dims=3,
|
||||||
in_channels=1024,
|
in_channels=1024,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
@@ -298,8 +299,8 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
out_channels_reduction_factor=2,
|
out_channels_reduction_factor=2,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
),
|
),
|
||||||
ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
2: ResBlockGroup(512, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
DepthToSpaceUpsample(
|
3: DepthToSpaceUpsample(
|
||||||
dims=3,
|
dims=3,
|
||||||
in_channels=512,
|
in_channels=512,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
@@ -307,8 +308,8 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
out_channels_reduction_factor=2,
|
out_channels_reduction_factor=2,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
),
|
),
|
||||||
ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
4: ResBlockGroup(256, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
DepthToSpaceUpsample(
|
5: DepthToSpaceUpsample(
|
||||||
dims=3,
|
dims=3,
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
stride=(2, 2, 2),
|
stride=(2, 2, 2),
|
||||||
@@ -316,8 +317,8 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
out_channels_reduction_factor=2,
|
out_channels_reduction_factor=2,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
),
|
),
|
||||||
ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
6: ResBlockGroup(128, num_layers_per_block, spatial_padding_mode, timestep_conditioning),
|
||||||
]
|
}
|
||||||
|
|
||||||
final_out_channels = out_channels * patch_size * patch_size
|
final_out_channels = out_channels * patch_size * patch_size
|
||||||
class ConvOutWrapper(nn.Module):
|
class ConvOutWrapper(nn.Module):
|
||||||
@@ -396,10 +397,10 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
if debug:
|
if debug:
|
||||||
debug_stats("After conv_in", x)
|
debug_stats("After conv_in", x)
|
||||||
|
|
||||||
for i, block in enumerate(self.up_blocks):
|
for i, block in self.up_blocks.items():
|
||||||
if isinstance(block, ResBlockGroup):
|
if isinstance(block, ResBlockGroup):
|
||||||
x = block(x, causal=causal, timestep=scaled_timestep)
|
x = block(x, causal=causal, timestep=scaled_timestep)
|
||||||
else:
|
else:
|
||||||
x = block(x, causal=causal)
|
x = block(x, causal=causal)
|
||||||
if debug:
|
if debug:
|
||||||
block_type = type(block).__name__
|
block_type = type(block).__name__
|
||||||
@@ -443,10 +444,10 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX2VideoDecoder:
|
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import json
|
||||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
from safetensors import safe_open
|
||||||
|
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
|
|
||||||
@@ -461,6 +462,25 @@ def load_vae_decoder(model_path: str, timestep_conditioning: bool = True) -> LTX
|
|||||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||||
|
|
||||||
print(f"Loading VAE decoder from {weights_path}...")
|
print(f"Loading VAE decoder from {weights_path}...")
|
||||||
|
|
||||||
|
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||||
|
if timestep_conditioning is None:
|
||||||
|
try:
|
||||||
|
with safe_open(str(weights_path), framework="numpy") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
if metadata and "config" in metadata:
|
||||||
|
configs = json.loads(metadata["config"])
|
||||||
|
vae_config = configs.get("vae", {})
|
||||||
|
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||||
|
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||||
|
else:
|
||||||
|
timestep_conditioning = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||||
|
timestep_conditioning = False
|
||||||
|
|
||||||
|
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||||
|
|
||||||
weights = mx.load(str(weights_path))
|
weights = mx.load(str(weights_path))
|
||||||
|
|
||||||
# Determine prefix based on weight keys
|
# Determine prefix based on weight keys
|
||||||
|
|||||||
Reference in New Issue
Block a user