File size: 4,759 Bytes
f6b58ff
 
 
 
 
 
 
c5ab13c
 
 
f6b58ff
 
 
 
 
 
6bd8735
f6b58ff
 
5105c39
 
f6b58ff
 
 
 
c5ab13c
 
 
f6b58ff
 
 
 
 
5105c39
f6b58ff
 
 
5105c39
f6b58ff
c5ab13c
f6b58ff
 
 
 
 
 
 
c5ab13c
f6b58ff
 
 
 
 
 
 
 
 
 
 
 
6bd8735
f6b58ff
 
 
 
6bd8735
f6b58ff
 
 
 
 
 
 
 
 
 
 
 
 
5105c39
6bd8735
 
f6b58ff
6bd8735
f6b58ff
6bd8735
 
 
 
f6b58ff
6bd8735
 
 
 
 
 
c5ab13c
f6b58ff
6bd8735
 
f6b58ff
6bd8735
 
 
 
 
f6b58ff
6bd8735
 
 
 
 
 
 
 
f6b58ff
 
 
5105c39
5505fbb
5105c39
 
 
f6b58ff
 
6bd8735
 
 
f6b58ff
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
from collections import OrderedDict

import numpy as np
import torch

torch.manual_seed(21)
import random

random.seed(21)
import torch.multiprocessing

torch.multiprocessing.set_sharing_strategy('file_system')
import albumentations as A
import albumentations.pytorch as Ap
from utils import architectures
from utils.python_patch_extractor.PatchExtractor import PatchExtractor
from PIL import Image

import argparse


class Detector:
    def __init__(self):

        # You need to download the weight.zip file from here https://www.dropbox.com/s/g1z2u8wl6srjh6v/weigths.zip?dl=0
        # and uncompress it into the main folder.
        self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']

        # GPU configuration if available
        self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

        self.nets = []
        for i, l in enumerate('ABCDE'):
            # Instantiate and load network
            network_class = getattr(architectures, 'EfficientNetB4')
            net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
            print(f'Loading model {l}...')
            state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')

            if 'net' not in state_tmp.keys():
                state = OrderedDict({'net': OrderedDict()})
                [state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
            else:
                state = state_tmp
            incomp_keys = net.load_state_dict(state['net'], strict=True)
            print(incomp_keys)
            print('Model loaded!\n')

            self.nets += [net]

        net_normalizer = net.get_normalizer()  # pick normalizer from last network
        transform = [
            A.Normalize(mean=net_normalizer.mean, std=net_normalizer.std),
            Ap.transforms.ToTensorV2()
        ]
        self.trans = A.Compose(transform)
        self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
        self.criterion = torch.nn.CrossEntropyLoss(reduction='none')

    def synth_real_detector(self, img_path: str, n_patch: int = 200):

        # Load image:
        img = np.asarray(Image.open(img_path))

        # Opt-out if image is non conforming
        if img.shape == ():
            print('{} None dimension'.format(img_path))
            return None
        if img.shape[0] < 128 or img.shape[1] < 128:
            print('Too small image')
            return None
        if img.ndim != 3:
            print('RGB images only')
            return None
        if img.shape[2] > 3:
            print('Omitting alpha channel')
            img = img[:, :, :3]

        print('Computing scores...')
        img_net_scores = []
        for net_idx, net in enumerate(self.nets):

            if net_idx == 0:

                # only for detector A, extract N = 200 random patches per image
                patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]

            else:

                # for detectors B, C, D, E, extract patches aligned with the 8 x 8 pixel grid:
                # we want more or less 200 patches per img
                stride_0 = ((((img.shape[0] - 128) // 20) + 7) // 8) * 8
                stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
                pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
                patches = pe.extract(img)
                patch_list = list(patches.reshape((patches.shape[0] * patches.shape[1], 128, 128, 3)))

            # Normalization
            transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]

            # Compute scores
            transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
            with torch.no_grad():
                patch_scores = net(transf_patch_tensor).cpu().numpy()
                patch_predictions = np.argmax(patch_scores, axis=1)

            maj_voting = np.any(patch_predictions).astype(int)
            scores_maj_voting = patch_scores[:, maj_voting]
            img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))

        # final score is the average among the 5 scores returned by the detectors
        img_score = np.mean(img_net_scores)

        return img_score


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--img_path', help='Path to the test image', required=True)
    args = parser.parse_args()

    img_path = args.img_path

    detector = Detector()
    score = detector.synth_real_detector(img_path)

    print('Image Score: {}'.format(score))

    return 0


if __name__ == '__main__':
    main()