import cv2 import numpy as np from PIL import Image from typing import Any, Dict, List def load_img_to_array(img_p): return np.array(Image.open(img_p)) def save_array_to_img(img_arr, img_p): Image.fromarray(img_arr.astype(np.uint8)).save(img_p) def dilate_mask(mask, dilate_factor=15): mask = mask.astype(np.uint8) mask = cv2.dilate( mask, np.ones((dilate_factor, dilate_factor), np.uint8), iterations=1 ) return mask def erode_mask(mask, dilate_factor=15): mask = mask.astype(np.uint8) mask = cv2.erode( mask, np.ones((dilate_factor, dilate_factor), np.uint8), iterations=1 ) return mask def show_mask(ax, mask: np.ndarray, random_color=False): mask = mask.astype(np.uint8) if np.max(mask) == 255: mask = mask / 255 if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask_img = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_img) def show_points(ax, coords: List[List[float]], labels: List[int], size=375): coords = np.array(coords) labels = np.array(labels) color_table = {0: 'red', 1: 'green'} for label_value, color in color_table.items(): points = coords[labels == label_value] ax.scatter(points[:, 0], points[:, 1], color=color, marker='*', s=size, edgecolor='white', linewidth=1.25)