Initial release
This commit is contained in:
134
README.md
Normal file
134
README.md
Normal file
@@ -0,0 +1,134 @@
|
||||
# LTX-2 on Apple Silicon (MPS)
|
||||
|
||||
Run [Lightricks LTX-2](https://huggingface.co/Lightricks/LTX-2) video generation on Mac with Apple Silicon using Metal Performance Shaders (MPS).
|
||||
|
||||
## The Problem
|
||||
|
||||
LTX-2 uses `float64` (double precision) for rotary position embeddings (RoPE), but Apple's MPS backend doesn't support float64 - only float32. This causes the error:
|
||||
|
||||
```
|
||||
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64
|
||||
```
|
||||
|
||||
## The Solution
|
||||
|
||||
This repo provides a patch that forces `float32` for RoPE calculations. The quality difference is negligible, and it enables LTX-2 to run on Mac.
|
||||
|
||||
## Requirements
|
||||
|
||||
- **macOS** with Apple Silicon (M1, M2, M3, M4 - any variant)
|
||||
- **Python 3.11+**
|
||||
- **64GB+ RAM recommended** (model is ~40GB, 128GB ideal for max settings)
|
||||
- **PyTorch 2.0+**
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Clone this repo
|
||||
git clone https://github.com/YOUR_USERNAME/ltx2-mps.git
|
||||
cd ltx2-mps
|
||||
|
||||
# 2. Create virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# 3. Install dependencies
|
||||
pip install torch torchvision torchaudio
|
||||
pip install git+https://github.com/huggingface/diffusers.git
|
||||
pip install transformers accelerate safetensors sentencepiece
|
||||
pip install imageio imageio-ffmpeg
|
||||
|
||||
# 4. Apply MPS patches
|
||||
python patch_mps.py
|
||||
|
||||
# 5. Generate a video!
|
||||
python generate.py "A cat walking through grass" -o output.mp4
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python generate.py "Your prompt here" -o output.mp4 [options]
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `--width` | 512 | Video width (must be divisible by 32) |
|
||||
| `--height` | 320 | Video height (must be divisible by 32) |
|
||||
| `--frames` | 25 | Number of frames (must be 8n+1: 9, 17, 25, 33, 41, 49, 57, 65, 73, 81, 89, 97) |
|
||||
| `--steps` | 20 | Inference steps (more = better quality, slower) |
|
||||
| `--guidance` | 5.0 | Guidance scale |
|
||||
| `--fps` | 24 | Output video FPS |
|
||||
| `--seed` | random | Random seed for reproducibility |
|
||||
| `-n` | "" | Negative prompt |
|
||||
|
||||
### Examples
|
||||
|
||||
```bash
|
||||
# Quick preview (fast)
|
||||
python generate.py "A sunset over mountains" -o preview.mp4 --frames 25 --steps 10 --width 512 --height 320
|
||||
|
||||
# Standard quality
|
||||
python generate.py "A dog running on the beach" -o standard.mp4 --frames 49 --steps 20 --width 768 --height 448
|
||||
|
||||
# High quality (slow, needs 128GB RAM)
|
||||
python generate.py "Cinematic shot of a forest" -o hq.mp4 --frames 97 --steps 30 --width 1024 --height 576
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
Tested on Mac with M-series chips:
|
||||
|
||||
| Resolution | Frames | Steps | Time (approx) | RAM Usage |
|
||||
|------------|--------|-------|---------------|-----------|
|
||||
| 512x320 | 25 | 10 | ~1 min | ~45GB |
|
||||
| 768x448 | 49 | 20 | ~10 min | ~60GB |
|
||||
| 1024x576 | 97 | 30 | ~30 min | ~80GB |
|
||||
|
||||
## How the Patch Works
|
||||
|
||||
Two files in diffusers are patched:
|
||||
|
||||
### 1. `diffusers/pipelines/ltx2/connectors.py`
|
||||
```python
|
||||
# Before:
|
||||
freqs_dtype = torch.float64 if self.double_precision else torch.float32
|
||||
|
||||
# After:
|
||||
freqs_dtype = torch.float32 # MPS fix
|
||||
```
|
||||
|
||||
### 2. `diffusers/models/transformers/transformer_ltx2.py`
|
||||
```python
|
||||
# Before:
|
||||
freqs_dtype = torch.float64 if self.double_precision else torch.float32
|
||||
|
||||
# After:
|
||||
freqs_dtype = torch.float32 # MPS fix
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "MPS backend out of memory"
|
||||
- Reduce resolution, frames, or close other apps
|
||||
- Try `--width 512 --height 320 --frames 25`
|
||||
|
||||
### Model download fails
|
||||
- Check your internet connection
|
||||
- The model is ~40GB, first run takes a while to download
|
||||
|
||||
### Import errors
|
||||
- Make sure you installed diffusers from git (dev version needed for LTX2Pipeline)
|
||||
- Run `pip install git+https://github.com/huggingface/diffusers.git`
|
||||
|
||||
## Credits
|
||||
|
||||
- [Lightricks](https://github.com/Lightricks) for LTX-2
|
||||
- [Hugging Face](https://github.com/huggingface/diffusers) for diffusers
|
||||
- MPS patch discovered while debugging with Claude
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
106
generate.py
Normal file
106
generate.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LTX-2 Video Generator for Apple Silicon (MPS)
|
||||
|
||||
Usage:
|
||||
python generate.py "Your prompt here" -o output.mp4 [options]
|
||||
|
||||
Options:
|
||||
--width Video width (default: 512, must be divisible by 32)
|
||||
--height Video height (default: 320, must be divisible by 32)
|
||||
--frames Number of frames (default: 25, must be 8n+1)
|
||||
--steps Inference steps (default: 20)
|
||||
--guidance Guidance scale (default: 5.0)
|
||||
--fps Output FPS (default: 24)
|
||||
--seed Random seed (optional)
|
||||
-n Negative prompt (optional)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from diffusers import LTX2Pipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="LTX-2 Video Generator for MPS")
|
||||
parser.add_argument("prompt", help="Text prompt for video generation")
|
||||
parser.add_argument("-o", "--output", default="output.mp4", help="Output video path")
|
||||
parser.add_argument("-n", "--negative-prompt", default="", help="Negative prompt")
|
||||
parser.add_argument("--steps", type=int, default=20, help="Inference steps")
|
||||
parser.add_argument("--guidance", type=float, default=5.0, help="Guidance scale")
|
||||
parser.add_argument("--width", type=int, default=512, help="Video width")
|
||||
parser.add_argument("--height", type=int, default=320, help="Video height")
|
||||
parser.add_argument("--frames", type=int, default=25, help="Number of frames")
|
||||
parser.add_argument("--fps", type=int, default=24, help="Frames per second")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate dimensions
|
||||
if args.width % 32 != 0:
|
||||
print(f"Error: width must be divisible by 32 (got {args.width})")
|
||||
sys.exit(1)
|
||||
if args.height % 32 != 0:
|
||||
print(f"Error: height must be divisible by 32 (got {args.height})")
|
||||
sys.exit(1)
|
||||
if (args.frames - 1) % 8 != 0:
|
||||
valid = [8*i + 1 for i in range(1, 13)]
|
||||
print(f"Error: frames must be 8n+1 (valid: {valid})")
|
||||
sys.exit(1)
|
||||
|
||||
# Check MPS availability
|
||||
if not torch.backends.mps.is_available():
|
||||
print("Warning: MPS not available, falling back to CPU (will be slow)")
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "mps"
|
||||
print(f"Using MPS (Apple Silicon GPU)")
|
||||
|
||||
# Load model
|
||||
print("Loading LTX-2 model (this may take a while on first run)...")
|
||||
pipe = LTX2Pipeline.from_pretrained(
|
||||
"Lightricks/LTX-2",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to(device)
|
||||
print("Model loaded!")
|
||||
|
||||
# Set up generator
|
||||
if args.seed is None:
|
||||
args.seed = torch.randint(0, 2**31, (1,)).item()
|
||||
|
||||
generator = torch.Generator(device="cpu") # CPU generator more stable with MPS
|
||||
generator.manual_seed(args.seed)
|
||||
|
||||
print(f"\nGenerating video...")
|
||||
print(f" Prompt: {args.prompt}")
|
||||
print(f" Size: {args.width}x{args.height}, {args.frames} frames")
|
||||
print(f" Steps: {args.steps}, Guidance: {args.guidance}")
|
||||
print(f" Seed: {args.seed}")
|
||||
print()
|
||||
|
||||
# Generate
|
||||
result = pipe(
|
||||
prompt=args.prompt,
|
||||
negative_prompt=args.negative_prompt if args.negative_prompt else None,
|
||||
num_inference_steps=args.steps,
|
||||
guidance_scale=args.guidance,
|
||||
width=args.width,
|
||||
height=args.height,
|
||||
num_frames=args.frames,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# Export video
|
||||
video_frames = result.frames[0]
|
||||
export_to_video(video_frames, args.output, fps=args.fps)
|
||||
|
||||
print(f"\nVideo saved to: {args.output}")
|
||||
print(f"Seed: {args.seed}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
105
patch_mps.py
Normal file
105
patch_mps.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LTX-2 MPS Patcher
|
||||
|
||||
Patches the diffusers library to enable LTX-2 on Apple Silicon (MPS).
|
||||
The issue is that LTX-2 uses float64 for RoPE calculations, but MPS doesn't support float64.
|
||||
This script forces float32 which works fine for video generation.
|
||||
|
||||
Usage:
|
||||
python patch_mps.py
|
||||
|
||||
Requirements:
|
||||
- diffusers (dev version with LTX2Pipeline)
|
||||
- pip install git+https://github.com/huggingface/diffusers.git
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import site
|
||||
|
||||
|
||||
def find_diffusers_path():
|
||||
"""Find the diffusers installation path."""
|
||||
for path in site.getsitepackages():
|
||||
diffusers_path = os.path.join(path, "diffusers")
|
||||
if os.path.exists(diffusers_path):
|
||||
return diffusers_path
|
||||
|
||||
# Check user site-packages
|
||||
user_site = site.getusersitepackages()
|
||||
if user_site:
|
||||
diffusers_path = os.path.join(user_site, "diffusers")
|
||||
if os.path.exists(diffusers_path):
|
||||
return diffusers_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def patch_file(filepath, old_text, new_text, description):
|
||||
"""Patch a file by replacing text."""
|
||||
if not os.path.exists(filepath):
|
||||
print(f" SKIP: {filepath} not found")
|
||||
return False
|
||||
|
||||
with open(filepath, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
if new_text in content:
|
||||
print(f" OK: {description} (already patched)")
|
||||
return True
|
||||
|
||||
if old_text not in content:
|
||||
print(f" SKIP: {description} (pattern not found)")
|
||||
return False
|
||||
|
||||
content = content.replace(old_text, new_text)
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
print(f" PATCHED: {description}")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
print("LTX-2 MPS Patcher")
|
||||
print("=" * 50)
|
||||
|
||||
diffusers_path = find_diffusers_path()
|
||||
|
||||
if not diffusers_path:
|
||||
print("ERROR: diffusers not found. Install it first:")
|
||||
print(" pip install git+https://github.com/huggingface/diffusers.git")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found diffusers at: {diffusers_path}")
|
||||
print()
|
||||
|
||||
# Patch 1: connectors.py
|
||||
connectors_path = os.path.join(diffusers_path, "pipelines", "ltx2", "connectors.py")
|
||||
patch_file(
|
||||
connectors_path,
|
||||
"freqs_dtype = torch.float64 if self.double_precision else torch.float32",
|
||||
"# MPS fix: force float32 as MPS doesn't support float64\n freqs_dtype = torch.float32",
|
||||
"connectors.py RoPE dtype"
|
||||
)
|
||||
|
||||
# Patch 2: transformer_ltx2.py
|
||||
transformer_path = os.path.join(diffusers_path, "models", "transformers", "transformer_ltx2.py")
|
||||
patch_file(
|
||||
transformer_path,
|
||||
" # 3. Create a 1D grid of frequencies for RoPE\n freqs_dtype = torch.float64 if self.double_precision else torch.float32",
|
||||
" # 3. Create a 1D grid of frequencies for RoPE\n # MPS fix: force float32 as MPS doesn't support float64\n freqs_dtype = torch.float32",
|
||||
"transformer_ltx2.py RoPE dtype"
|
||||
)
|
||||
|
||||
print()
|
||||
print("Done! LTX-2 should now work on Apple Silicon MPS.")
|
||||
print()
|
||||
print("Test with:")
|
||||
print(" python generate.py 'A cat walking' -o test.mp4")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
torch>=2.0.0
|
||||
torchvision
|
||||
torchaudio
|
||||
transformers>=4.40.0
|
||||
accelerate>=0.25.0
|
||||
safetensors>=0.4.0
|
||||
sentencepiece>=0.1.99
|
||||
imageio>=2.30.0
|
||||
imageio-ffmpeg>=0.4.9
|
||||
# Install diffusers from git for LTX2Pipeline:
|
||||
# pip install git+https://github.com/huggingface/diffusers.git
|
||||
Reference in New Issue
Block a user