Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
from PIL import Image | |
from typing import Any, Dict, List | |
def load_img_to_array(img_p): | |
img = Image.open(img_p) | |
if img.mode == "RGBA": | |
img = img.convert("RGB") | |
return np.array(img) | |
def save_array_to_img(img_arr, img_p): | |
Image.fromarray(img_arr.astype(np.uint8)).save(img_p) | |
def dilate_mask(mask, dilate_factor=15): | |
mask = mask.astype(np.uint8) | |
mask = cv2.dilate( | |
mask, | |
np.ones((dilate_factor, dilate_factor), np.uint8), | |
iterations=1 | |
) | |
return mask | |
def erode_mask(mask, dilate_factor=15): | |
mask = mask.astype(np.uint8) | |
mask = cv2.erode( | |
mask, | |
np.ones((dilate_factor, dilate_factor), np.uint8), | |
iterations=1 | |
) | |
return mask | |
def show_mask(ax, mask: np.ndarray, random_color=False): | |
mask = mask.astype(np.uint8) | |
if np.max(mask) == 255: | |
mask = mask / 255 | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_img) | |
def show_points(ax, coords: List[List[float]], labels: List[int], size=375): | |
coords = np.array(coords) | |
labels = np.array(labels) | |
color_table = {0: 'red', 1: 'green'} | |
for label_value, color in color_table.items(): | |
points = coords[labels == label_value] | |
ax.scatter(points[:, 0], points[:, 1], color=color, marker='*', | |
s=size, edgecolor='white', linewidth=1.25) | |
def get_clicked_point(img_path): | |
img = cv2.imread(img_path) | |
cv2.namedWindow("image") | |
cv2.imshow("image", img) | |
last_point = [] | |
keep_looping = True | |
def mouse_callback(event, x, y, flags, param): | |
nonlocal last_point, keep_looping, img | |
if event == cv2.EVENT_LBUTTONDOWN: | |
if last_point: | |
cv2.circle(img, tuple(last_point), 5, (0, 0, 0), -1) | |
last_point = [x, y] | |
cv2.circle(img, tuple(last_point), 5, (0, 0, 255), -1) | |
cv2.imshow("image", img) | |
elif event == cv2.EVENT_RBUTTONDOWN: | |
keep_looping = False | |
cv2.setMouseCallback("image", mouse_callback) | |
while keep_looping: | |
cv2.waitKey(1) | |
cv2.destroyAllWindows() | |
return last_point |