Files
ltx2-mps/patch_mps.py
Norbert Schmidt c75e87b9be Initial release
2026-01-11 10:17:37 +01:00

106 lines
3.1 KiB
Python

#!/usr/bin/env python3
"""
LTX-2 MPS Patcher
Patches the diffusers library to enable LTX-2 on Apple Silicon (MPS).
The issue is that LTX-2 uses float64 for RoPE calculations, but MPS doesn't support float64.
This script forces float32 which works fine for video generation.
Usage:
python patch_mps.py
Requirements:
- diffusers (dev version with LTX2Pipeline)
- pip install git+https://github.com/huggingface/diffusers.git
"""
import os
import sys
import site
def find_diffusers_path():
"""Find the diffusers installation path."""
for path in site.getsitepackages():
diffusers_path = os.path.join(path, "diffusers")
if os.path.exists(diffusers_path):
return diffusers_path
# Check user site-packages
user_site = site.getusersitepackages()
if user_site:
diffusers_path = os.path.join(user_site, "diffusers")
if os.path.exists(diffusers_path):
return diffusers_path
return None
def patch_file(filepath, old_text, new_text, description):
"""Patch a file by replacing text."""
if not os.path.exists(filepath):
print(f" SKIP: {filepath} not found")
return False
with open(filepath, 'r') as f:
content = f.read()
if new_text in content:
print(f" OK: {description} (already patched)")
return True
if old_text not in content:
print(f" SKIP: {description} (pattern not found)")
return False
content = content.replace(old_text, new_text)
with open(filepath, 'w') as f:
f.write(content)
print(f" PATCHED: {description}")
return True
def main():
print("LTX-2 MPS Patcher")
print("=" * 50)
diffusers_path = find_diffusers_path()
if not diffusers_path:
print("ERROR: diffusers not found. Install it first:")
print(" pip install git+https://github.com/huggingface/diffusers.git")
sys.exit(1)
print(f"Found diffusers at: {diffusers_path}")
print()
# Patch 1: connectors.py
connectors_path = os.path.join(diffusers_path, "pipelines", "ltx2", "connectors.py")
patch_file(
connectors_path,
"freqs_dtype = torch.float64 if self.double_precision else torch.float32",
"# MPS fix: force float32 as MPS doesn't support float64\n freqs_dtype = torch.float32",
"connectors.py RoPE dtype"
)
# Patch 2: transformer_ltx2.py
transformer_path = os.path.join(diffusers_path, "models", "transformers", "transformer_ltx2.py")
patch_file(
transformer_path,
" # 3. Create a 1D grid of frequencies for RoPE\n freqs_dtype = torch.float64 if self.double_precision else torch.float32",
" # 3. Create a 1D grid of frequencies for RoPE\n # MPS fix: force float32 as MPS doesn't support float64\n freqs_dtype = torch.float32",
"transformer_ltx2.py RoPE dtype"
)
print()
print("Done! LTX-2 should now work on Apple Silicon MPS.")
print()
print("Test with:")
print(" python generate.py 'A cat walking' -o test.mp4")
if __name__ == "__main__":
main()