Nicolò commited on
Commit
c5ab13c
1 Parent(s): 6bd8735

update detector

Browse files
Files changed (1) hide show
  1. gan_vs_real_detector.py +11 -5
gan_vs_real_detector.py CHANGED
@@ -5,6 +5,9 @@ import numpy as np
5
  import torch
6
 
7
  torch.manual_seed(21)
 
 
 
8
  import torch.multiprocessing
9
 
10
  torch.multiprocessing.set_sharing_strategy('file_system')
@@ -18,7 +21,9 @@ from PIL import Image
18
  class Detector:
19
  def __init__(self):
20
 
21
- self.weights_path_list = [os.path.join('/nas/home/nbonettini/projects/StyleGAN3-detection/weights', f'method_{x}.pth') for x in 'ABCDE']
 
 
22
 
23
  # GPU configuration if available
24
  self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
@@ -30,6 +35,7 @@ class Detector:
30
  net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
31
  print('Loading model...')
32
  state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
 
33
  if 'net' not in state_tmp.keys():
34
  state = OrderedDict({'net': OrderedDict()})
35
  [state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
@@ -37,7 +43,7 @@ class Detector:
37
  state = state_tmp
38
  incomp_keys = net.load_state_dict(state['net'], strict=True)
39
  print(incomp_keys)
40
- print('Model loaded!')
41
 
42
  self.nets += [net]
43
 
@@ -85,7 +91,7 @@ class Detector:
85
  stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
86
  pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
87
  patches = pe.extract(img)
88
- patch_list = list(patches.reshape((patches.shape[0]*patches.shape[1], 128, 128, 3)))
89
 
90
  # Normalization
91
  transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
@@ -101,14 +107,14 @@ class Detector:
101
  img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))
102
 
103
  # final score is the average among the 5 scores returned by the detectors
 
104
  img_score = np.mean(img_net_scores)
105
 
106
  return img_score
107
 
108
 
109
  def main():
110
-
111
- # img_path on fermi:
112
  img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
113
 
114
  detector = Detector()
 
5
  import torch
6
 
7
  torch.manual_seed(21)
8
+ import random
9
+
10
+ random.seed(21)
11
  import torch.multiprocessing
12
 
13
  torch.multiprocessing.set_sharing_strategy('file_system')
 
21
  class Detector:
22
  def __init__(self):
23
 
24
+ # You need to download the weight.zip file from here https://www.dropbox.com/s/g1z2u8wl6srjh6v/weigths.zip?dl=0
25
+ # and uncompress it into the main folder.
26
+ self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']
27
 
28
  # GPU configuration if available
29
  self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
 
35
  net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
36
  print('Loading model...')
37
  state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
38
+
39
  if 'net' not in state_tmp.keys():
40
  state = OrderedDict({'net': OrderedDict()})
41
  [state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
 
43
  state = state_tmp
44
  incomp_keys = net.load_state_dict(state['net'], strict=True)
45
  print(incomp_keys)
46
+ print('Model loaded!\n')
47
 
48
  self.nets += [net]
49
 
 
91
  stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
92
  pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
93
  patches = pe.extract(img)
94
+ patch_list = list(patches.reshape((patches.shape[0] * patches.shape[1], 128, 128, 3)))
95
 
96
  # Normalization
97
  transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
 
107
  img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))
108
 
109
  # final score is the average among the 5 scores returned by the detectors
110
+ print(img_net_scores)
111
  img_score = np.mean(img_net_scores)
112
 
113
  return img_score
114
 
115
 
116
  def main():
117
+ # debug img_path on fermi:
 
118
  img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
119
 
120
  detector = Detector()