File size: 4,753 Bytes
3c7d8d9
 
681a350
3c7d8d9
 
 
 
 
 
 
a54d829
84a3dfc
3c7d8d9
 
 
 
4737e31
8537a9b
3c7d8d9
 
 
 
 
a54d829
3c7d8d9
 
 
 
 
681a350
71df424
9ecc7a5
1557500
 
0cc211e
532d724
84a3dfc
681a350
3c7d8d9
681a350
 
3c7d8d9
4737e31
 
3c7d8d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681a350
 
 
175cb6c
681a350
 
 
 
 
 
 
defe89e
681a350
 
3c7d8d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681a350
3c7d8d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
import os
from collections import defaultdict
from functools import lru_cache
from typing import List, Dict

import faiss
import gradio as gr
import numpy as np
from PIL import Image
from cheesechaser.datapool import YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool, \
    KonachanWebpDataPool, AnimePicturesWebpDataPool, DanbooruNewestWebpDataPool, Rule34WebpDataPool
from hfutils.operate import get_hf_fs, get_hf_client
from hfutils.utils import TemporaryDirectory
from imgutils.tagging import wd14

from pools import quick_webp_pool

_REPO_ID = 'deepghs/index_experiments'

hf_fs = get_hf_fs()
hf_client = get_hf_client()

_DEFAULT_MODEL_NAME = 'SwinV2_v3_dgzyka_23325111_8GB'
_ALL_MODEL_NAMES = [
    os.path.dirname(os.path.relpath(path, _REPO_ID))
    for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index')
]

_SITE_CLS = {
    'danbooru': DanbooruNewestWebpDataPool,
    'yandere': YandeWebpDataPool,
    'zerochan': ZerochanWebpDataPool,
    'gelbooru': GelbooruWebpDataPool,
    'konachan': KonachanWebpDataPool,
    'anime_pictures': AnimePicturesWebpDataPool,
    'rule34': Rule34WebpDataPool,
}


def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]:
    with TemporaryDirectory() as td:
        site_cls = _SITE_CLS.get(site_name) or quick_webp_pool(site_name, 3)
        datapool = site_cls()
        datapool.batch_download_to_directory(
            resource_ids=ids,
            dst_dir=td,
        )

        retval = {}
        for file in os.listdir(td):
            id_ = int(os.path.splitext(file)[0])
            image = Image.open(os.path.join(td, file))
            image.load()
            retval[id_] = image

        return retval


def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]:
    _sites = defaultdict(list)
    for id_ in ids:
        site_name, num_id = id_.rsplit('_', maxsplit=1)
        num_id = int(num_id)
        _sites[site_name].append(num_id)

    _retval = {}
    for site_name, site_ids in _sites.items():
        _retval.update({
            f'{site_name}_{id_}': image
            for id_, image in _get_from_ids(site_name, site_ids).items()
        })
    return _retval


@lru_cache(maxsize=3)
def _get_index_info(repo_id: str, model_name: str):
    image_ids = np.load(hf_client.hf_hub_download(
        repo_id=repo_id,
        repo_type='model',
        filename=f'{model_name}/ids.npy',
    ))
    knn_index = faiss.read_index(hf_client.hf_hub_download(
        repo_id=repo_id,
        repo_type='model',
        filename=f'{model_name}/knn.index',
    ))

    config = json.loads(open(hf_client.hf_hub_download(
        repo_id=repo_id,
        repo_type='model',
        filename=f'{model_name}/infos.json',
    )).read())["index_param"]
    faiss.ParameterSpace().set_index_parameters(knn_index, config)
    return image_ids, knn_index


def search(model_name: str, img_input, n_neighbours: int):
    images_ids, knn_index = _get_index_info(_REPO_ID, model_name)
    embeddings = wd14.get_wd14_tags(
        img_input,
        model_name="SwinV2_v3",
        fmt="embedding",
    )
    embeddings = np.expand_dims(embeddings, 0)
    faiss.normalize_L2(embeddings)

    dists, indexes = knn_index.search(embeddings, k=n_neighbours)
    neighbours_ids = images_ids[indexes][0]

    captions = []
    images = []
    ids_to_images = _get_from_raw_ids(neighbours_ids)
    for image_id, dist in zip(neighbours_ids, dists[0]):
        if image_id in ids_to_images:
            images.append(ids_to_images[image_id])
            captions.append(f"{image_id}/{dist:.2f}")

    return list(zip(images, captions))


if __name__ == "__main__":
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                img_input = gr.Image(type="pil", label="Input")

            with gr.Column():
                with gr.Row():
                    n_model = gr.Dropdown(
                        choices=_ALL_MODEL_NAMES,
                        value=_DEFAULT_MODEL_NAME,
                        label='Index to Use',
                    )
                with gr.Row():
                    n_neighbours = gr.Slider(
                        minimum=1,
                        maximum=50,
                        value=20,
                        step=1,
                        label="# of images",
                    )
                find_btn = gr.Button("Find similar images")

        with gr.Row():
            similar_images = gr.Gallery(label="Similar images", columns=[5])

        find_btn.click(
            fn=search,
            inputs=[
                n_model,
                img_input,
                n_neighbours,
            ],
            outputs=[similar_images],
        )

    demo.queue().launch()