52Hz commited on
Commit
151ef24
·
verified ·
1 Parent(s): 5cd8ead

Update main_test_SRMNet.py

Browse files
Files changed (1) hide show
  1. main_test_SRMNet.py +20 -19
main_test_SRMNet.py CHANGED
@@ -23,7 +23,24 @@ def clean_folder(folder):
23
  shutil.rmtree(file_path)
24
  except Exception as e:
25
  print('Failed to delete %s. Reason: %s' % (file_path, e))
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def main():
28
  parser = argparse.ArgumentParser(description='Demo Image Denoising')
29
  parser.add_argument('--input_dir', default='test', type=str, help='Input images')
@@ -74,24 +91,8 @@ def main():
74
 
75
  f = os.path.splitext(os.path.split(file_)[-1])[0]
76
  save_img((os.path.join(out_dir, f + '.png')), restored)
77
- clean_folder(inp_dir)
78
-
79
- def save_img(filepath, img):
80
- cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
81
-
82
-
83
- def load_checkpoint(model, weights):
84
- checkpoint = torch.load(weights, map_location=torch.device('cpu'))
85
- try:
86
- model.load_state_dict(checkpoint["state_dict"])
87
- except:
88
- state_dict = checkpoint["state_dict"]
89
- new_state_dict = OrderedDict()
90
- for k, v in state_dict.items():
91
- name = k[7:] # remove `module.`
92
- new_state_dict[name] = v
93
- model.load_state_dict(new_state_dict)
94
-
95
 
96
  if __name__ == '__main__':
97
  main()
 
23
  shutil.rmtree(file_path)
24
  except Exception as e:
25
  print('Failed to delete %s. Reason: %s' % (file_path, e))
26
+
27
+ def save_img(filepath, img):
28
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
29
+
30
+
31
+ def load_checkpoint(model, weights):
32
+ checkpoint = torch.load(weights, map_location=torch.device('cpu'))
33
+ try:
34
+ model.load_state_dict(checkpoint["state_dict"])
35
+ except:
36
+ state_dict = checkpoint["state_dict"]
37
+ new_state_dict = OrderedDict()
38
+ for k, v in state_dict.items():
39
+ name = k[7:] # remove `module.`
40
+ new_state_dict[name] = v
41
+ model.load_state_dict(new_state_dict)
42
+
43
+
44
  def main():
45
  parser = argparse.ArgumentParser(description='Demo Image Denoising')
46
  parser.add_argument('--input_dir', default='test', type=str, help='Input images')
 
91
 
92
  f = os.path.splitext(os.path.split(file_)[-1])[0]
93
  save_img((os.path.join(out_dir, f + '.png')), restored)
94
+ clean_folder(inp_dir)
95
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == '__main__':
98
  main()