Runtime error
Runtime error
File size: 4,759 Bytes
f6b58ff c5ab13c f6b58ff 6bd8735 f6b58ff 5105c39 f6b58ff c5ab13c f6b58ff 5105c39 f6b58ff 5105c39 f6b58ff c5ab13c f6b58ff c5ab13c f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 5105c39 6bd8735 f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 6bd8735 c5ab13c f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 5105c39 5505fbb 5105c39 f6b58ff 6bd8735 f6b58ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
from collections import OrderedDict
import numpy as np
import torch
import random
import torch.multiprocessing
import albumentations as A
import albumentations.pytorch as Ap
from utils import architectures
from utils.python_patch_extractor.PatchExtractor import PatchExtractor
from PIL import Image
import argparse
class Detector:
def __init__(self):
# You need to download the file from here
# and uncompress it into the main folder.
self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']
# GPU configuration if available
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
self.nets = []
for i, l in enumerate('ABCDE'):
# Instantiate and load network
network_class = getattr(architectures, 'EfficientNetB4')
net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
print(f'Loading model {l}...')
state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
if 'net' not in state_tmp.keys():
state = OrderedDict({'net': OrderedDict()})
[state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
state = state_tmp
incomp_keys = net.load_state_dict(state['net'], strict=True)
print('Model loaded!\n')
self.nets += [net]
net_normalizer = net.get_normalizer() # pick normalizer from last network
transform = [
A.Normalize(mean=net_normalizer.mean, std=net_normalizer.std),
self.trans = A.Compose(transform)
self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
def synth_real_detector(self, img_path: str, n_patch: int = 200):
# Load image:
img = np.asarray(
# Opt-out if image is non conforming
if img.shape == ():
print('{} None dimension'.format(img_path))
return None
if img.shape[0] < 128 or img.shape[1] < 128:
print('Too small image')
return None
if img.ndim != 3:
print('RGB images only')
return None
if img.shape[2] > 3:
print('Omitting alpha channel')
img = img[:, :, :3]
print('Computing scores...')
img_net_scores = []
for net_idx, net in enumerate(self.nets):
if net_idx == 0:
# only for detector A, extract N = 200 random patches per image
patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]
# for detectors B, C, D, E, extract patches aligned with the 8 x 8 pixel grid:
# we want more or less 200 patches per img
stride_0 = ((((img.shape[0] - 128) // 20) + 7) // 8) * 8
stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
patches = pe.extract(img)
patch_list = list(patches.reshape((patches.shape[0] * patches.shape[1], 128, 128, 3)))
# Normalization
transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
# Compute scores
transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
with torch.no_grad():
patch_scores = net(transf_patch_tensor).cpu().numpy()
patch_predictions = np.argmax(patch_scores, axis=1)
maj_voting = np.any(patch_predictions).astype(int)
scores_maj_voting = patch_scores[:, maj_voting]
img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))
# final score is the average among the 5 scores returned by the detectors
img_score = np.mean(img_net_scores)
return img_score
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--img_path', help='Path to the test image', required=True)
args = parser.parse_args()
img_path = args.img_path
detector = Detector()
score = detector.synth_real_detector(img_path)
print('Image Score: {}'.format(score))
return 0
if __name__ == '__main__':