File size: 1,249 Bytes
582519c
 
2023a9f
582519c
2023a9f
582519c
2023a9f
582519c
2023a9f
 
 
 
582519c
 
 
 
 
 
 
2023a9f
 
 
 
 
582519c
 
 
 
 
 
 
2023a9f
 
 
 
 
 
 
582519c
 
2023a9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import json
import os
from functools import lru_cache
from typing import Mapping, List

from huggingface_hub import hf_hub_download, HfFileSystem
from imgutils.data import ImageTyping, load_image
from natsort import natsorted

from onnx_ import _open_onnx_model
from preprocess import _img_encode

hfs = HfFileSystem()

_REPO = 'deepghs/anime_classification'
_CLS_MODELS = natsorted([
    os.path.dirname(os.path.relpath(file, _REPO))
    for file in hfs.glob(f'{_REPO}/*/model.onnx')
])
_DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist'


@lru_cache()
def _open_anime_classify_model(model_name):
    return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))


@lru_cache()
def _get_tags(model_name) -> List[str]:
    with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
        return json.load(f)['labels']


def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
    image = load_image(image, mode='RGB')
    input_ = _img_encode(image, size=(size, size))[None, ...]
    output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})

    labels = _get_tags(model_name)
    values = dict(zip(labels, map(lambda x: x.item(), output[0])))
    return values