Update app.py
Browse files
app.py
CHANGED
@@ -21,14 +21,22 @@ def load_image(img_path):
|
|
21 |
|
22 |
return d_img
|
23 |
|
|
|
24 |
def predict(image):
|
25 |
global model
|
26 |
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
|
27 |
|
28 |
"""Run a single prediction on the model"""
|
29 |
img = load_image(image)
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
return "The image quality of the image is: {}".format(round(iq, 4))
|
34 |
|
@@ -40,8 +48,13 @@ parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_p
|
|
40 |
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
|
41 |
config = parser.parse_args()
|
42 |
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
model.eval()
|
46 |
|
47 |
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
|
|
|
21 |
|
22 |
return d_img
|
23 |
|
24 |
+
# import time
|
25 |
def predict(image):
|
26 |
global model
|
27 |
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
|
28 |
|
29 |
"""Run a single prediction on the model"""
|
30 |
img = load_image(image)
|
31 |
+
# t = time.time()
|
32 |
+
if is_gpu:
|
33 |
+
img_tensor = trans(img).unsqueeze(0).cuda()
|
34 |
+
iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
|
35 |
+
# print('GPU ', time.time() - t)
|
36 |
+
else:
|
37 |
+
img_tensor = trans(img).unsqueeze(0)
|
38 |
+
iq = model(img_tensor).detach().numpy().tolist()[0]
|
39 |
+
# print('CPU Time: ', time.time() - t)
|
40 |
|
41 |
return "The image quality of the image is: {}".format(round(iq, 4))
|
42 |
|
|
|
48 |
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
|
49 |
config = parser.parse_args()
|
50 |
|
51 |
+
is_gpu = torch.cuda.is_available()
|
52 |
+
if is_gpu:
|
53 |
+
model = MoNet.MoNet(config, is_gpu=is_gpu).cuda()
|
54 |
+
model.load_state_dict(torch.load('best_model.pkl'))
|
55 |
+
else:
|
56 |
+
model = MoNet.MoNet(config, is_gpu=is_gpu)
|
57 |
+
model.load_state_dict(torch.load('best_model.pkl', map_location="cpu"))
|
58 |
model.eval()
|
59 |
|
60 |
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
|