Cast dtype to bf16 in video and audio generation processes

This commit is contained in:
Prince Canuma
2026-01-17 17:20:22 +01:00
parent 883c6b0ad8
commit 78244a2d66
3 changed files with 86 additions and 76 deletions

View File

@@ -171,6 +171,7 @@ def load_image(
image_path: Union[str, Path],
height: Optional[int] = None,
width: Optional[int] = None,
dtype: mx.Dtype = mx.float32,
) -> mx.array:
"""Load and preprocess an image for I2V conditioning.
@@ -210,7 +211,7 @@ def load_image(
# Convert to numpy then MLX
image_np = np.array(image).astype(np.float32) / 255.0
return mx.array(image_np)
return mx.array(image_np, dtype=dtype)
def resize_image_aspect_ratio(