import os from functools import lru_cache import gradio as gr import numpy as np from PIL import Image from huggingface_hub import hf_hub_download from imgutils.data import load_image from imgutils.utils import open_onnx_model _MODELS = [ ('nsfwjs.onnx', 224), ('inception_v3.onnx', 299), ] _MODEL_NAMES = [name for name, _ in _MODELS] _DEFAULT_MODEL_NAME = _MODEL_NAMES[0] _MODEL_TO_SIZE = dict(_MODELS) @lru_cache() def _onnx_model(name): return open_onnx_model(hf_hub_download( 'deepghs/imgutils-models', f'nsfw/{name}' )) def _image_preprocess(image, size: int = 224) -> np.ndarray: image = load_image(image, mode='RGB').resize((size, size), Image.NEAREST) return (np.array(image) / 255.0)[None, ...] _LABELS = ['drawings', 'hentai', 'neutral', 'porn', 'sexy'] def predict(image, model_name): input_ = _image_preprocess(image, _MODEL_TO_SIZE[model_name]).astype(np.float32) output_, = _onnx_model(model_name).run(['dense_3'], {'input_1': input_}) return dict(zip(_LABELS, map(float, output_[0]))) if __name__ == '__main__': with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr_input_image = gr.Image(type='pil', label='Original Image') gr_model = gr.Dropdown(_MODEL_NAMES, value=_DEFAULT_MODEL_NAME, label='Model') gr_btn_submit = gr.Button(value='Tagging', variant='primary') with gr.Column(): gr_ratings = gr.Label(label='Ratings') gr_btn_submit.click( predict, inputs=[gr_input_image, gr_model], outputs=[gr_ratings], ) demo.queue(os.cpu_count()).launch()