Spaces:
Runtime error
Runtime error
File size: 3,145 Bytes
da3eeba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from typing import Any, Dict, List, Union
import numpy as np
from PIL import Image
def invert_mask(mask: np.ndarray) -> np.ndarray:
"""Invert mask.
Args:
mask (np.ndarray): mask
Returns:
np.ndarray: inverted mask
"""
if mask is None or not isinstance(mask, np.ndarray):
raise ValueError("Invalid mask")
# return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255
return np.invert(mask.astype(np.uint8))
def check_inputs_create_mask_image(
mask: Union[np.ndarray, Image.Image],
sam_masks: List[Dict[str, Any]],
ignore_black_chk: bool = True,
) -> None:
"""Check create mask image inputs.
Args:
mask (Union[np.ndarray, Image.Image]): mask
sam_masks (List[Dict[str, Any]]): SAM masks
ignore_black_chk (bool): ignore black check
Returns:
None
"""
if mask is None or not isinstance(mask, (np.ndarray, Image.Image)):
raise ValueError("Invalid mask")
if sam_masks is None or not isinstance(sam_masks, list):
raise ValueError("Invalid SAM masks")
if ignore_black_chk is None or not isinstance(ignore_black_chk, bool):
raise ValueError("Invalid ignore black check")
def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray:
"""Convert mask.
Args:
mask (Union[np.ndarray, Image.Image]): mask
Returns:
np.ndarray: converted mask
"""
if isinstance(mask, Image.Image):
mask = np.array(mask)
if mask.ndim == 2:
mask = mask[:, :, np.newaxis]
if mask.shape[2] != 1:
mask = mask[:, :, 0:1]
return mask
def create_mask_image(
mask: Union[np.ndarray, Image.Image],
sam_masks: List[Dict[str, Any]],
ignore_black_chk: bool = True,
) -> np.ndarray:
"""Create mask image.
Args:
mask (Union[np.ndarray, Image.Image]): mask
sam_masks (List[Dict[str, Any]]): SAM masks
ignore_black_chk (bool): ignore black check
Returns:
np.ndarray: mask image
"""
check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk)
mask = convert_mask(mask)
canvas_image = np.zeros(mask.shape, dtype=np.uint8)
mask_region = np.zeros(mask.shape, dtype=np.uint8)
for seg_dict in sam_masks:
seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1)
canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
if (seg_mask * canvas_mask * mask).astype(bool).any():
mask_region = mask_region + (seg_mask * canvas_mask)
seg_color = seg_mask * canvas_mask
canvas_image = canvas_image + seg_color
if not ignore_black_chk:
canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8)
if (canvas_mask * mask).astype(bool).any():
mask_region = mask_region + (canvas_mask)
mask_region = np.tile(mask_region * 255, (1, 1, 3))
seg_image = mask_region.astype(np.uint8)
return seg_image
|