import os from functools import lru_cache import gradio as gr import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from imgutils.data import load_image from imgutils.utils import open_onnx_model _MODELS = [ ('caformer_s36_v0_ls0.2', 224), ('mobilenetv3_large_100_v0_ls0.2', 224), # ('swinv2pv3_v0_ls0.2', 224), ] _MODEL_NAMES = [name for name, _ in _MODELS] _DEFAULT_MODEL_NAME = _MODEL_NAMES[0] _MODEL_TO_SIZE = dict(_MODELS) @lru_cache() def _onnx_model(name): return open_onnx_model(hf_hub_download( 'deepghs/anime_dbrating', f'{name}/model.onnx' )) def _image_preprocess(image, size: int = 224) -> np.ndarray: image = load_image(image, mode='RGB').resize((size, size), Image.NEAREST) image = np.array(image) / 255.0 image = image.transpose(2, 0, 1) return image[None, ...] _LABELS = ['general', 'sensitive', 'questionable', 'explicit'] def predict(image, model_name): input_ = _image_preprocess(image, _MODEL_TO_SIZE[model_name]).astype(np.float32) output_, = _onnx_model(model_name).run(['output'], {'input': input_}) return dict(zip(_LABELS, map(float, output_[0]))) if __name__ == '__main__': with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr_input_image = gr.Image(type='pil', label='Original Image') gr_model = gr.Dropdown(_MODEL_NAMES, value=_DEFAULT_MODEL_NAME, label='Model') gr_btn_submit = gr.Button(value='Tagging', variant='primary') with gr.Column(): gr_ratings = gr.Label(label='Ratings') gr_btn_submit.click( predict, inputs=[gr_input_image, gr_model], outputs=[gr_ratings], ) demo.queue(os.cpu_count()).launch()