fix loading
This commit is contained in:
@@ -349,6 +349,8 @@ class LTX2VideoDecoder(nn.Module):
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Build decoder weights dict with key remapping
|
||||
sanitized = {}
|
||||
if "per_channel_statistics.mean" in weights:
|
||||
return weights
|
||||
for key, value in weights.items():
|
||||
new_key = key
|
||||
|
||||
@@ -399,21 +401,21 @@ class LTX2VideoDecoder(nn.Module):
|
||||
import json
|
||||
|
||||
model_path = Path(model_path)
|
||||
config_dict = {}
|
||||
|
||||
# Load config from directory
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
if model_path.is_dir():
|
||||
# Load config from directory
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Load weights from directory
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
# Load weights from directory
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
|
||||
|
||||
Reference in New Issue
Block a user