Spaces:
Runtime error
Runtime error
File size: 4,680 Bytes
f6b58ff c5ab13c f6b58ff 6bd8735 f6b58ff c5ab13c f6b58ff c5ab13c f6b58ff c5ab13c f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 6bd8735 c5ab13c f6b58ff 6bd8735 f6b58ff 6bd8735 f6b58ff 6bd8735 c5ab13c 6bd8735 f6b58ff c5ab13c 6bd8735 f6b58ff 6bd8735 f6b58ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
from collections import OrderedDict
import numpy as np
import torch
torch.manual_seed(21)
import random
random.seed(21)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import albumentations as A
import albumentations.pytorch as Ap
from utils import architectures
from utils.python_patch_extractor.PatchExtractor import PatchExtractor
from PIL import Image
class Detector:
def __init__(self):
# You need to download the weight.zip file from here https://www.dropbox.com/s/g1z2u8wl6srjh6v/weigths.zip?dl=0
# and uncompress it into the main folder.
self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']
# GPU configuration if available
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
self.nets = []
for i in range(5):
# Instantiate and load network
network_class = getattr(architectures, 'EfficientNetB4')
net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
print('Loading model...')
state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
if 'net' not in state_tmp.keys():
state = OrderedDict({'net': OrderedDict()})
[state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
else:
state = state_tmp
incomp_keys = net.load_state_dict(state['net'], strict=True)
print(incomp_keys)
print('Model loaded!\n')
self.nets += [net]
net_normalizer = net.get_normalizer() # pick normalizer from last network
transform = [
A.Normalize(mean=net_normalizer.mean, std=net_normalizer.std),
Ap.transforms.ToTensorV2()
]
self.trans = A.Compose(transform)
self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
def synth_real_detector(self, img_path: str, n_patch: int = 200):
# Load image:
img = np.asarray(Image.open(img_path))
# Opt-out if image is non conforming
if img.shape == ():
print('{} None dimension'.format(img_path))
return None
if img.shape[0] < 128 or img.shape[1] < 128:
print('Too small image')
return None
if img.ndim != 3:
print('RGB images only')
return None
if img.shape[2] > 3:
print('Omitting alpha channel')
img = img[:, :, :3]
img_net_scores = []
for net_idx, net in enumerate(self.nets):
if net_idx == 0:
# only for detector A, extract N = 200 random patches per image
patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]
else:
# for detectors B, C, D, E, extract patches aligned with the 8 x 8 pixel grid:
# we want more or less 200 patches per img
stride_0 = ((((img.shape[0] - 128) // 20) + 7) // 8) * 8
stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
patches = pe.extract(img)
patch_list = list(patches.reshape((patches.shape[0] * patches.shape[1], 128, 128, 3)))
# Normalization
transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
# Compute scores
transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
with torch.no_grad():
patch_scores = net(transf_patch_tensor).cpu().numpy()
patch_predictions = np.argmax(patch_scores, axis=1)
maj_voting = np.any(patch_predictions).astype(int)
scores_maj_voting = patch_scores[:, maj_voting]
img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))
# final score is the average among the 5 scores returned by the detectors
print(img_net_scores)
img_score = np.mean(img_net_scores)
return img_score
def main():
# debug img_path on fermi:
img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
detector = Detector()
score = detector.synth_real_detector(img_path)
print('Image Score: {}'.format(score))
return 0
if __name__ == '__main__':
main()
|