Zevin2023 commited on
Commit
63ccb58
1 Parent(s): a32d7be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -21,14 +21,22 @@ def load_image(img_path):
21
 
22
  return d_img
23
 
 
24
  def predict(image):
25
  global model
26
  trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
27
 
28
  """Run a single prediction on the model"""
29
  img = load_image(image)
30
- img_tensor = trans(img).unsqueeze(0).cuda()
31
- iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
 
 
 
 
 
 
 
32
 
33
  return "The image quality of the image is: {}".format(round(iq, 4))
34
 
@@ -40,8 +48,13 @@ parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_p
40
  parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
41
  config = parser.parse_args()
42
 
43
- model = MoNet.MoNet(config).cuda()
44
- model.load_state_dict(torch.load('best_model.pkl'))
 
 
 
 
 
45
  model.eval()
46
 
47
  interface = gr.Interface(fn=predict, inputs="image", outputs="text")
 
21
 
22
  return d_img
23
 
24
+ # import time
25
  def predict(image):
26
  global model
27
  trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
28
 
29
  """Run a single prediction on the model"""
30
  img = load_image(image)
31
+ # t = time.time()
32
+ if is_gpu:
33
+ img_tensor = trans(img).unsqueeze(0).cuda()
34
+ iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
35
+ # print('GPU ', time.time() - t)
36
+ else:
37
+ img_tensor = trans(img).unsqueeze(0)
38
+ iq = model(img_tensor).detach().numpy().tolist()[0]
39
+ # print('CPU Time: ', time.time() - t)
40
 
41
  return "The image quality of the image is: {}".format(round(iq, 4))
42
 
 
48
  parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
49
  config = parser.parse_args()
50
 
51
+ is_gpu = torch.cuda.is_available()
52
+ if is_gpu:
53
+ model = MoNet.MoNet(config, is_gpu=is_gpu).cuda()
54
+ model.load_state_dict(torch.load('best_model.pkl'))
55
+ else:
56
+ model = MoNet.MoNet(config, is_gpu=is_gpu)
57
+ model.load_state_dict(torch.load('best_model.pkl', map_location="cpu"))
58
  model.eval()
59
 
60
  interface = gr.Interface(fn=predict, inputs="image", outputs="text")