Refactor weight loading and sanitization processes for audio models
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import load_vae_encoder, encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import load_vae_decoder, LTX2VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.encoder import encode_image
|
||||
from mlx_video.models.ltx.video_vae.decoder import LTX2VideoDecoder
|
||||
from mlx_video.models.ltx.video_vae.tiling import (
|
||||
TilingConfig,
|
||||
SpatialTilingConfig,
|
||||
|
||||
@@ -15,13 +15,14 @@ Architecture (from PyTorch weights):
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from mlx_video.models.ltx.video_vae.convolution import CausalConv3d, PaddingModeType
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify
|
||||
from mlx_video.models.ltx.video_vae.ops import unpatchify, PerChannelStatistics
|
||||
from mlx_video.models.ltx.video_vae.sampling import DepthToSpaceUpsample
|
||||
from mlx_video.models.ltx.video_vae.tiling import TilingConfig, decode_with_tiling
|
||||
|
||||
@@ -269,8 +270,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
self.decode_timestep = 0.05
|
||||
|
||||
# Per-channel statistics for denormalization (loaded from weights)
|
||||
self.latents_mean = mx.zeros((in_channels,))
|
||||
self.latents_std = mx.ones((in_channels,))
|
||||
self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels)
|
||||
|
||||
# Initial conv: 128 -> 1024
|
||||
class ConvInWrapper(nn.Module):
|
||||
@@ -346,13 +346,72 @@ class LTX2VideoDecoder(nn.Module):
|
||||
)
|
||||
self.last_scale_shift_table = mx.zeros((2, 128))
|
||||
|
||||
def denormalize(self, x: mx.array) -> mx.array:
|
||||
"""Denormalize latents using per-channel statistics."""
|
||||
dtype = x.dtype
|
||||
# Cast to float32 for precision (statistics may be in bfloat16)
|
||||
mean = self.latents_mean.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
std = self.latents_std.astype(mx.float32).reshape(1, -1, 1, 1, 1)
|
||||
return (x * std + mean).astype(dtype)
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Build decoder weights dict with key remapping
|
||||
sanitized = {}
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
if not key.startswith("vae.") or key.startswith("vae.encoder."):
|
||||
continue
|
||||
|
||||
if key.startswith("vae.per_channel_statistics."):
|
||||
# Map per-channel statistics (use exact key matching)
|
||||
if key == "vae.per_channel_statistics.mean-of-means":
|
||||
new_key = "per_channel_statistics.mean"
|
||||
elif key == "vae.per_channel_statistics.std-of-means":
|
||||
new_key = "per_channel_statistics.std"
|
||||
else:
|
||||
continue # Skip other statistics keys
|
||||
|
||||
if key.startswith("vae.decoder."):
|
||||
new_key = key.replace("vae.decoder.", "")
|
||||
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
sanitized[new_key] = value
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
weights = mx.load(str(model_path))
|
||||
|
||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||
if timestep_conditioning is None:
|
||||
try:
|
||||
with safe_open(str(model_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||
else:
|
||||
timestep_conditioning = False
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||
timestep_conditioning = False
|
||||
|
||||
model = cls(timestep_conditioning=timestep_conditioning)
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def pixel_norm(self, x: mx.array, eps: float = 1e-8) -> mx.array:
|
||||
"""Apply pixel normalization."""
|
||||
@@ -367,28 +426,19 @@ class LTX2VideoDecoder(nn.Module):
|
||||
chunked_conv: bool = False,
|
||||
) -> mx.array:
|
||||
|
||||
def debug_stats(name, t):
|
||||
if debug:
|
||||
mx.eval(t)
|
||||
print(f" [VAE] {name}: shape={t.shape}, min={t.min().item():.4f}, max={t.max().item():.4f}, mean={t.mean().item():.4f}")
|
||||
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
if debug:
|
||||
debug_stats("Input", sample)
|
||||
|
||||
|
||||
# Add noise if timestep conditioning is enabled
|
||||
if self.timestep_conditioning:
|
||||
noise = mx.random.normal(sample.shape) * self.decode_noise_scale
|
||||
sample = noise + (1.0 - self.decode_noise_scale) * sample
|
||||
if debug:
|
||||
debug_stats("After noise", sample)
|
||||
|
||||
|
||||
if debug:
|
||||
print(f" [VAE] Denorm stats - mean: [{self.latents_mean.min().item():.4f}, {self.latents_mean.max().item():.4f}], std: [{self.latents_std.min().item():.4f}, {self.latents_std.max().item():.4f}]")
|
||||
sample = self.denormalize(sample)
|
||||
if debug:
|
||||
debug_stats("After denormalize", sample)
|
||||
sample = self.per_channel_statistics.un_normalize(sample)
|
||||
|
||||
|
||||
if timestep is None and self.timestep_conditioning:
|
||||
timestep = mx.full((batch_size,), self.decode_timestep)
|
||||
@@ -398,8 +448,7 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
x = self.conv_in(sample, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_in", x)
|
||||
|
||||
|
||||
for i, block in self.up_blocks.items():
|
||||
if isinstance(block, ResBlockGroup):
|
||||
@@ -408,13 +457,10 @@ class LTX2VideoDecoder(nn.Module):
|
||||
x = block(x, causal=causal, chunked_conv=chunked_conv)
|
||||
else:
|
||||
x = block(x, causal=causal)
|
||||
if debug:
|
||||
block_type = type(block).__name__
|
||||
debug_stats(f"After up_blocks[{i}] ({block_type})", x)
|
||||
|
||||
|
||||
x = self.pixel_norm(x)
|
||||
if debug:
|
||||
debug_stats("After pixel_norm", x)
|
||||
|
||||
|
||||
if self.timestep_conditioning and scaled_timestep is not None:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
@@ -431,21 +477,16 @@ class LTX2VideoDecoder(nn.Module):
|
||||
scale = ada_values[:, 1]
|
||||
|
||||
x = x * (1 + scale) + shift
|
||||
if debug:
|
||||
debug_stats("After timestep modulation", x)
|
||||
|
||||
|
||||
x = self.act(x)
|
||||
if debug:
|
||||
debug_stats("After activation", x)
|
||||
|
||||
|
||||
x = self.conv_out(x, causal=causal)
|
||||
if debug:
|
||||
debug_stats("After conv_out", x)
|
||||
|
||||
|
||||
# Unpatchify: (B, 48, F', H', W') -> (B, 3, F, H*4, W*4)
|
||||
x = unpatchify(x, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||
if debug:
|
||||
debug_stats("After unpatchify", x)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
@@ -519,103 +560,3 @@ class LTX2VideoDecoder(nn.Module):
|
||||
chunked_conv=use_chunked_conv,
|
||||
on_frames_ready=on_frames_ready,
|
||||
)
|
||||
|
||||
|
||||
def load_vae_decoder(model_path: str, timestep_conditioning: Optional[bool] = None) -> LTX2VideoDecoder:
|
||||
from pathlib import Path
|
||||
import json
|
||||
from safetensors import safe_open
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE decoder from {weights_path}...")
|
||||
|
||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||
if timestep_conditioning is None:
|
||||
try:
|
||||
with safe_open(str(weights_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||
else:
|
||||
timestep_conditioning = False
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||
timestep_conditioning = False
|
||||
|
||||
decoder = LTX2VideoDecoder(timestep_conditioning=timestep_conditioning)
|
||||
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
has_decoder_prefix = any(k.startswith("decoder.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.decoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
elif has_decoder_prefix:
|
||||
prefix = "decoder."
|
||||
stats_prefix = ""
|
||||
else:
|
||||
prefix = ""
|
||||
stats_prefix = ""
|
||||
|
||||
# Load per-channel statistics for denormalization
|
||||
# Note: use std-of-means (not mean-of-stds) for proper denormalization
|
||||
mean_key = f"{stats_prefix}mean-of-means" if stats_prefix else "latents_mean"
|
||||
std_key = f"{stats_prefix}std-of-means" if stats_prefix else "latents_std"
|
||||
|
||||
if mean_key in weights:
|
||||
decoder.latents_mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {decoder.latents_mean.shape}")
|
||||
if std_key in weights:
|
||||
decoder.latents_std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {decoder.latents_std.shape}")
|
||||
|
||||
# Build decoder weights dict with key remapping
|
||||
decoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".conv.weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
if ".conv.bias" in key:
|
||||
pass # bias doesn't need transpose
|
||||
|
||||
|
||||
if ".conv.weight" in new_key or ".conv.bias" in new_key:
|
||||
if ".conv.conv.weight" not in new_key and ".conv.conv.bias" not in new_key:
|
||||
new_key = new_key.replace(".conv.weight", ".conv.conv.weight")
|
||||
new_key = new_key.replace(".conv.bias", ".conv.conv.bias")
|
||||
|
||||
decoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(decoder_weights)} decoder weights")
|
||||
|
||||
ts_keys = [k for k in decoder_weights.keys() if "scale_shift" in k or "time_embedder" in k or "timestep_scale" in k]
|
||||
print(f" Found {len(ts_keys)} timestep conditioning weights")
|
||||
|
||||
# Load weights
|
||||
decoder.load_weights(list(decoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE decoder loaded successfully")
|
||||
return decoder
|
||||
|
||||
@@ -5,152 +5,9 @@ Used for I2V (image-to-video) conditioning by encoding the input image
|
||||
to latent space, which can then be used to condition video generation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Any, Optional
|
||||
import json
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder
|
||||
|
||||
from mlx_video.models.ltx.video_vae.video_vae import VideoEncoder, LogVarianceType, NormLayerType, PaddingModeType
|
||||
|
||||
|
||||
def load_vae_encoder(model_path: str) -> VideoEncoder:
|
||||
"""Load VAE encoder from safetensors file.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model weights (safetensors file or directory)
|
||||
|
||||
Returns:
|
||||
Loaded VideoEncoder instance
|
||||
"""
|
||||
from safetensors import safe_open
|
||||
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Try to find the weights file
|
||||
if model_path.is_file() and model_path.suffix == ".safetensors":
|
||||
weights_path = model_path
|
||||
elif (model_path / "ltx-2-19b-distilled.safetensors").exists():
|
||||
weights_path = model_path / "ltx-2-19b-distilled.safetensors"
|
||||
elif (model_path / "vae" / "diffusion_pytorch_model.safetensors").exists():
|
||||
weights_path = model_path / "vae" / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise FileNotFoundError(f"VAE weights not found at {model_path}")
|
||||
|
||||
print(f"Loading VAE encoder from {weights_path}...")
|
||||
|
||||
# Read config from safetensors metadata
|
||||
encoder_blocks = []
|
||||
norm_layer = NormLayerType.PIXEL_NORM
|
||||
latent_log_var = LogVarianceType.UNIFORM
|
||||
patch_size = 4
|
||||
|
||||
try:
|
||||
with safe_open(str(weights_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
|
||||
# Parse encoder blocks
|
||||
raw_blocks = vae_config.get("encoder_blocks", [])
|
||||
for block in raw_blocks:
|
||||
if isinstance(block, list) and len(block) == 2:
|
||||
name, params = block
|
||||
encoder_blocks.append((name, params))
|
||||
|
||||
# Parse other config
|
||||
norm_str = vae_config.get("norm_layer", "pixel_norm")
|
||||
norm_layer = NormLayerType.PIXEL_NORM if norm_str == "pixel_norm" else NormLayerType.GROUP_NORM
|
||||
|
||||
var_str = vae_config.get("latent_log_var", "uniform")
|
||||
if var_str == "uniform":
|
||||
latent_log_var = LogVarianceType.UNIFORM
|
||||
elif var_str == "per_channel":
|
||||
latent_log_var = LogVarianceType.PER_CHANNEL
|
||||
elif var_str == "constant":
|
||||
latent_log_var = LogVarianceType.CONSTANT
|
||||
else:
|
||||
latent_log_var = LogVarianceType.NONE
|
||||
|
||||
patch_size = vae_config.get("patch_size", 4)
|
||||
|
||||
print(f" Loaded config: {len(encoder_blocks)} encoder blocks, norm={norm_str}, patch_size={patch_size}")
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}")
|
||||
# Use default config
|
||||
encoder_blocks = [
|
||||
("res_x", {"num_layers": 4}),
|
||||
("compress_space_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_time_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 6}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
("compress_all_res", {"multiplier": 2}),
|
||||
("res_x", {"num_layers": 2}),
|
||||
]
|
||||
print(f" Using default encoder config with {len(encoder_blocks)} blocks")
|
||||
|
||||
# Create encoder
|
||||
encoder = VideoEncoder(
|
||||
convolution_dimensions=3,
|
||||
in_channels=3,
|
||||
out_channels=128,
|
||||
encoder_blocks=encoder_blocks,
|
||||
patch_size=patch_size,
|
||||
norm_layer=norm_layer,
|
||||
latent_log_var=latent_log_var,
|
||||
encoder_spatial_padding_mode=PaddingModeType.ZEROS,
|
||||
)
|
||||
|
||||
# Load weights
|
||||
weights = mx.load(str(weights_path))
|
||||
|
||||
# Determine prefix based on weight keys
|
||||
has_vae_prefix = any(k.startswith("vae.") for k in weights.keys())
|
||||
|
||||
if has_vae_prefix:
|
||||
prefix = "vae.encoder."
|
||||
stats_prefix = "vae.per_channel_statistics."
|
||||
else:
|
||||
prefix = "encoder."
|
||||
stats_prefix = "per_channel_statistics."
|
||||
|
||||
# Load per-channel statistics for normalization
|
||||
mean_key = f"{stats_prefix}mean-of-means"
|
||||
std_key = f"{stats_prefix}std-of-means"
|
||||
|
||||
if mean_key in weights:
|
||||
encoder.per_channel_statistics.mean = weights[mean_key]
|
||||
print(f" Loaded latent mean: shape {weights[mean_key].shape}")
|
||||
if std_key in weights:
|
||||
encoder.per_channel_statistics.std = weights[std_key]
|
||||
print(f" Loaded latent std: shape {weights[std_key].shape}")
|
||||
|
||||
# Build encoder weights dict with key remapping
|
||||
encoder_weights = {}
|
||||
for key, value in weights.items():
|
||||
if not key.startswith(prefix):
|
||||
continue
|
||||
|
||||
# Remove prefix
|
||||
new_key = key[len(prefix):]
|
||||
|
||||
# Handle Conv3d weight transpose: (O, I, D, H, W) -> (O, D, H, W, I)
|
||||
if ".weight" in key and value.ndim == 5:
|
||||
value = mx.transpose(value, (0, 2, 3, 4, 1))
|
||||
|
||||
encoder_weights[new_key] = value
|
||||
|
||||
print(f" Found {len(encoder_weights)} encoder weights")
|
||||
|
||||
# Load weights
|
||||
encoder.load_weights(list(encoder_weights.items()), strict=False)
|
||||
|
||||
print("VAE encoder loaded successfully")
|
||||
return encoder
|
||||
|
||||
|
||||
def encode_image(
|
||||
|
||||
@@ -273,9 +273,10 @@ class VideoEncoder(nn.Module):
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
|
||||
# Build encoder blocks - use dict with int keys for MLX parameter tracking
|
||||
# Build encoder blocks
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.down_blocks = {}
|
||||
for i, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
for idx, (block_name, block_params) in enumerate(encoder_blocks):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_encoder_block(
|
||||
@@ -287,7 +288,7 @@ class VideoEncoder(nn.Module):
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=encoder_spatial_padding_mode,
|
||||
)
|
||||
self.down_blocks[i] = block
|
||||
self.down_blocks[idx] = block
|
||||
|
||||
# Output normalization and convolution
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
@@ -341,7 +342,8 @@ class VideoEncoder(nn.Module):
|
||||
sample = self.conv_in(sample, causal=True)
|
||||
|
||||
# Process through encoder blocks
|
||||
for down_block in self.down_blocks.values():
|
||||
for i in range(len(self.down_blocks)):
|
||||
down_block = self.down_blocks[i]
|
||||
if isinstance(down_block, (UNetMidBlock3D, ResnetBlock3D)):
|
||||
sample = down_block(sample, causal=True)
|
||||
else:
|
||||
@@ -440,8 +442,9 @@ class VideoDecoder(nn.Module):
|
||||
)
|
||||
|
||||
# Build decoder blocks (reversed order)
|
||||
self.up_blocks = []
|
||||
for block_name, block_params in list(reversed(decoder_blocks)):
|
||||
# Use dict with int keys for MLX to track parameters (lists are NOT tracked)
|
||||
self.up_blocks = {}
|
||||
for idx, (block_name, block_params) in enumerate(reversed(decoder_blocks)):
|
||||
block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params
|
||||
|
||||
block, feature_channels = _make_decoder_block(
|
||||
@@ -454,7 +457,7 @@ class VideoDecoder(nn.Module):
|
||||
norm_num_groups=self._norm_num_groups,
|
||||
spatial_padding_mode=decoder_spatial_padding_mode,
|
||||
)
|
||||
self.up_blocks.append(block)
|
||||
self.up_blocks[idx] = block
|
||||
|
||||
# Output normalization
|
||||
if norm_layer == NormLayerType.GROUP_NORM:
|
||||
@@ -509,7 +512,8 @@ class VideoDecoder(nn.Module):
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
# Process through decoder blocks
|
||||
for up_block in self.up_blocks:
|
||||
for i in range(len(self.up_blocks)):
|
||||
up_block = self.up_blocks[i]
|
||||
if isinstance(up_block, UNetMidBlock3D):
|
||||
sample = up_block(sample, causal=self.causal)
|
||||
elif isinstance(up_block, ResnetBlock3D):
|
||||
|
||||
Reference in New Issue
Block a user