add from pretrained
This commit is contained in:
@@ -273,6 +273,7 @@ class VideoDecoderModelConfig(BaseModelConfig):
|
||||
norm_type: Enum = None
|
||||
causality_axis: Enum = None
|
||||
dropout: float = 0.0
|
||||
timestep_conditioning: bool = False
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderModelConfig(BaseModelConfig):
|
||||
|
||||
@@ -385,28 +385,38 @@ class LTX2VideoDecoder(nn.Module):
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
from safetensors import safe_open
|
||||
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
||||
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||
|
||||
Args:
|
||||
model_path: Path to directory containing config.json and safetensors files,
|
||||
or path to a single safetensors file.
|
||||
strict: Whether to require all weight keys to match.
|
||||
|
||||
Returns:
|
||||
Loaded LTX2VideoDecoder instance
|
||||
"""
|
||||
import json
|
||||
weights = mx.load(str(model_path))
|
||||
|
||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
||||
if timestep_conditioning is None:
|
||||
try:
|
||||
with safe_open(str(model_path), framework="numpy") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata and "config" in metadata:
|
||||
configs = json.loads(metadata["config"])
|
||||
vae_config = configs.get("vae", {})
|
||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
||||
else:
|
||||
timestep_conditioning = False
|
||||
except Exception as e:
|
||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
||||
timestep_conditioning = False
|
||||
model_path = Path(model_path)
|
||||
|
||||
model = cls(timestep_conditioning=timestep_conditioning)
|
||||
if model_path.is_dir():
|
||||
# Load config from directory
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Load weights from directory
|
||||
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(str(wf)))
|
||||
|
||||
|
||||
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
|
||||
weights = model.sanitize(weights)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user