100 lines
2.8 KiB
Python
100 lines
2.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
patches diffusers to run ltx-2 on apple silicon (mps).
|
|
|
|
ltx-2 uses float64 for rope, but mps doesn't support it.
|
|
this forces float32 instead - works fine.
|
|
|
|
usage: python patch_mps.py
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import site
|
|
|
|
|
|
def find_diffusers_path():
|
|
"""find where diffusers is installed"""
|
|
for path in site.getsitepackages():
|
|
diffusers_path = os.path.join(path, "diffusers")
|
|
if os.path.exists(diffusers_path):
|
|
return diffusers_path
|
|
|
|
# Check user site-packages
|
|
user_site = site.getusersitepackages()
|
|
if user_site:
|
|
diffusers_path = os.path.join(user_site, "diffusers")
|
|
if os.path.exists(diffusers_path):
|
|
return diffusers_path
|
|
|
|
return None
|
|
|
|
|
|
def patch_file(filepath, old_text, new_text, description):
|
|
"""replace text in a file"""
|
|
if not os.path.exists(filepath):
|
|
print(f" skip: {filepath} not found")
|
|
return False
|
|
|
|
with open(filepath, 'r') as f:
|
|
content = f.read()
|
|
|
|
if new_text in content:
|
|
print(f" ok: {description} (already patched)")
|
|
return True
|
|
|
|
if old_text not in content:
|
|
print(f" skip: {description} (pattern not found)")
|
|
return False
|
|
|
|
content = content.replace(old_text, new_text)
|
|
|
|
with open(filepath, 'w') as f:
|
|
f.write(content)
|
|
|
|
print(f" patched: {description}")
|
|
return True
|
|
|
|
|
|
def main():
|
|
print("ltx-2 mps patcher")
|
|
print("-" * 40)
|
|
|
|
diffusers_path = find_diffusers_path()
|
|
|
|
if not diffusers_path:
|
|
print("error: diffusers not found")
|
|
print(" pip install git+https://github.com/huggingface/diffusers.git")
|
|
sys.exit(1)
|
|
|
|
print(f"found diffusers at: {diffusers_path}")
|
|
print()
|
|
|
|
# Patch 1: connectors.py
|
|
connectors_path = os.path.join(diffusers_path, "pipelines", "ltx2", "connectors.py")
|
|
patch_file(
|
|
connectors_path,
|
|
"freqs_dtype = torch.float64 if self.double_precision else torch.float32",
|
|
"# MPS fix: force float32 as MPS doesn't support float64\n freqs_dtype = torch.float32",
|
|
"connectors.py RoPE dtype"
|
|
)
|
|
|
|
# Patch 2: transformer_ltx2.py
|
|
transformer_path = os.path.join(diffusers_path, "models", "transformers", "transformer_ltx2.py")
|
|
patch_file(
|
|
transformer_path,
|
|
" # 3. Create a 1D grid of frequencies for RoPE\n freqs_dtype = torch.float64 if self.double_precision else torch.float32",
|
|
" # 3. Create a 1D grid of frequencies for RoPE\n # MPS fix: force float32 as MPS doesn't support float64\n freqs_dtype = torch.float32",
|
|
"transformer_ltx2.py RoPE dtype"
|
|
)
|
|
|
|
print()
|
|
print("done. ltx-2 should work on mps now.")
|
|
print()
|
|
print("test with:")
|
|
print(" python generate.py 'a cat walking' -o test.mp4")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|