import cv2 import numpy as np from PIL import Image from typing import Any, Dict, List def load_img_to_array(img_p): img = Image.open(img_p) if img.mode == "RGBA": img = img.convert("RGB") return np.array(img) def save_array_to_img(img_arr, img_p): Image.fromarray(img_arr.astype(np.uint8)).save(img_p) def dilate_mask(mask, dilate_factor=15): mask = mask.astype(np.uint8) mask = cv2.dilate( mask, np.ones((dilate_factor, dilate_factor), np.uint8), iterations=1 ) return mask def erode_mask(mask, dilate_factor=15): mask = mask.astype(np.uint8) mask = cv2.erode( mask, np.ones((dilate_factor, dilate_factor), np.uint8), iterations=1 ) return mask def show_mask(ax, mask: np.ndarray, random_color=False): mask = mask.astype(np.uint8) if np.max(mask) == 255: mask = mask / 255 if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_img) def show_points(ax, coords: List[List[float]], labels: List[int], size=375): coords = np.array(coords) labels = np.array(labels) color_table = {0: 'red', 1: 'green'} for label_value, color in color_table.items(): points = coords[labels == label_value] ax.scatter(points[:, 0], points[:, 1], color=color, marker='*', s=size, edgecolor='white', linewidth=1.25) def get_clicked_point(img_path): img = cv2.imread(img_path) cv2.namedWindow("image") cv2.imshow("image", img) last_point = [] keep_looping = True def mouse_callback(event, x, y, flags, param): nonlocal last_point, keep_looping, img if event == cv2.EVENT_LBUTTONDOWN: if last_point: cv2.circle(img, tuple(last_point), 5, (0, 0, 0), -1) last_point = [x, y] cv2.circle(img, tuple(last_point), 5, (0, 0, 255), -1) cv2.imshow("image", img) elif event == cv2.EVENT_RBUTTONDOWN: keep_looping = False cv2.setMouseCallback("image", mouse_callback) while keep_looping: cv2.waitKey(1) cv2.destroyAllWindows() return last_point