Spaces:
Running
Running
import json | |
import os | |
from collections import defaultdict | |
from functools import lru_cache | |
from typing import List, Dict | |
import faiss | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool, \ | |
KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool | |
from hfutils.operate import get_hf_fs, get_hf_client | |
from hfutils.utils import TemporaryDirectory | |
from imgutils.tagging import wd14 | |
from pools import quick_webp_pool | |
_REPO_ID = 'deepghs/index_experiments' | |
hf_fs = get_hf_fs() | |
hf_client = get_hf_client() | |
_DEFAULT_MODEL_NAME = 'SwinV2_v3_dgzyka_23325111_8GB' | |
_ALL_MODEL_NAMES = [ | |
os.path.dirname(os.path.relpath(path, _REPO_ID)) | |
for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') | |
] | |
_SITE_CLS = { | |
'danbooru': DanbooruNewestWebpDataPool, | |
'yandere': YandeWebpDataPool, | |
'zerochan': ZerochanWebpDataPool, | |
'gelbooru': GelbooruWebpDataPool, | |
'konachan': KonachanWebpDataPool, | |
'anime_pictures': AnimePicturesWebpDataPool, | |
'rule34': Rule34WebpDataPool, | |
} | |
def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]: | |
with TemporaryDirectory() as td: | |
site_cls = _SITE_CLS.get(site_name) or quick_webp_pool(site_name, 3) | |
datapool = site_cls() | |
datapool.batch_download_to_directory( | |
resource_ids=ids, | |
dst_dir=td, | |
) | |
retval = {} | |
for file in os.listdir(td): | |
id_ = int(os.path.splitext(file)[0]) | |
image = Image.open(os.path.join(td, file)) | |
image.load() | |
retval[id_] = image | |
return retval | |
def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]: | |
_sites = defaultdict(list) | |
for id_ in ids: | |
site_name, num_id = id_.rsplit('_', maxsplit=1) | |
num_id = int(num_id) | |
_sites[site_name].append(num_id) | |
_retval = {} | |
for site_name, site_ids in _sites.items(): | |
_retval.update({ | |
f'{site_name}_{id_}': image | |
for id_, image in _get_from_ids(site_name, site_ids).items() | |
}) | |
return _retval | |
def _get_index_info(repo_id: str, model_name: str): | |
image_ids = np.load(hf_client.hf_hub_download( | |
repo_id=repo_id, | |
repo_type='model', | |
filename=f'{model_name}/ids.npy', | |
)) | |
knn_index = faiss.read_index(hf_client.hf_hub_download( | |
repo_id=repo_id, | |
repo_type='model', | |
filename=f'{model_name}/knn.index', | |
)) | |
config = json.loads(open(hf_client.hf_hub_download( | |
repo_id=repo_id, | |
repo_type='model', | |
filename=f'{model_name}/infos.json', | |
)).read())["index_param"] | |
faiss.ParameterSpace().set_index_parameters(knn_index, config) | |
return image_ids, knn_index | |
def search(model_name: str, img_input, n_neighbours: int): | |
images_ids, knn_index = _get_index_info(_REPO_ID, model_name) | |
embeddings = wd14.get_wd14_tags( | |
img_input, | |
model_name="SwinV2_v3", | |
fmt="embedding", | |
) | |
embeddings = np.expand_dims(embeddings, 0) | |
faiss.normalize_L2(embeddings) | |
dists, indexes = knn_index.search(embeddings, k=n_neighbours) | |
neighbours_ids = images_ids[indexes][0] | |
captions = [] | |
images = [] | |
ids_to_images = _get_from_raw_ids(neighbours_ids) | |
for image_id, dist in zip(neighbours_ids, dists[0]): | |
if image_id in ids_to_images: | |
images.append(ids_to_images[image_id]) | |
captions.append(f"{image_id}/{dist:.2f}") | |
return list(zip(images, captions)) | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(type="pil", label="Input") | |
with gr.Column(): | |
with gr.Row(): | |
n_model = gr.Dropdown( | |
choices=_ALL_MODEL_NAMES, | |
value=_DEFAULT_MODEL_NAME, | |
label='Index to Use', | |
) | |
with gr.Row(): | |
n_neighbours = gr.Slider( | |
minimum=1, | |
maximum=50, | |
value=20, | |
step=1, | |
label="# of images", | |
) | |
find_btn = gr.Button("Find similar images") | |
with gr.Row(): | |
similar_images = gr.Gallery(label="Similar images", columns=[5]) | |
find_btn.click( | |
fn=search, | |
inputs=[ | |
n_model, | |
img_input, | |
n_neighbours, | |
], | |
outputs=[similar_images], | |
) | |
demo.queue().launch() | |