from transformers import pipeline from imgutils.data import rgb_encode, load_image from onnx_ import _open_onnx_model from PIL import Image import gradio as gr import numpy as np import os import requests import torch import json def _img_encode(image, size=(384,384), normalize=(0.5,0.5)): image = image.resize(size, Image.BILINEAR) data = rgb_encode(image, order_='CHW') if normalize is not None: mean_, std_ = normalize mean = np.asarray([mean_]).reshape((-1, 1, 1)) std = np.asarray([std_]).reshape((-1, 1, 1)) data = (data - mean) / std return data.astype(np.float32) nsfw_tf = pipeline(model="carbon225/vit-base-patch16-224-hentai") if not os.path.exists("timm.onnx"): open("timm.onnx", "wb").write( requests.get( "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/model.onnx" ).content ) open("timmcfg.json", "wb").write( requests.get( "https://huggingface.co/deepghs/anime_rating/resolve/main/caformer_s36_plus/meta.json" ).content ) else: print("Model already exists, skipping redownload") with open("timmcfg.json") as file: tm_cfg = json.load(file) nsfw_tm = _open_onnx_model("timm.onnx") def launch(img): weight = 0 img = img.convert('RGB') tm_image = load_image(img, mode='RGB') tm_input_ = _img_encode(tm_image, size=(256, 256))[None, ...] tm_items, = nsfw_tm.run(['output'], {'input': tm_input_}) tm_output = sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True)[0][0] match tm_output: case "safe": weight -= 1 case "r15": weight += 2 case "r18": weight += 2 tf_output = nsfw_tf(img)[0]["label"] match tf_output: case "safe": weight -= 1 case "suggestive": weight += 1 case "r18": weight += 2 print(sorted(list(zip(tm_cfg["labels"], map(lambda x: x.item(), tm_items[0]))), key=lambda x: x[1], reverse=True), tf_output) return weight > 0 app = gr.Interface(fn=launch, inputs="pil", outputs="text") app.launch()