import numpy as np import cv2 from PIL import Image def get_high_freq_colors(image): im = image.getcolors(maxcolors=1024*1024) sorted_colors = sorted(im, key=lambda x: x[0], reverse=True) freqs = [c[0] for c in sorted_colors] mean_freq = sum(freqs) / len(freqs) high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency return high_freq_colors def color_quantization_old(image, n_colors): # Get color histogram hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256))) # Get most frequent colors colors = np.argwhere(hist > 0) colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]] colors = colors[:n_colors] # Replace each pixel with the closest color dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2) labels = np.argmin(dists, axis=1) return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8) def color_quantization(image, n_colors=8, rounds=1): h, w = image.shape[:2] samples = np.zeros([h*w,3], dtype=np.float32) count = 0 for x in range(h): for y in range(w): samples[count] = image[x][y] count += 1 compactness, labels, centers = cv2.kmeans(samples, clusters, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10000, 0.0001), rounds, cv2.KMEANS_RANDOM_CENTERS) centers = np.uint8(centers) res = centers[labels.flatten()] return res.reshape((image.shape)) def create_binary_matrix(img_arr, target_color): # Create mask of pixels with target color mask = np.all(img_arr == target_color, axis=-1) # Convert mask to binary matrix binary_matrix = mask.astype(int) from datetime import datetime binary_file_name = f'mask-{datetime.now().timestamp()}.png' cv2.imwrite(binary_file_name, binary_matrix * 255) #binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0) return binary_file_name