Spaces:
Sleeping
Sleeping
File size: 1,249 Bytes
582519c 2023a9f 582519c 2023a9f 582519c 2023a9f 582519c 2023a9f 582519c 2023a9f 582519c 2023a9f 582519c 2023a9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import json
import os
from functools import lru_cache
from typing import Mapping, List
from huggingface_hub import hf_hub_download, HfFileSystem
from imgutils.data import ImageTyping, load_image
from natsort import natsorted
from onnx_ import _open_onnx_model
from preprocess import _img_encode
hfs = HfFileSystem()
_REPO = 'deepghs/anime_classification'
_CLS_MODELS = natsorted([
os.path.dirname(os.path.relpath(file, _REPO))
for file in hfs.glob(f'{_REPO}/*/model.onnx')
])
_DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist'
@lru_cache()
def _open_anime_classify_model(model_name):
return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx'))
@lru_cache()
def _get_tags(model_name) -> List[str]:
with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f:
return json.load(f)['labels']
def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
image = load_image(image, mode='RGB')
input_ = _img_encode(image, size=(size, size))[None, ...]
output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})
labels = _get_tags(model_name)
values = dict(zip(labels, map(lambda x: x.item(), output[0])))
return values
|