multidiffusion-region-based / sketch_helper.py
multimodalart's picture
Update sketch_helper.py
fb8832c
raw
history blame
No virus
2.38 kB
import numpy as np
import cv2
from PIL import Image
from skimage.color import rgb2lab
from skimage.color import lab2rgb
from sklearn.cluster import KMeans
def color_quantization(image, n_colors):
# Convert image to LAB color space
lab_image = rgb2lab(image)
# Reshape image to 2D array of pixels
pixels = lab_image.reshape(-1, 3)
# Perform K-means clustering
kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(pixels)
# Replace each pixel with the closest color
labels = kmeans.predict(pixels)
colors = kmeans.cluster_centers_
quantized_pixels = colors[labels]
# Convert quantized image back to RGB color space
quantized_lab_image = quantized_pixels.reshape(lab_image.shape)
quantized_rgb_image = lab2rgb(quantized_lab_image)
return (quantized_rgb_image * 255).astype(np.uint8)
def get_high_freq_colors(image):
im = image.getcolors(maxcolors=1024*1024)
sorted_colors = sorted(im, key=lambda x: x[0], reverse=True)
freqs = [c[0] for c in sorted_colors]
mean_freq = sum(freqs) / len(freqs)
high_freq_colors = [c for c in sorted_colors if c[0] > max(2, mean_freq)] # Ignore colors that occur very few times (less than 2) or less than half the average frequency
return high_freq_colors
def color_quantization_old(image, n_colors):
# Get color histogram
hist, _ = np.histogramdd(image.reshape(-1, 3), bins=(256, 256, 256), range=((0, 256), (0, 256), (0, 256)))
# Get most frequent colors
colors = np.argwhere(hist > 0)
colors = colors[np.argsort(hist[colors[:, 0], colors[:, 1], colors[:, 2]])[::-1]]
colors = colors[:n_colors]
# Replace each pixel with the closest color
dists = np.sum((image.reshape(-1, 1, 3) - colors.reshape(1, -1, 3))**2, axis=2)
labels = np.argmin(dists, axis=1)
return colors[labels].reshape((image.shape[0], image.shape[1], 3)).astype(np.uint8)
def create_binary_matrix(img_arr, target_color):
# Create mask of pixels with target color
mask = np.all(img_arr == target_color, axis=-1)
# Convert mask to binary matrix
binary_matrix = mask.astype(int)
from datetime import datetime
binary_file_name = f'mask-{datetime.now().timestamp()}.png'
cv2.imwrite(binary_file_name, binary_matrix * 255)
#binary_matrix = torch.from_numpy(binary_matrix).unsqueeze(0).unsqueeze(0)
return binary_file_name