Spaces:
Sleeping
Sleeping
import json | |
import os | |
from functools import lru_cache | |
from typing import Mapping, List | |
from huggingface_hub import hf_hub_download, HfFileSystem | |
from imgutils.data import ImageTyping, load_image | |
from natsort import natsorted | |
from onnx_ import _open_onnx_model | |
from preprocess import _img_encode | |
hfs = HfFileSystem() | |
_REPO = 'deepghs/anime_classification' | |
_CLS_MODELS = natsorted([ | |
os.path.dirname(os.path.relpath(file, _REPO)) | |
for file in hfs.glob(f'{_REPO}/*/model.onnx') | |
]) | |
_DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist' | |
def _open_anime_classify_model(model_name): | |
return _open_onnx_model(hf_hub_download(_REPO, f'{model_name}/model.onnx')) | |
def _get_tags(model_name) -> List[str]: | |
with open(hf_hub_download(_REPO, f'{model_name}/meta.json'), 'r') as f: | |
return json.load(f)['labels'] | |
def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]: | |
image = load_image(image, mode='RGB') | |
input_ = _img_encode(image, size=(size, size))[None, ...] | |
output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_}) | |
labels = _get_tags(model_name) | |
values = dict(zip(labels, map(lambda x: x.item(), output[0]))) | |
return values | |