File size: 1,671 Bytes
07e1105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import torchvision

import cv2
import numpy as np
from models import monet as MoNet
import argparse
from utils.dataset.process import ToTensor, Normalize
import gradio as gr

def load_image(img_path):
    d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR)
    # d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
    d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
    d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
    d_img = np.array(d_img).astype('float32') / 255
    d_img = np.transpose(d_img, (2, 0, 1))
    
    return d_img

def predict(image):
    parser = argparse.ArgumentParser()
    # model related
    parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
                        help='The backbone for MoNet.')
    parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
    config = parser.parse_args()

    model = MoNet.MoNet(config).cuda()
    model.load_state_dict(torch.load('./checkpoints/best_model.pkl'))
    model.eval()

    trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])

    """Run a single prediction on the model"""
    img = load_image(image)
    img_tensor = trans(img).unsqueeze(0).cuda()
    iq = model(img_tensor).cpu().detach().numpy().tolist()[0]

    return "The image quality of the image is: {}".format(round(iq, 4))

# os.system("wget -O ./checkpoints/best_model.pkl https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")

interface = gr.Interface(fn=predict, inputs="image", outputs="text")
interface.launch(server_name='127.0.0.1',server_port=8088)