Nicolò commited on
Commit
f6b58ff
1 Parent(s): 4988f80

draft of image prediction

Browse files
Files changed (1) hide show
  1. gan_vs_real_detector.py +145 -0
gan_vs_real_detector.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ torch.manual_seed(21)
8
+ import torch.multiprocessing
9
+
10
+ torch.multiprocessing.set_sharing_strategy('file_system')
11
+ import albumentations as A
12
+ import albumentations.pytorch as Ap
13
+ from utils import architectures
14
+ from PIL import Image
15
+
16
+
17
+ class Detector:
18
+ def __init__(self):
19
+
20
+ # model directory and path for detector A
21
+ # model_A_dir = 'weights/method_A/net-EfficientNetB4_lr-0.001_img_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\', ' \
22
+ # '\'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\', \'resize\', \'jpeg\']' \
23
+ # '_img_aug_p-0.5_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
24
+ #
25
+ # # model directory and path for detector B
26
+ # model_B_dir = 'weights/method_B/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\', ' \
27
+ # '\'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
28
+ # '_aug_p-0.5_jpeg_aug_p-0.7_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
29
+ #
30
+ # # model directory and path for detector C
31
+ # model_C_dir = 'weights/method_C/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
32
+ # ' \'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
33
+ # '_aug_p-0.5_jpeg_aug_p-0_patch_size-128_patch_number-5_batch_size-50_num_classes-2'
34
+ #
35
+ # # model directory and path for detector D
36
+ # model_D_dir = 'weights/method_D/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
37
+ # '\'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
38
+ # '_aug_p-0.5_jpeg_aug_p-0_patch_size-128_patch_number-10_batch_size-25_num_classes-2'
39
+ #
40
+ # # model directory for detector E
41
+ # model_E_dir = 'weights/method_E/net-EfficientNetB4_lr-0.001_aug-[\'flip\', \'rotate\', \'clahe\', \'blur\',' \
42
+ # ' \'crop&resize\', \'brightness&contrast\', \'jitter\', \'downscale\', \'hsv\']' \
43
+ # '_aug_p-0.5_jpeg_aug_p-0.7_patch_size-128_patch_number-1_batch_size-250_num_classes-2'
44
+
45
+ self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']
46
+ # self.model_path = os.path.join(model_dir, 'bestval.pth')
47
+
48
+ # GPU configuration if available
49
+ self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
50
+
51
+ self.nets = []
52
+ for i in range(5):
53
+ # Instantiate and load network
54
+ network_class = getattr(architectures, 'EfficientNetB4')
55
+ net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
56
+ print('Loading model...')
57
+ state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
58
+ if 'net' not in state_tmp.keys():
59
+ state = OrderedDict({'net': OrderedDict()})
60
+ [state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
61
+ else:
62
+ state = state_tmp
63
+ incomp_keys = net.load_state_dict(state['net'], strict=True)
64
+ print(incomp_keys)
65
+ print('Model loaded!')
66
+
67
+ self.nets += [net]
68
+
69
+ net_normalizer = net.get_normalizer() # pick normalizer from last network
70
+ transform = [
71
+ A.Normalize(mean=net_normalizer.mean, std=net_normalizer.std),
72
+ Ap.transforms.ToTensorV2()
73
+ ]
74
+ self.trans = A.Compose(transform)
75
+
76
+ self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
77
+
78
+ self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
79
+
80
+ def synth_real_detector(self, img_path: str, n_patch: int = 50):
81
+
82
+ # Load image:
83
+ img = np.asarray(Image.open(img_path))
84
+
85
+ # Optout if image is non conforming
86
+ if img.shape == ():
87
+ print('{} None dimension'.format(img_path))
88
+ return None
89
+ if img.shape[0] < 128 or img.shape[1] < 128:
90
+ print('Too small image')
91
+ return None
92
+ if img.ndim != 3:
93
+ print('RGB images only')
94
+ return None
95
+ if img.shape[2] > 3:
96
+ print('Omitting alpha channel')
97
+ img = img[:, :, :3]
98
+
99
+ # Extract test_N random patches from image:
100
+ patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]
101
+
102
+ # Normalization
103
+ transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
104
+
105
+ # Compute scores
106
+ transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
107
+ with torch.no_grad():
108
+ patch_scores = self.net(transf_patch_tensor)
109
+ softmax_scores = torch.softmax(patch_scores, dim=1)
110
+ predictions = torch.argmax(softmax_scores, dim=1)
111
+
112
+ # Majority voting on patches
113
+ if sum(predictions) > len(predictions) // 2:
114
+ majority_voting = 1
115
+ else:
116
+ majority_voting = 0
117
+
118
+ # get an output score from softmax scores:
119
+ # LLR < 0: real
120
+ # LLR > 0: synthetic
121
+
122
+ sign_predictions = majority_voting * 2 - 1
123
+ # select only the scores associated with the estimated class (by majority voting)
124
+ softmax_scores = softmax_scores[:, majority_voting]
125
+ normalized_prediction = torch.max(softmax_scores).item() * sign_predictions
126
+
127
+ return normalized_prediction
128
+
129
+
130
+ def main():
131
+ # img_path
132
+ img_path = "/nas/public/exchange/semafor/eval1/stylegan2/100k-generated-images/car-512x384_cropped/stylegan2-" \
133
+ "config-f-psi-0.5/097000/097001.png"
134
+
135
+ # number of random patches to extract from images
136
+ test_N = 50
137
+
138
+ detector = Detector()
139
+ detector.synth_real_detector(img_path, test_N)
140
+
141
+ return 0
142
+
143
+
144
+ if __name__ == '__main__':
145
+ main()