fix loading

This commit is contained in:
Prince Canuma
2026-01-24 01:37:38 +01:00
parent ef76ec0921
commit cb2d19c84d
2 changed files with 22 additions and 24 deletions

View File

@@ -349,6 +349,8 @@ class LTX2VideoDecoder(nn.Module):
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Build decoder weights dict with key remapping
sanitized = {}
if "per_channel_statistics.mean" in weights:
return weights
for key, value in weights.items():
new_key = key
@@ -399,21 +401,21 @@ class LTX2VideoDecoder(nn.Module):
import json
model_path = Path(model_path)
config_dict = {}
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
if model_path.is_dir():
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Load weights from directory
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
# Load weights from directory
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))