diff --git a/mlx_video/models/ltx/config.py b/mlx_video/models/ltx/config.py index 400a634..c63fcd7 100644 --- a/mlx_video/models/ltx/config.py +++ b/mlx_video/models/ltx/config.py @@ -273,6 +273,7 @@ class VideoDecoderModelConfig(BaseModelConfig): norm_type: Enum = None causality_axis: Enum = None dropout: float = 0.0 + timestep_conditioning: bool = False @dataclass class VideoEncoderModelConfig(BaseModelConfig): diff --git a/mlx_video/models/ltx/video_vae/decoder.py b/mlx_video/models/ltx/video_vae/decoder.py index f14cca0..4896c22 100644 --- a/mlx_video/models/ltx/video_vae/decoder.py +++ b/mlx_video/models/ltx/video_vae/decoder.py @@ -385,28 +385,38 @@ class LTX2VideoDecoder(nn.Module): return sanitized @classmethod - def from_pretrained(cls, model_path: Path, timestep_conditioning: Optional[bool] = None, strict: bool = True) -> "LTX2VideoDecoder": - from safetensors import safe_open + def from_pretrained(cls, model_path: Path, strict: bool = True) -> "LTX2VideoDecoder": + """Load a pretrained decoder from a directory with config.json and weights. + + Args: + model_path: Path to directory containing config.json and safetensors files, + or path to a single safetensors file. + strict: Whether to require all weight keys to match. + + Returns: + Loaded LTX2VideoDecoder instance + """ import json - weights = mx.load(str(model_path)) - # Read config from safetensors metadata to auto-detect timestep_conditioning - if timestep_conditioning is None: - try: - with safe_open(str(model_path), framework="numpy") as f: - metadata = f.metadata() - if metadata and "config" in metadata: - configs = json.loads(metadata["config"]) - vae_config = configs.get("vae", {}) - timestep_conditioning = vae_config.get("timestep_conditioning", False) - print(f" Auto-detected timestep_conditioning={timestep_conditioning} from weights") - else: - timestep_conditioning = False - except Exception as e: - print(f" Could not read config from metadata: {e}, defaulting to timestep_conditioning=False") - timestep_conditioning = False + model_path = Path(model_path) - model = cls(timestep_conditioning=timestep_conditioning) + if model_path.is_dir(): + # Load config from directory + config_path = model_path / "config.json" + if config_path.exists(): + with open(config_path) as f: + config_dict = json.load(f) + + # Load weights from directory + weight_files = sorted(model_path.glob("*.safetensors")) + if not weight_files: + raise FileNotFoundError(f"No safetensors files found in {model_path}") + weights = {} + for wf in weight_files: + weights.update(mx.load(str(wf))) + + + model = cls(timestep_conditioning=config_dict.get("timestep_conditioning", False)) weights = model.sanitize(weights) model.load_weights(list(weights.items()), strict=strict) return model