|
import os |
|
from typing import Tuple |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
from sklearn.cluster import KMeans |
|
|
|
|
|
def _image_resize(image: Image.Image, pixels: int = 90000, **kwargs): |
|
rt = (image.size[0] * image.size[1] / pixels) ** 0.5 |
|
if rt > 1.0: |
|
small_image = image.resize((int(image.size[0] / rt), int(image.size[1] / rt)), **kwargs) |
|
else: |
|
small_image = image.copy() |
|
return small_image |
|
|
|
|
|
def get_main_colors(image: Image.Image, n: int = 28, pixels: int = 90000) \ |
|
-> Tuple[Image.Image, np.ndarray, np.ndarray, np.ndarray]: |
|
image = image.copy() |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
small_image = _image_resize(image, pixels) |
|
|
|
few_raw = np.asarray(small_image).reshape(-1, 3) |
|
kmeans = KMeans(n_clusters=n) |
|
kmeans.fit(few_raw) |
|
|
|
width, height = image.size |
|
raw = np.asarray(image).reshape(-1, 3) |
|
colors = kmeans.cluster_centers_.round().astype(np.uint8) |
|
prediction = kmeans.predict(raw) |
|
new_data = colors[prediction].reshape((height, width, 3)) |
|
new_image = Image.fromarray(new_data, mode='RGB') |
|
|
|
cids, counts = np.unique(prediction, return_counts=True) |
|
counts = np.asarray(list(map(lambda x: x[1], sorted(zip(cids, counts))))) |
|
|
|
return new_image, colors, counts, prediction.reshape((height, width)) |
|
|
|
|
|
def main_func(image: Image.Image, n: int, pixels: int, fixed_width: bool, width: int): |
|
if fixed_width: |
|
_width, _height = image.size |
|
r = width / _width |
|
new_width, new_height = int(round(_width * r)), int(round(_height * r)) |
|
image = image.resize((new_width, new_height)) |
|
|
|
new_image, colors, counts, predictions = get_main_colors(image, n, pixels) |
|
|
|
table = pd.DataFrame({ |
|
'r': colors[:, 0], |
|
'g': colors[:, 1], |
|
'b': colors[:, 2], |
|
'count': counts, |
|
}) |
|
table['ratio'] = table['count'] / table['count'].sum() |
|
hexes = [] |
|
for r, g, b in zip(table['r'], table['g'], table['b']): |
|
hexes.append(f'#{r:02x}{g:02x}{b:02x}') |
|
table['hex'] = hexes |
|
|
|
new_table = pd.DataFrame({ |
|
'Hex': table['hex'], |
|
'Pixels': table['count'], |
|
'Ratio': table['ratio'], |
|
'Red': table['r'], |
|
'Green': table['g'], |
|
'Blue': table['b'], |
|
}).sort_values('Pixels', ascending=False) |
|
|
|
return new_image, new_table |
|
|
|
|
|
if __name__ == '__main__': |
|
pd.set_option("display.precision", 3) |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
ch_image = gr.Image(type='pil', label='Original Image') |
|
with gr.Row(): |
|
ch_clusters = gr.Slider(value=8, minimum=2, maximum=256, step=2, label='Clusters') |
|
ch_pixels = gr.Slider(value=100000, minimum=10000, maximum=1000000, step=10000, |
|
label='Pixels for Clustering') |
|
ch_fixed_width = gr.Checkbox(value=True, label='Width Fixed') |
|
ch_width = gr.Slider(value=200, minimum=12, maximum=2048, label='Width') |
|
|
|
ch_submit = gr.Button(value='Submit', variant='primary') |
|
|
|
with gr.Column(): |
|
with gr.Tabs(): |
|
with gr.Tab('Output Image'): |
|
ch_output = gr.Image(type='pil', label='Output Image') |
|
with gr.Tab('Color Map'): |
|
ch_color_map = gr.Dataframe( |
|
headers=['Hex', 'Pixels', 'Ratio', 'Red', 'Green', 'Blue'], |
|
label='Color Map' |
|
) |
|
|
|
ch_submit.click( |
|
main_func, |
|
inputs=[ch_image, ch_clusters, ch_pixels, ch_fixed_width, ch_width], |
|
outputs=[ch_output, ch_color_map], |
|
) |
|
|
|
demo.queue(os.cpu_count()).launch() |
|
|