This commit is contained in:
Norbert Schmidt
2026-01-11 11:01:07 +01:00
parent c75e87b9be
commit 707f2447cf
3 changed files with 101 additions and 143 deletions

131
README.md
View File

@@ -1,134 +1,115 @@
# LTX-2 on Apple Silicon (MPS) # ltx2-mps
Run [Lightricks LTX-2](https://huggingface.co/Lightricks/LTX-2) video generation on Mac with Apple Silicon using Metal Performance Shaders (MPS). run [LTX-2](https://huggingface.co/Lightricks/LTX-2) video generation on mac using MPS (metal).
## The Problem ## what's this about
LTX-2 uses `float64` (double precision) for rotary position embeddings (RoPE), but Apple's MPS backend doesn't support float64 - only float32. This causes the error: LTX-2 uses float64 for rotary position embeddings, but MPS doesn't support float64. you get this error:
``` ```
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64
``` ```
## The Solution this repo patches diffusers to use float32 instead. works fine, no noticeable quality loss.
This repo provides a patch that forces `float32` for RoPE calculations. The quality difference is negligible, and it enables LTX-2 to run on Mac. ## requirements
## Requirements - mac with apple silicon (m1/m2/m3/m4)
- python 3.11+
- 64GB+ ram recommended (model is ~40GB)
- **macOS** with Apple Silicon (M1, M2, M3, M4 - any variant) ## setup
- **Python 3.11+**
- **64GB+ RAM recommended** (model is ~40GB, 128GB ideal for max settings)
- **PyTorch 2.0+**
## Quick Start
```bash ```bash
# 1. Clone this repo git clone https://github.com/Pocket-science/ltx2-mps.git
git clone https://github.com/YOUR_USERNAME/ltx2-mps.git
cd ltx2-mps cd ltx2-mps
# 2. Create virtual environment
python3 -m venv venv python3 -m venv venv
source venv/bin/activate source venv/bin/activate
# 3. Install dependencies
pip install torch torchvision torchaudio pip install torch torchvision torchaudio
pip install git+https://github.com/huggingface/diffusers.git pip install git+https://github.com/huggingface/diffusers.git
pip install transformers accelerate safetensors sentencepiece pip install transformers accelerate safetensors sentencepiece
pip install imageio imageio-ffmpeg pip install imageio imageio-ffmpeg
# 4. Apply MPS patches
python patch_mps.py python patch_mps.py
# 5. Generate a video!
python generate.py "A cat walking through grass" -o output.mp4
``` ```
## Usage ## usage
```bash ```bash
python generate.py "Your prompt here" -o output.mp4 [options] python generate.py "a cat walking through grass" -o output.mp4
``` ```
### Options ### options
| Option | Default | Description | | flag | default | description |
|--------|---------|-------------| |------|---------|-------------|
| `--width` | 512 | Video width (must be divisible by 32) | | `--width` | 512 | video width (divisible by 32) |
| `--height` | 320 | Video height (must be divisible by 32) | | `--height` | 320 | video height (divisible by 32) |
| `--frames` | 25 | Number of frames (must be 8n+1: 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97) | | `--frames` | 25 | frame count (must be 8n+1: 9, 17, 25, 33...) |
| `--steps` | 20 | Inference steps (more = better quality, slower) | | `--steps` | 20 | inference steps |
| `--guidance` | 5.0 | Guidance scale | | `--guidance` | 5.0 | guidance scale |
| `--fps` | 24 | Output video FPS | | `--fps` | 24 | output fps |
| `--seed` | random | Random seed for reproducibility | | `--seed` | random | seed for reproducibility |
| `-n` | "" | Negative prompt | | `-n` | "" | negative prompt |
### Examples ### examples
```bash ```bash
# Quick preview (fast) # quick test
python generate.py "A sunset over mountains" -o preview.mp4 --frames 25 --steps 10 --width 512 --height 320 python generate.py "sunset over mountains" -o test.mp4 --steps 10
# Standard quality # higher quality
python generate.py "A dog running on the beach" -o standard.mp4 --frames 49 --steps 20 --width 768 --height 448 python generate.py "dog running on beach" -o video.mp4 --frames 49 --steps 20 --width 768 --height 448
# High quality (slow, needs 128GB RAM) # max quality (needs 128GB ram, takes ~30 min)
python generate.py "Cinematic shot of a forest" -o hq.mp4 --frames 97 --steps 30 --width 1024 --height 576 python generate.py "cinematic forest shot" -o hq.mp4 --frames 97 --steps 30 --width 1024 --height 576
``` ```
## Performance ## performance
Tested on Mac with M-series chips: tested on m3 ultra:
| Resolution | Frames | Steps | Time (approx) | RAM Usage | | resolution | frames | steps | time |
|------------|--------|-------|---------------|-----------| |------------|--------|-------|------|
| 512x320 | 25 | 10 | ~1 min | ~45GB | | 512x320 | 25 | 10 | ~1 min |
| 768x448 | 49 | 20 | ~10 min | ~60GB | | 768x448 | 49 | 20 | ~10 min |
| 1024x576 | 97 | 30 | ~30 min | ~80GB | | 1024x576 | 97 | 30 | ~30 min |
## How the Patch Works ## how the patch works
Two files in diffusers are patched: two files get patched in diffusers:
### 1. `diffusers/pipelines/ltx2/connectors.py` **diffusers/pipelines/ltx2/connectors.py**
```python ```python
# Before: # before
freqs_dtype = torch.float64 if self.double_precision else torch.float32 freqs_dtype = torch.float64 if self.double_precision else torch.float32
# After: # after
freqs_dtype = torch.float32 # MPS fix freqs_dtype = torch.float32
``` ```
### 2. `diffusers/models/transformers/transformer_ltx2.py` **diffusers/models/transformers/transformer_ltx2.py**
```python ```python
# Before: # same change
freqs_dtype = torch.float64 if self.double_precision else torch.float32 freqs_dtype = torch.float32
# After:
freqs_dtype = torch.float32 # MPS fix
``` ```
## Troubleshooting ## troubleshooting
### "MPS backend out of memory" **out of memory** - reduce resolution/frames or close other apps
- Reduce resolution, frames, or close other apps
- Try `--width 512 --height 320 --frames 25`
### Model download fails **model download fails** - it's ~40GB, first run takes a while
- Check your internet connection
- The model is ~40GB, first run takes a while to download
### Import errors **import errors** - make sure you installed diffusers from git, not pip
- Make sure you installed diffusers from git (dev version needed for LTX2Pipeline)
- Run `pip install git+https://github.com/huggingface/diffusers.git`
## Credits ## credits
- [Lightricks](https://github.com/Lightricks) for LTX-2 - [lightricks](https://github.com/Lightricks) for ltx-2
- [Hugging Face](https://github.com/huggingface/diffusers) for diffusers - [@ivanfioravanti](https://twitter.com/ivanfioravanti) for the mps fix approach
- MPS patch discovered while debugging with Claude - [huggingface](https://github.com/huggingface/diffusers) for diffusers
## License ## license
MIT MIT

View File

@@ -1,19 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
LTX-2 Video Generator for Apple Silicon (MPS) ltx-2 video generator for mps
Usage: usage: python generate.py "your prompt" -o output.mp4
python generate.py "Your prompt here" -o output.mp4 [options]
Options:
--width Video width (default: 512, must be divisible by 32)
--height Video height (default: 320, must be divisible by 32)
--frames Number of frames (default: 25, must be 8n+1)
--steps Inference steps (default: 20)
--guidance Guidance scale (default: 5.0)
--fps Output FPS (default: 24)
--seed Random seed (optional)
-n Negative prompt (optional)
""" """
import argparse import argparse
@@ -25,64 +14,59 @@ from diffusers.utils import export_to_video
def main(): def main():
parser = argparse.ArgumentParser(description="LTX-2 Video Generator for MPS") parser = argparse.ArgumentParser(description="ltx-2 video generator for mps")
parser.add_argument("prompt", help="Text prompt for video generation") parser.add_argument("prompt", help="text prompt")
parser.add_argument("-o", "--output", default="output.mp4", help="Output video path") parser.add_argument("-o", "--output", default="output.mp4", help="output path")
parser.add_argument("-n", "--negative-prompt", default="", help="Negative prompt") parser.add_argument("-n", "--negative-prompt", default="", help="negative prompt")
parser.add_argument("--steps", type=int, default=20, help="Inference steps") parser.add_argument("--steps", type=int, default=20, help="inference steps")
parser.add_argument("--guidance", type=float, default=5.0, help="Guidance scale") parser.add_argument("--guidance", type=float, default=5.0, help="guidance scale")
parser.add_argument("--width", type=int, default=512, help="Video width") parser.add_argument("--width", type=int, default=512, help="video width")
parser.add_argument("--height", type=int, default=320, help="Video height") parser.add_argument("--height", type=int, default=320, help="video height")
parser.add_argument("--frames", type=int, default=25, help="Number of frames") parser.add_argument("--frames", type=int, default=25, help="frame count")
parser.add_argument("--fps", type=int, default=24, help="Frames per second") parser.add_argument("--fps", type=int, default=24, help="output fps")
parser.add_argument("--seed", type=int, default=None, help="Random seed") parser.add_argument("--seed", type=int, default=None, help="random seed")
args = parser.parse_args() args = parser.parse_args()
# Validate dimensions
if args.width % 32 != 0: if args.width % 32 != 0:
print(f"Error: width must be divisible by 32 (got {args.width})") print(f"error: width must be divisible by 32 (got {args.width})")
sys.exit(1) sys.exit(1)
if args.height % 32 != 0: if args.height % 32 != 0:
print(f"Error: height must be divisible by 32 (got {args.height})") print(f"error: height must be divisible by 32 (got {args.height})")
sys.exit(1) sys.exit(1)
if (args.frames - 1) % 8 != 0: if (args.frames - 1) % 8 != 0:
valid = [8*i + 1 for i in range(1, 13)] valid = [8*i + 1 for i in range(1, 13)]
print(f"Error: frames must be 8n+1 (valid: {valid})") print(f"error: frames must be 8n+1 (valid: {valid})")
sys.exit(1) sys.exit(1)
# Check MPS availability
if not torch.backends.mps.is_available(): if not torch.backends.mps.is_available():
print("Warning: MPS not available, falling back to CPU (will be slow)") print("warning: mps not available, using cpu (slow)")
device = "cpu" device = "cpu"
else: else:
device = "mps" device = "mps"
print(f"Using MPS (Apple Silicon GPU)") print("using mps")
# Load model print("loading model...")
print("Loading LTX-2 model (this may take a while on first run)...")
pipe = LTX2Pipeline.from_pretrained( pipe = LTX2Pipeline.from_pretrained(
"Lightricks/LTX-2", "Lightricks/LTX-2",
torch_dtype=torch.bfloat16 torch_dtype=torch.bfloat16
) )
pipe.to(device) pipe.to(device)
print("Model loaded!") print("model loaded")
# Set up generator
if args.seed is None: if args.seed is None:
args.seed = torch.randint(0, 2**31, (1,)).item() args.seed = torch.randint(0, 2**31, (1,)).item()
generator = torch.Generator(device="cpu") # CPU generator more stable with MPS generator = torch.Generator(device="cpu")
generator.manual_seed(args.seed) generator.manual_seed(args.seed)
print(f"\nGenerating video...") print(f"\ngenerating...")
print(f" Prompt: {args.prompt}") print(f" prompt: {args.prompt}")
print(f" Size: {args.width}x{args.height}, {args.frames} frames") print(f" size: {args.width}x{args.height}, {args.frames} frames")
print(f" Steps: {args.steps}, Guidance: {args.guidance}") print(f" steps: {args.steps}, guidance: {args.guidance}")
print(f" Seed: {args.seed}") print(f" seed: {args.seed}")
print() print()
# Generate
result = pipe( result = pipe(
prompt=args.prompt, prompt=args.prompt,
negative_prompt=args.negative_prompt if args.negative_prompt else None, negative_prompt=args.negative_prompt if args.negative_prompt else None,
@@ -94,12 +78,11 @@ def main():
generator=generator, generator=generator,
) )
# Export video
video_frames = result.frames[0] video_frames = result.frames[0]
export_to_video(video_frames, args.output, fps=args.fps) export_to_video(video_frames, args.output, fps=args.fps)
print(f"\nVideo saved to: {args.output}") print(f"\nsaved to: {args.output}")
print(f"Seed: {args.seed}") print(f"seed: {args.seed}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,17 +1,11 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
LTX-2 MPS Patcher patches diffusers to run ltx-2 on apple silicon (mps).
Patches the diffusers library to enable LTX-2 on Apple Silicon (MPS). ltx-2 uses float64 for rope, but mps doesn't support it.
The issue is that LTX-2 uses float64 for RoPE calculations, but MPS doesn't support float64. this forces float32 instead - works fine.
This script forces float32 which works fine for video generation.
Usage: usage: python patch_mps.py
python patch_mps.py
Requirements:
- diffusers (dev version with LTX2Pipeline)
- pip install git+https://github.com/huggingface/diffusers.git
""" """
import os import os
@@ -20,7 +14,7 @@ import site
def find_diffusers_path(): def find_diffusers_path():
"""Find the diffusers installation path.""" """find where diffusers is installed"""
for path in site.getsitepackages(): for path in site.getsitepackages():
diffusers_path = os.path.join(path, "diffusers") diffusers_path = os.path.join(path, "diffusers")
if os.path.exists(diffusers_path): if os.path.exists(diffusers_path):
@@ -37,20 +31,20 @@ def find_diffusers_path():
def patch_file(filepath, old_text, new_text, description): def patch_file(filepath, old_text, new_text, description):
"""Patch a file by replacing text.""" """replace text in a file"""
if not os.path.exists(filepath): if not os.path.exists(filepath):
print(f" SKIP: {filepath} not found") print(f" skip: {filepath} not found")
return False return False
with open(filepath, 'r') as f: with open(filepath, 'r') as f:
content = f.read() content = f.read()
if new_text in content: if new_text in content:
print(f" OK: {description} (already patched)") print(f" ok: {description} (already patched)")
return True return True
if old_text not in content: if old_text not in content:
print(f" SKIP: {description} (pattern not found)") print(f" skip: {description} (pattern not found)")
return False return False
content = content.replace(old_text, new_text) content = content.replace(old_text, new_text)
@@ -58,22 +52,22 @@ def patch_file(filepath, old_text, new_text, description):
with open(filepath, 'w') as f: with open(filepath, 'w') as f:
f.write(content) f.write(content)
print(f" PATCHED: {description}") print(f" patched: {description}")
return True return True
def main(): def main():
print("LTX-2 MPS Patcher") print("ltx-2 mps patcher")
print("=" * 50) print("-" * 40)
diffusers_path = find_diffusers_path() diffusers_path = find_diffusers_path()
if not diffusers_path: if not diffusers_path:
print("ERROR: diffusers not found. Install it first:") print("error: diffusers not found")
print(" pip install git+https://github.com/huggingface/diffusers.git") print(" pip install git+https://github.com/huggingface/diffusers.git")
sys.exit(1) sys.exit(1)
print(f"Found diffusers at: {diffusers_path}") print(f"found diffusers at: {diffusers_path}")
print() print()
# Patch 1: connectors.py # Patch 1: connectors.py
@@ -95,10 +89,10 @@ def main():
) )
print() print()
print("Done! LTX-2 should now work on Apple Silicon MPS.") print("done. ltx-2 should work on mps now.")
print() print()
print("Test with:") print("test with:")
print(" python generate.py 'A cat walking' -o test.mp4") print(" python generate.py 'a cat walking' -o test.mp4")
if __name__ == "__main__": if __name__ == "__main__":