Rothfeld's picture
Update app.py
473633f verified
# %%
import cv2
from sklearn.cluster import KMeans
from PIL import Image
import numpy as np
import gradio.components as gc
import gradio as gr
def pixart(
i,
block_size=4,
n_clusters=5,
hsv_weights=[0, 0, 1],
local_contrast_blur_radius=51, # has to be odd
upscale=True,
seed=None,
output_scaling=1,
dither_amount=15
):
w, h = i.size
dw = w//block_size
dh = h//block_size
# always resize with NEAREST to keep the original colors
i = i.resize((dw, dh), Image.Resampling.NEAREST)
ai = np.array(i)
if seed is None:
# seed = np.random.randint(0, 2**32 - 1)
seed = np.random.randint(0, 2**16 - 1)
km = KMeans(n_clusters=n_clusters, random_state=seed)
hsv = cv2.cvtColor(ai, cv2.COLOR_RGB2HSV)
bhsv = cv2.GaussianBlur(
hsv,
(local_contrast_blur_radius, local_contrast_blur_radius),
0,
borderType=cv2.BORDER_REPLICATE
)
hsv32 = hsv.astype(np.float32)
km.fit(
hsv32.reshape(-1, hsv32.shape[-1]),
# (sharp-blurred) gives large values if a pixel stands out from its surroundings
# raise to the power of 4 to make the difference more pronounced.
# this preserves rare specks of color by increasing the probability of them getting their own cluster
sample_weight=(
np.linalg.norm((hsv32 - bhsv), axis=-1).reshape(-1)
** 4
)
)
label_grid = km.labels_.reshape(hsv32.shape[:2])
centers = km.cluster_centers_ # hsv values
def pick_representative_pixel(cluster):
'''pick the representative pixel for a cluster'''
most_sat_color = (hsv[label_grid == cluster] @
np.array(hsv_weights)).argmax()
return hsv[label_grid == cluster][most_sat_color]
cluster_colors = np.array([
pick_representative_pixel(c)
for c in range(centers.shape[0])])
if dither_amount == 0:
# assign each pixel the color of its cluster
ki = cluster_colors[label_grid]
else:
# add noise to the colors before selecting the nearest color, this acts as a dithering effect
noised_colors = hsv32 + np.random.normal(0, dither_amount, hsv.shape)
noised_colors = np.clip(noised_colors, 0, 255)
flattened = noised_colors.reshape(-1, 3)
# use the dot product to find the closest cluster (could also try euclidean distance)
closest_clusters = np.argmax(flattened @ centers.T,axis=1)
closest_clusters_eucledian = np.argmin(np.linalg.norm(centers - flattened[:, None], axis=-1), axis=1)
label_grid = closest_clusters_eucledian.reshape(hsv32.shape[:2])
ki = cluster_colors[label_grid]
rgb = cv2.cvtColor(ki.astype(np.uint8), cv2.COLOR_HSV2RGB)
i = Image.fromarray(rgb)
if upscale:
i = i.resize((w, h), Image.Resampling.NEAREST)
if output_scaling != 1:
i = i.resize(
(w*output_scaling, h*output_scaling), Image.Resampling.NEAREST)
return i, seed
def query(
i: Image.Image,
block_size: str,
n_clusters, # =5,
hsv_weights, # ='0,0,1'
local_contrast_blur_radius, # =51 has to be odd
seed, # =42,
output_scaling,
dither_amount
):
bs = float(block_size)
w, h = i.size
if bs < 1:
blsz = int(bs * min(w, h))
else:
blsz = int(bs)
hw = [float(w) for w in hsv_weights.split(',')]
pxart, usedseed = pixart(
i,
block_size=blsz,
n_clusters=n_clusters,
hsv_weights=hw,
local_contrast_blur_radius=local_contrast_blur_radius,
upscale=True,
seed=int(seed) if seed != '' else None,
output_scaling=output_scaling,
dither_amount=dither_amount
)
if n_clusters <= 256:
pxart = pxart.convert('P', palette=Image.Palette.ADAPTIVE, colors=n_clusters)
#pxart.save('temp.bmp')
return pxart, usedseed
# %%
searchimage = gc.Image(
# shape=(512, 512),
label="Search image", type='pil')
block_size = gc.Textbox(
"0.01",
label='Block Size ',
placeholder="e.g. 8 for 8 pixels. 0.01 for 1% of min(w,h) (<1 for percentages, >= 1 for pixels)")
palette_size = gc.Slider(
1, 1024, 32, step=1, label='Palette Size (Number of Colors)')
hsv_weights = gc.Textbox(
"0,0,1",
label='HSV Weights. Weights of the channels when selecting a "representative pixel"/centroid from a cluster of pixels',
placeholder='e.g. 0,0,1 to only consider the V channel (which seems to work well)')
lcbr = gc.Slider(
3, 512, 51, step=2, label='Blur radius to calculate local contrast')
seed = gc.Textbox(
"",
label='Seed for the random number generator (empty to randomize)',
placeholder='e.g. 42')
outimage = gc.Image(
# shape=(224, 224),
label="Output", type='pil')
seedout = gc.Textbox(label='used seed')
output_scaling = gc.Slider(
0, 16, 1, step=1, label='Output scaling factor')
dither_amount = gc.Slider(
0, 255, 0, step=1, label='Dithering amount')
gr.Interface(
query,
[searchimage, block_size, palette_size, hsv_weights, lcbr, seed, output_scaling, dither_amount],
[outimage, seedout],
title="kmeans-Pixartifier",
description=f"Turns images into pixel art using kmeans clustering",
analytics_enabled=False,
allow_flagging='never',
).launch()