|
import torch |
|
import torchvision |
|
|
|
import cv2 |
|
import numpy as np |
|
from models import monet as MoNet |
|
import argparse |
|
from utils.dataset.process import ToTensor, Normalize |
|
import gradio as gr |
|
|
|
def load_image(img_path): |
|
d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR) |
|
|
|
d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC) |
|
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) |
|
d_img = np.array(d_img).astype('float32') / 255 |
|
d_img = np.transpose(d_img, (2, 0, 1)) |
|
|
|
return d_img |
|
|
|
def predict(image): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', |
|
help='The backbone for MoNet.') |
|
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.') |
|
config = parser.parse_args() |
|
|
|
model = MoNet.MoNet(config).cuda() |
|
model.load_state_dict(torch.load('./checkpoints/best_model.pkl')) |
|
model.eval() |
|
|
|
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) |
|
|
|
"""Run a single prediction on the model""" |
|
img = load_image(image) |
|
img_tensor = trans(img).unsqueeze(0).cuda() |
|
iq = model(img_tensor).cpu().detach().numpy().tolist()[0] |
|
|
|
return "The image quality of the image is: {}".format(round(iq, 4)) |
|
|
|
|
|
|
|
interface = gr.Interface(fn=predict, inputs="image", outputs="text") |
|
interface.launch(server_name='127.0.0.1',server_port=8088) |