GMC-IQA / app.py
Zevin2023's picture
MoC-IQA
07e1105
raw
history blame
1.67 kB
import torch
import torchvision
import cv2
import numpy as np
from models import monet as MoNet
import argparse
from utils.dataset.process import ToTensor, Normalize
import gradio as gr
def load_image(img_path):
d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR)
# d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
d_img = np.array(d_img).astype('float32') / 255
d_img = np.transpose(d_img, (2, 0, 1))
return d_img
def predict(image):
parser = argparse.ArgumentParser()
# model related
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
help='The backbone for MoNet.')
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
config = parser.parse_args()
model = MoNet.MoNet(config).cuda()
model.load_state_dict(torch.load('./checkpoints/best_model.pkl'))
model.eval()
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
"""Run a single prediction on the model"""
img = load_image(image)
img_tensor = trans(img).unsqueeze(0).cuda()
iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
return "The image quality of the image is: {}".format(round(iq, 4))
# os.system("wget -O ./checkpoints/best_model.pkl https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
interface.launch(server_name='127.0.0.1',server_port=8088)