Zevin2023 commited on
Commit
a32d7be
1 Parent(s): 51e71e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -22,17 +22,7 @@ def load_image(img_path):
22
  return d_img
23
 
24
  def predict(image):
25
- parser = argparse.ArgumentParser()
26
- # model related
27
- parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
28
- help='The backbone for MoNet.')
29
- parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
30
- config = parser.parse_args()
31
-
32
- model = MoNet.MoNet(config).cuda()
33
- model.load_state_dict(torch.load('best_model.pkl'))
34
- model.eval()
35
-
36
  trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
37
 
38
  """Run a single prediction on the model"""
@@ -42,7 +32,17 @@ def predict(image):
42
 
43
  return "The image quality of the image is: {}".format(round(iq, 4))
44
 
45
- os.system("wget -O best_model.pkl https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")
 
 
 
 
 
 
 
 
 
 
46
 
47
  interface = gr.Interface(fn=predict, inputs="image", outputs="text")
48
  interface.launch()
 
22
  return d_img
23
 
24
  def predict(image):
25
+ global model
 
 
 
 
 
 
 
 
 
 
26
  trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
27
 
28
  """Run a single prediction on the model"""
 
32
 
33
  return "The image quality of the image is: {}".format(round(iq, 4))
34
 
35
+ # os.system("wget https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")
36
+
37
+ parser = argparse.ArgumentParser()
38
+ # model related
39
+ parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', help='The backbone for MoNet.')
40
+ parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
41
+ config = parser.parse_args()
42
+
43
+ model = MoNet.MoNet(config).cuda()
44
+ model.load_state_dict(torch.load('best_model.pkl'))
45
+ model.eval()
46
 
47
  interface = gr.Interface(fn=predict, inputs="image", outputs="text")
48
  interface.launch()