RysonFeng
Add source code
cdb26a4
raw
history blame contribute delete
No virus
1.53 kB
import cv2
import numpy as np
from PIL import Image
from typing import Any, Dict, List
def load_img_to_array(img_p):
return np.array(Image.open(img_p))
def save_array_to_img(img_arr, img_p):
Image.fromarray(img_arr.astype(np.uint8)).save(img_p)
def dilate_mask(mask, dilate_factor=15):
mask = mask.astype(np.uint8)
mask = cv2.dilate(
mask,
np.ones((dilate_factor, dilate_factor), np.uint8),
iterations=1
)
return mask
def erode_mask(mask, dilate_factor=15):
mask = mask.astype(np.uint8)
mask = cv2.erode(
mask,
np.ones((dilate_factor, dilate_factor), np.uint8),
iterations=1
)
return mask
def show_mask(ax, mask: np.ndarray, random_color=False):
mask = mask.astype(np.uint8)
if np.max(mask) == 255:
mask = mask / 255
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_img)
def show_points(ax, coords: List[List[float]], labels: List[int], size=375):
coords = np.array(coords)
labels = np.array(labels)
color_table = {0: 'red', 1: 'green'}
for label_value, color in color_table.items():
points = coords[labels == label_value]
ax.scatter(points[:, 0], points[:, 1], color=color, marker='*',
s=size, edgecolor='white', linewidth=1.25)