import torch import torchvision import cv2 import numpy as np from models import monet as MoNet import argparse from utils.dataset.process import ToTensor, Normalize import gradio as gr def load_image(img_path): d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR) # d_img = cv2.imread(img_path, cv2.IMREAD_COLOR) d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC) d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB) d_img = np.array(d_img).astype('float32') / 255 d_img = np.transpose(d_img, (2, 0, 1)) return d_img def predict(image): parser = argparse.ArgumentParser() # model related parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', help='The backbone for MoNet.') parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.') config = parser.parse_args() model = MoNet.MoNet(config).cuda() model.load_state_dict(torch.load('./checkpoints/best_model.pkl')) model.eval() trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()]) """Run a single prediction on the model""" img = load_image(image) img_tensor = trans(img).unsqueeze(0).cuda() iq = model(img_tensor).cpu().detach().numpy().tolist()[0] return "The image quality of the image is: {}".format(round(iq, 4)) # os.system("wget -O ./checkpoints/best_model.pkl https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl") interface = gr.Interface(fn=predict, inputs="image", outputs="text") interface.launch(server_name='127.0.0.1',server_port=8088)