From 281750f0a99f95b536bbe7df74b9ec269f2a16d9 Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 11 Mar 2026 12:34:28 +0100 Subject: [PATCH] Revert changes to existing files by copying some code. --- mlx_video/convert_wan.py | 4 ++-- mlx_video/generate.py | 13 ++++++++++- mlx_video/generate_wan.py | 16 +++++++++++-- mlx_video/models/wan/postprocess.py | 35 +++++++++++++++++++++++++++++ mlx_video/postprocess.py | 34 ---------------------------- mlx_video/utils.py | 14 ------------ 6 files changed, 63 insertions(+), 53 deletions(-) create mode 100644 mlx_video/models/wan/postprocess.py diff --git a/mlx_video/convert_wan.py b/mlx_video/convert_wan.py index a7930c1..5636565 100644 --- a/mlx_video/convert_wan.py +++ b/mlx_video/convert_wan.py @@ -247,7 +247,7 @@ def _load_lora_configs( Shared between weight-merging and runtime-wrapping paths. """ from mlx_video.lora import LoRAConfig, load_multiple_loras - from mlx_video.utils import Colors + from mlx_video.generate_wan import Colors print(f"\n{Colors.CYAN}Loading {len(lora_configs)} LoRA(s)...{Colors.RESET}") @@ -280,7 +280,7 @@ def load_and_apply_loras( For non-quantized (bf16) models. For quantized models, use apply_loras_to_model(). """ from mlx_video.lora import apply_loras_to_weights - from mlx_video.utils import Colors + from mlx_video.generate_wan import Colors if not lora_configs: return model_weights diff --git a/mlx_video/generate.py b/mlx_video/generate.py index 2e32a95..f5a5bc8 100644 --- a/mlx_video/generate.py +++ b/mlx_video/generate.py @@ -9,7 +9,18 @@ import numpy as np from PIL import Image from tqdm import tqdm -from mlx_video.utils import Colors +class Colors: + """ANSI color codes for terminal output.""" + + CYAN = "\033[96m" + BLUE = "\033[94m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + MAGENTA = "\033[95m" + BOLD = "\033[1m" + DIM = "\033[2m" + RESET = "\033[0m" from mlx_video.models.ltx.config import LTXModelConfig, LTXModelType, LTXRopeType from mlx_video.models.ltx.ltx import LTXModel diff --git a/mlx_video/generate_wan.py b/mlx_video/generate_wan.py index 697ce50..cc5d895 100644 --- a/mlx_video/generate_wan.py +++ b/mlx_video/generate_wan.py @@ -22,8 +22,20 @@ from mlx_video.models.wan.loading import ( load_vae_encoder, load_wan_model, ) -from mlx_video.postprocess import save_video -from mlx_video.utils import Colors +from mlx_video.models.wan.postprocess import save_video + +class Colors: + """ANSI color codes for terminal output.""" + + CYAN = "\033[96m" + BLUE = "\033[94m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + MAGENTA = "\033[95m" + BOLD = "\033[1m" + DIM = "\033[2m" + RESET = "\033[0m" # Backward-compat alias (tests and external code may use the old name) _build_i2v_mask = build_i2v_mask diff --git a/mlx_video/models/wan/postprocess.py b/mlx_video/models/wan/postprocess.py new file mode 100644 index 0000000..4c24fc6 --- /dev/null +++ b/mlx_video/models/wan/postprocess.py @@ -0,0 +1,35 @@ +import numpy as np +from pathlib import Path + +def save_video(frames: np.ndarray, output_path: str, fps: int = 16): + """Save video frames to MP4. + + Args: + frames: Video frames [T, H, W, 3] uint8 + output_path: Output file path + fps: Frames per second + """ + try: + import imageio + writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8) + for frame in frames: + writer.append_data(frame) + writer.close() + except ImportError: + try: + import cv2 + h, w = frames.shape[1], frames.shape[2] + fourcc = cv2.VideoWriter_fourcc(*"avc1") + writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) + for frame in frames: + writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) + writer.release() + except (ImportError, Exception): + # Last resort: save as individual PNGs + from PIL import Image + out_dir = Path(output_path).parent / Path(output_path).stem + out_dir.mkdir(parents=True, exist_ok=True) + for i, frame in enumerate(frames): + Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png") + print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)") + diff --git a/mlx_video/postprocess.py b/mlx_video/postprocess.py index 9d579bc..03ef61d 100644 --- a/mlx_video/postprocess.py +++ b/mlx_video/postprocess.py @@ -1,42 +1,8 @@ import numpy as np -from pathlib import Path from typing import Optional -def save_video(frames: np.ndarray, output_path: str, fps: int = 16): - """Save video frames to MP4. - - Args: - frames: Video frames [T, H, W, 3] uint8 - output_path: Output file path - fps: Frames per second - """ - try: - import imageio - writer = imageio.get_writer(output_path, fps=fps, codec="libx264", quality=8) - for frame in frames: - writer.append_data(frame) - writer.close() - except ImportError: - try: - import cv2 - h, w = frames.shape[1], frames.shape[2] - fourcc = cv2.VideoWriter_fourcc(*"avc1") - writer = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) - for frame in frames: - writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) - writer.release() - except (ImportError, Exception): - # Last resort: save as individual PNGs - from PIL import Image - out_dir = Path(output_path).parent / Path(output_path).stem - out_dir.mkdir(parents=True, exist_ok=True) - for i, frame in enumerate(frames): - Image.fromarray(frame).save(out_dir / f"frame_{i:04d}.png") - print(f" (no video encoder available, saved {len(frames)} frames to {out_dir}/)") - - def bilateral_filter(image: np.ndarray, d: int = 5, sigma_color: float = 75, sigma_space: float = 75) -> np.ndarray: """Apply bilateral filter to reduce grid artifacts while preserving edges. diff --git a/mlx_video/utils.py b/mlx_video/utils.py index da99eb3..cebbed7 100644 --- a/mlx_video/utils.py +++ b/mlx_video/utils.py @@ -9,20 +9,6 @@ from pathlib import Path from huggingface_hub import snapshot_download from PIL import Image - -class Colors: - """ANSI color codes for terminal output.""" - - CYAN = "\033[96m" - BLUE = "\033[94m" - GREEN = "\033[92m" - YELLOW = "\033[93m" - RED = "\033[91m" - MAGENTA = "\033[95m" - BOLD = "\033[1m" - DIM = "\033[2m" - RESET = "\033[0m" - def get_model_path(model_repo: str): """Get or download LTX-2 model path.""" try: