add from pretrained
This commit is contained in:
@@ -273,6 +273,7 @@ class VideoDecoderModelConfig(BaseModelConfig):
|
|||||||
norm_type: Enum = None
|
norm_type: Enum = None
|
||||||
causality_axis: Enum = None
|
causality_axis: Enum = None
|
||||||
dropout: float = 0.0
|
dropout: float = 0.0
|
||||||
|
timestep_conditioning: bool = False
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VideoEncoderModelConfig(BaseModelConfig):
|
class VideoEncoderModelConfig(BaseModelConfig):
|
||||||
|
|||||||
@@ -385,28 +385,38 @@ class LTX2VideoDecoder(nn.Module):
|
|||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder":
|
def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder":
|
||||||
from safetensors import safe_open
|
"""Load a pretrained decoder from a directory with config.json and weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to directory containing config.json and safetensors files,
|
||||||
|
or path to a single safetensors file.
|
||||||
|
strict: Whether to require all weight keys to match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded LTX2VideoDecoder instance
|
||||||
|
"""
|
||||||
import json
|
import json
|
||||||
weights = mx.load(str(model_path))
|
|
||||||
|
|
||||||
# Read config from safetensors metadata to auto-detect timestep_conditioning
|
model_path = Path(model_path)
|
||||||
if timestep_conditioning is None:
|
|
||||||
try:
|
|
||||||
with safe_open(str(model_path), framework="numpy") as f:
|
|
||||||
metadata = f.metadata()
|
|
||||||
if metadata and "config" in metadata:
|
|
||||||
configs = json.loads(metadata["config"])
|
|
||||||
vae_config = configs.get("vae", {})
|
|
||||||
timestep_conditioning = vae_config.get("timestep_conditioning", False)
|
|
||||||
print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights")
|
|
||||||
else:
|
|
||||||
timestep_conditioning = False
|
|
||||||
except Exception as e:
|
|
||||||
print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False")
|
|
||||||
timestep_conditioning = False
|
|
||||||
|
|
||||||
model = cls(timestep_conditioning=timestep_conditioning)
|
if model_path.is_dir():
|
||||||
|
# Load config from directory
|
||||||
|
config_path = model_path / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
with open(config_path) as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
|
||||||
|
# Load weights from directory
|
||||||
|
weight_files = sorted(model_path.glob("*.safetensors"))
|
||||||
|
if not weight_files:
|
||||||
|
raise FileNotFoundError(f"No safetensors files found in {model_path}")
|
||||||
|
weights = {}
|
||||||
|
for wf in weight_files:
|
||||||
|
weights.update(mx.load(str(wf)))
|
||||||
|
|
||||||
|
|
||||||
|
model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False))
|
||||||
weights = model.sanitize(weights)
|
weights = model.sanitize(weights)
|
||||||
model.load_weights(list(weights.items()), strict=strict)
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user