artwork-scorer / app.py
Muinez's picture
Update app.py
9e7770f
raw
history blame
1.13 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, ConvNextV2ForImageClassification
from transformers import AutoModelForImageClassification
from torch import nn
import dbimutils as utils
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
image_processor = AutoImageProcessor.from_pretrained("Muinez/artwork-scorer")
model = AutoModelForImageClassification.from_pretrained("Muinez/artwork-scorer", problem_type="multi_label_classification").to(DEVICE)
def predict(img):
file = utils.preprocess_image(img)
encoded = image_processor(file, return_tensors="pt").to(DEVICE)
with torch.no_grad():
logits = model(**encoded).logits.cpu()
outputs = nn.functional.sigmoid(logits)
return outputs[0][0].item(), outputs[0][1].item()
gr.Interface(
title="Artwork scorer",
description="Predicts score (0-1) for artwork.\nCould be wrong!!!\nDoes not work very well with nsfw i.e. it was not trained on it",
fn=predict,
allow_flagging="never",
inputs=gr.Image(type="pil"),
outputs=[gr.Number(label="Score"), gr.Number(label="View count ratio (probably useless)")]
).launch()