Spaces:
Running
Running
# %% | |
import cv2 | |
from sklearn.cluster import KMeans | |
from PIL import Image | |
import numpy as np | |
import gradio.components as gc | |
import gradio as gr | |
def pixart( | |
i, | |
block_size=4, | |
n_clusters=5, | |
hsv_weights=[0, 0, 1], | |
local_contrast_blur_radius=51, # has to be odd | |
upscale=True, | |
seed=None, | |
output_scaling=1, | |
dither_amount=15 | |
): | |
w, h = i.size | |
dw = w//block_size | |
dh = h//block_size | |
# always resize with NEAREST to keep the original colors | |
i = i.resize((dw, dh), Image.Resampling.NEAREST) | |
ai = np.array(i) | |
if seed is None: | |
# seed = np.random.randint(0, 2**32 - 1) | |
seed = np.random.randint(0, 2**16 - 1) | |
km = KMeans(n_clusters=n_clusters, random_state=seed) | |
hsv = cv2.cvtColor(ai, cv2.COLOR_RGB2HSV) | |
bhsv = cv2.GaussianBlur( | |
hsv, | |
(local_contrast_blur_radius, local_contrast_blur_radius), | |
0, | |
borderType=cv2.BORDER_REPLICATE | |
) | |
hsv32 = hsv.astype(np.float32) | |
km.fit( | |
hsv32.reshape(-1, hsv32.shape[-1]), | |
# (sharp-blurred) gives large values if a pixel stands out from its surroundings | |
# raise to the power of 4 to make the difference more pronounced. | |
# this preserves rare specks of color by increasing the probability of them getting their own cluster | |
sample_weight=( | |
np.linalg.norm((hsv32 - bhsv), axis=-1).reshape(-1) | |
** 4 | |
) | |
) | |
label_grid = km.labels_.reshape(hsv32.shape[:2]) | |
centers = km.cluster_centers_ # hsv values | |
def pick_representative_pixel(cluster): | |
'''pick the representative pixel for a cluster''' | |
most_sat_color = (hsv[label_grid == cluster] @ | |
np.array(hsv_weights)).argmax() | |
return hsv[label_grid == cluster][most_sat_color] | |
cluster_colors = np.array([ | |
pick_representative_pixel(c) | |
for c in range(centers.shape[0])]) | |
if dither_amount == 0: | |
# assign each pixel the color of its cluster | |
ki = cluster_colors[label_grid] | |
else: | |
# add noise to the colors before selecting the nearest color, this acts as a dithering effect | |
noised_colors = hsv32 + np.random.normal(0, dither_amount, hsv.shape) | |
noised_colors = np.clip(noised_colors, 0, 255) | |
flattened = noised_colors.reshape(-1, 3) | |
# use the dot product to find the closest cluster (could also try euclidean distance) | |
closest_clusters = np.argmax(flattened @ centers.T,axis=1) | |
closest_clusters_eucledian = np.argmin(np.linalg.norm(centers - flattened[:, None], axis=-1), axis=1) | |
label_grid = closest_clusters_eucledian.reshape(hsv32.shape[:2]) | |
ki = cluster_colors[label_grid] | |
rgb = cv2.cvtColor(ki.astype(np.uint8), cv2.COLOR_HSV2RGB) | |
i = Image.fromarray(rgb) | |
if upscale: | |
i = i.resize((w, h), Image.Resampling.NEAREST) | |
if output_scaling != 1: | |
i = i.resize( | |
(w*output_scaling, h*output_scaling), Image.Resampling.NEAREST) | |
return i, seed | |
def query( | |
i: Image.Image, | |
block_size: str, | |
n_clusters, # =5, | |
hsv_weights, # ='0,0,1' | |
local_contrast_blur_radius, # =51 has to be odd | |
seed, # =42, | |
output_scaling, | |
dither_amount | |
): | |
bs = float(block_size) | |
w, h = i.size | |
if bs < 1: | |
blsz = int(bs * min(w, h)) | |
else: | |
blsz = int(bs) | |
hw = [float(w) for w in hsv_weights.split(',')] | |
pxart, usedseed = pixart( | |
i, | |
block_size=blsz, | |
n_clusters=n_clusters, | |
hsv_weights=hw, | |
local_contrast_blur_radius=local_contrast_blur_radius, | |
upscale=True, | |
seed=int(seed) if seed != '' else None, | |
output_scaling=output_scaling, | |
dither_amount=dither_amount | |
) | |
if n_clusters <= 256: | |
pxart = pxart.convert('P', palette=Image.Palette.ADAPTIVE, colors=n_clusters) | |
#pxart.save('temp.bmp') | |
return pxart, usedseed | |
# %% | |
searchimage = gc.Image( | |
# shape=(512, 512), | |
label="Search image", type='pil') | |
block_size = gc.Textbox( | |
"0.01", | |
label='Block Size ', | |
placeholder="e.g. 8 for 8 pixels. 0.01 for 1% of min(w,h) (<1 for percentages, >= 1 for pixels)") | |
palette_size = gc.Slider( | |
1, 1024, 32, step=1, label='Palette Size (Number of Colors)') | |
hsv_weights = gc.Textbox( | |
"0,0,1", | |
label='HSV Weights. Weights of the channels when selecting a "representative pixel"/centroid from a cluster of pixels', | |
placeholder='e.g. 0,0,1 to only consider the V channel (which seems to work well)') | |
lcbr = gc.Slider( | |
3, 512, 51, step=2, label='Blur radius to calculate local contrast') | |
seed = gc.Textbox( | |
"", | |
label='Seed for the random number generator (empty to randomize)', | |
placeholder='e.g. 42') | |
outimage = gc.Image( | |
# shape=(224, 224), | |
label="Output", type='pil') | |
seedout = gc.Textbox(label='used seed') | |
output_scaling = gc.Slider( | |
0, 16, 1, step=1, label='Output scaling factor') | |
dither_amount = gc.Slider( | |
0, 255, 0, step=1, label='Dithering amount') | |
gr.Interface( | |
query, | |
[searchimage, block_size, palette_size, hsv_weights, lcbr, seed, output_scaling, dither_amount], | |
[outimage, seedout], | |
title="kmeans-Pixartifier", | |
description=f"Turns images into pixel art using kmeans clustering", | |
analytics_enabled=False, | |
allow_flagging='never', | |
).launch() | |