|
import gradio as gr |
|
import timm |
|
import torch |
|
import pandas as pd |
|
|
|
|
|
TITLE = "wd-eva02-large-tagger-v3-vector" |
|
DESCRIPTION = """ |
|
[model](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3) |
|
""" |
|
|
|
model = timm.create_model(f"hf_hub:SmilingWolf/wd-eva02-large-tagger-v3", pretrained=True) |
|
head = model.head.weight.data |
|
del model |
|
df = pd.read_csv(f"https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3/resolve/main/selected_tags.csv") |
|
id2label = df["name"].to_dict() |
|
label2id = {v:k for k,v in id2label.items()} |
|
general_tags = df[df["category"] == 0].index |
|
character_tags = df[df["category"] == 4].index |
|
all_tags = df.index |
|
|
|
def predict(target_tags, search_in): |
|
target_tags = [tag.strip().replace(" ", "_") for tag in target_tags.split(",")] |
|
target_ids = [label2id[tag] for tag in target_tags] |
|
query = head[target_ids].unsqueeze(1) |
|
|
|
sim = torch.cosine_similarity(query, head.unsqueeze(0), dim=2).mean(dim=0) |
|
tags = general_tags if search_in == "general" else character_tags if search_in == "character" else all_tags |
|
return {id2label[i]: sim[i].item() for i in tags} |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Text(value="pink hair, braid", label="Target tags"), |
|
gr.Dropdown(["all", "general", "character"], label="Search in", value="all") |
|
], |
|
outputs=gr.Label(num_top_classes=50), |
|
title=TITLE, |
|
description=DESCRIPTION |
|
) |
|
|
|
demo.launch() |