Cast dtype to bf16 in video and audio generation processes
This commit is contained in:
@@ -171,6 +171,7 @@ def load_image(
|
||||
image_path: Union[str, Path],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
) -> mx.array:
|
||||
"""Load and preprocess an image for I2V conditioning.
|
||||
|
||||
@@ -210,7 +211,7 @@ def load_image(
|
||||
|
||||
# Convert to numpy then MLX
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
return mx.array(image_np)
|
||||
return mx.array(image_np, dtype=dtype)
|
||||
|
||||
|
||||
def resize_image_aspect_ratio(
|
||||
|
||||
Reference in New Issue
Block a user