import argparse import functools import json import os from pathlib import Path import faiss import gradio as gr import numpy as np import PIL.Image import requests import tensorflow as tf from huggingface_hub import hf_hub_download from Utils import dbimutils TITLE = "## Danbooru Explorer" DESCRIPTION = """ Image similarity-based retrieval tool using: - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2) as feature extractor - [Faiss](https://github.com/facebookresearch/faiss) and [autofaiss](https://github.com/criteo/autofaiss) for indexing Also, check out [SmilingWolf/danbooru2022_embeddings_playground](https://huggingface.co/spaces/SmilingWolf/danbooru2022_embeddings_playground) for a similar space with experimental support for text input combined with image input. """ CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV_MODEL_REVISION = "v2.0" CONV_FEXT_LAYER = "predictions_norm" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--share", action="store_true") return parser.parse_args() def download_model(model_repo, model_revision): model_files = [ {"filename": "saved_model.pb", "subfolder": ""}, {"filename": "keras_metadata.pb", "subfolder": ""}, {"filename": "variables.index", "subfolder": "variables"}, {"filename": "variables.data-00000-of-00001", "subfolder": "variables"}, ] model_file_paths = [] for elem in model_files: model_file_paths.append( Path( hf_hub_download( model_repo, revision=model_revision, **elem, ) ) ) model_path = model_file_paths[0].parents[0] return model_path def load_model(model_repo, model_revision, feature_extraction_layer): model_path = download_model(model_repo, model_revision) full_model = tf.keras.models.load_model(model_path) model = tf.keras.models.Model( full_model.inputs, full_model.get_layer(feature_extraction_layer).output ) return model def danbooru_id_to_url(image_id, selected_ratings, api_username="", api_key=""): headers = {"User-Agent": "image_similarity_tool"} ratings_to_letters = { "General": "g", "Sensitive": "s", "Questionable": "q", "Explicit": "e", } acceptable_ratings = [ratings_to_letters[x] for x in selected_ratings] image_url = f"https://danbooru.donmai.us/posts/{image_id}.json" if api_username != "" and api_key != "": image_url = f"{image_url}?api_key={api_key}&login={api_username}" r = requests.get(image_url, headers=headers) if r.status_code != 200: return None content = json.loads(r.text) image_url = content["large_file_url"] if "large_file_url" in content else None image_url = image_url if content["rating"] in acceptable_ratings else None return image_url class SimilaritySearcher: def __init__(self, model, images_ids): self.knn_index = None self.knn_metric = None self.model = model self.images_ids = images_ids def change_index(self, knn_metric): if knn_metric == self.knn_metric: return if knn_metric == "ip": self.knn_index = faiss.read_index("index/ip_knn.index") config = json.loads(open("index/ip_infos.json").read())["index_param"] elif knn_metric == "cosine": self.knn_index = faiss.read_index("index/cosine_knn.index") config = json.loads(open("index/cosine_infos.json").read())["index_param"] faiss.ParameterSpace().set_index_parameters(self.knn_index, config) self.knn_metric = knn_metric def predict( self, image, selected_ratings, knn_metric, api_username, api_key, n_neighbours ): _, height, width, _ = self.model.inputs[0].shape self.change_index(knn_metric) # Alpha to white image = image.convert("RGBA") new_image = PIL.Image.new("RGBA", image.size, "WHITE") new_image.paste(image, mask=image) image = new_image.convert("RGB") image = np.asarray(image) # PIL RGB to OpenCV BGR image = image[:, :, ::-1] image = dbimutils.make_square(image, height) image = dbimutils.smart_resize(image, height) image = image.astype(np.float32) image = np.expand_dims(image, 0) target = self.model(image).numpy() if self.knn_metric == "cosine": faiss.normalize_L2(target) dists, indexes = self.knn_index.search(target, k=n_neighbours) neighbours_ids = self.images_ids[indexes][0] neighbours_ids = [int(x) for x in neighbours_ids] captions = [] image_urls = [] for image_id, dist in zip(neighbours_ids, dists[0]): current_url = danbooru_id_to_url( image_id, selected_ratings, api_username, api_key ) if current_url is not None: image_urls.append(current_url) captions.append(f"{image_id}/{dist:.2f}") return list(zip(image_urls, captions)) def main(): args = parse_args() model = load_model(CONV_MODEL_REPO, CONV_MODEL_REVISION, CONV_FEXT_LAYER) images_ids = np.load("index/cosine_ids.npy") searcher = SimilaritySearcher(model=model, images_ids=images_ids) with gr.Blocks() as demo: gr.Markdown(TITLE) gr.Markdown(DESCRIPTION) with gr.Row(): input = gr.Image(type="pil", label="Input") with gr.Column(): with gr.Row(): api_username = gr.Textbox(label="Danbooru API Username") api_key = gr.Textbox(label="Danbooru API Key") selected_ratings = gr.CheckboxGroup( choices=["General", "Sensitive", "Questionable", "Explicit"], value=["General", "Sensitive"], label="Ratings", ) with gr.Row(): selected_metric = gr.Radio( choices=["cosine"], value="cosine", label="Metric selection", visible=False, ) n_neighbours = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="# of images", ) find_btn = gr.Button("Find similar images") similar_images = gr.Gallery(label="Similar images", columns=[5]) find_btn.click( fn=searcher.predict, inputs=[ input, selected_ratings, selected_metric, api_username, api_key, n_neighbours, ], outputs=[similar_images], ) demo.queue() demo.launch(share=args.share) if __name__ == "__main__": main()