add from pretrained

This commit is contained in:
Prince Canuma
2026-01-23 18:13:51 +01:00
parent ce39e744c3
commit ef76ec0921
2 changed files with 30 additions and 19 deletions

View File

@@ -385,28 +385,38 @@ class LTX2VideoDecoder(nn.Module):
return sanitized
@classmethod
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
from safetensors import safe_open
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
"""Load a pretrained decoder from a directory with config.json and weights.
Args:
model_path: Path to directory containing config.json and safetensors files,
or path to a single safetensors file.
strict: Whether to require all weight keys to match.
Returns:
Loaded LTX2VideoDecoder instance
"""
import json
weights = mx.load(str(model_path))
# Read config from safetensors metadata to auto-detect timestep_conditioning
if timestep_conditioning is None:
try:
with safe_open(str(model_path), framework="numpy") as f:
metadata = f.metadata()
if metadata and "config" in metadata:
configs = json.loads(metadata["config"])
vae_config = configs.get("vae", {})
timestep_conditioning = vae_config.get("timestep_conditioning", False)
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
else:
timestep_conditioning = False
except Exception as e:
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
timestep_conditioning = False
model_path = Path(model_path)
model = cls(timestep_conditioning=timestep_conditioning)
if model_path.is_dir():
# Load config from directory
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
# Load weights from directory
weight_files = sorted(model_path.glob("*.safetensors"))
if not weight_files:
raise FileNotFoundError(f"No safetensors files found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(str(wf)))
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
weights = model.sanitize(weights)
model.load_weights(list(weights.items()), strict=strict)
return model