File size: 1,671 Bytes
07e1105 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
import torchvision
import cv2
import numpy as np
from models import monet as MoNet
import argparse
from utils.dataset.process import ToTensor, Normalize
import gradio as gr
def load_image(img_path):
d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR)
# d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
d_img = np.array(d_img).astype('float32') / 255
d_img = np.transpose(d_img, (2, 0, 1))
return d_img
def predict(image):
parser = argparse.ArgumentParser()
# model related
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
help='The backbone for MoNet.')
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
config = parser.parse_args()
model = MoNet.MoNet(config).cuda()
model.load_state_dict(torch.load('./checkpoints/best_model.pkl'))
model.eval()
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
"""Run a single prediction on the model"""
img = load_image(image)
img_tensor = trans(img).unsqueeze(0).cuda()
iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
return "The image quality of the image is: {}".format(round(iq, 4))
# os.system("wget -O ./checkpoints/best_model.pkl https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
interface.launch(server_name='127.0.0.1',server_port=8088) |