Revert changes to existing files by copying some code.
This commit is contained in:
@@ -247,7 +247,7 @@ def _load_lora_configs(
|
|||||||
Shared between weight-merging and runtime-wrapping paths.
|
Shared between weight-merging and runtime-wrapping paths.
|
||||||
"""
|
"""
|
||||||
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
from mlx_video.lora import LoRAConfig, load_multiple_loras
|
||||||
from mlx_video.utils import Colors
|
from mlx_video.generate_wan import Colors
|
||||||
|
|
||||||
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}")
|
||||||
|
|
||||||
@@ -280,7 +280,7 @@ def load_and_apply_loras(
|
|||||||
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
For non-quantized (bf16) models. For quantized models, use apply_loras_to_model().
|
||||||
"""
|
"""
|
||||||
from mlx_video.lora import apply_loras_to_weights
|
from mlx_video.lora import apply_loras_to_weights
|
||||||
from mlx_video.utils import Colors
|
from mlx_video.generate_wan import Colors
|
||||||
|
|
||||||
if not lora_configs:
|
if not lora_configs:
|
||||||
return model_weights
|
return model_weights
|
||||||
|
|||||||
@@ -9,7 +9,18 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from mlx_video.utils import Colors
|
class Colors:
|
||||||
|
"""ANSI color codes for terminal output."""
|
||||||
|
|
||||||
|
CYAN = "\033[96m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RED = "\033[91m"
|
||||||
|
MAGENTA = "\033[95m"
|
||||||
|
BOLD = "\033[1m"
|
||||||
|
DIM = "\033[2m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType
|
||||||
from mlx_video.models.ltx.ltx import LTXModel
|
from mlx_video.models.ltx.ltx import LTXModel
|
||||||
|
|||||||
@@ -22,8 +22,20 @@ from mlx_video.models.wan.loading import (
|
|||||||
load_vae_encoder,
|
load_vae_encoder,
|
||||||
load_wan_model,
|
load_wan_model,
|
||||||
)
|
)
|
||||||
from mlx_video.postprocess import save_video
|
from mlx_video.models.wan.postprocess import save_video
|
||||||
from mlx_video.utils import Colors
|
|
||||||
|
class Colors:
|
||||||
|
"""ANSI color codes for terminal output."""
|
||||||
|
|
||||||
|
CYAN = "\033[96m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RED = "\033[91m"
|
||||||
|
MAGENTA = "\033[95m"
|
||||||
|
BOLD = "\033[1m"
|
||||||
|
DIM = "\033[2m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
# Backward-compat alias (tests and external code may use the old name)
|
# Backward-compat alias (tests and external code may use the old name)
|
||||||
_build_i2v_mask = build_i2v_mask
|
_build_i2v_mask = build_i2v_mask
|
||||||
|
|||||||
35
mlx_video/models/wan/postprocess.py
Normal file
35
mlx_video/models/wan/postprocess.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
||||||
|
"""Save video frames to MP4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frames: Video frames [T, H, W, 3] uint8
|
||||||
|
output_path: Output file path
|
||||||
|
fps: Frames per second
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import imageio
|
||||||
|
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
||||||
|
for frame in frames:
|
||||||
|
writer.append_data(frame)
|
||||||
|
writer.close()
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import cv2
|
||||||
|
h, w = frames.shape[1], frames.shape[2]
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
||||||
|
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
||||||
|
for frame in frames:
|
||||||
|
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
||||||
|
writer.release()
|
||||||
|
except (ImportError, Exception):
|
||||||
|
# Last resort: save as individual PNGs
|
||||||
|
from PIL import Image
|
||||||
|
out_dir = Path(output_path).parent / Path(output_path).stem
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for i, frame in enumerate(frames):
|
||||||
|
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
||||||
|
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
|
||||||
|
|
||||||
@@ -1,42 +1,8 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def save_video(frames: np.ndarray, output_path: str, fps: int = 16):
|
|
||||||
"""Save video frames to MP4.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frames: Video frames [T, H, W, 3] uint8
|
|
||||||
output_path: Output file path
|
|
||||||
fps: Frames per second
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import imageio
|
|
||||||
writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8)
|
|
||||||
for frame in frames:
|
|
||||||
writer.append_data(frame)
|
|
||||||
writer.close()
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
import cv2
|
|
||||||
h, w = frames.shape[1], frames.shape[2]
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*"avc1")
|
|
||||||
writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
|
||||||
for frame in frames:
|
|
||||||
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
|
||||||
writer.release()
|
|
||||||
except (ImportError, Exception):
|
|
||||||
# Last resort: save as individual PNGs
|
|
||||||
from PIL import Image
|
|
||||||
out_dir = Path(output_path).parent / Path(output_path).stem
|
|
||||||
out_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
for i, frame in enumerate(frames):
|
|
||||||
Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png")
|
|
||||||
print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)")
|
|
||||||
|
|
||||||
|
|
||||||
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray:
|
||||||
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
"""Apply bilateral filter to reduce grid artifacts while preserving edges.
|
||||||
|
|
||||||
|
|||||||
@@ -9,20 +9,6 @@ from pathlib import Path
|
|||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class Colors:
|
|
||||||
"""ANSI color codes for terminal output."""
|
|
||||||
|
|
||||||
CYAN = "\033[96m"
|
|
||||||
BLUE = "\033[94m"
|
|
||||||
GREEN = "\033[92m"
|
|
||||||
YELLOW = "\033[93m"
|
|
||||||
RED = "\033[91m"
|
|
||||||
MAGENTA = "\033[95m"
|
|
||||||
BOLD = "\033[1m"
|
|
||||||
DIM = "\033[2m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
def get_model_path(model_repo: str):
|
def get_model_path(model_repo: str):
|
||||||
"""Get or download LTX-2 model path."""
|
"""Get or download LTX-2 model path."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user