Spaces:
Runtime error
Runtime error
import os | |
from collections import OrderedDict | |
import numpy as np | |
import torch | |
torch.manual_seed(21) | |
import random | |
random.seed(21) | |
import torch.multiprocessing | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
import albumentations as A | |
import albumentations.pytorch as Ap | |
from utils import architectures | |
from utils.python_patch_extractor.PatchExtractor import PatchExtractor | |
from PIL import Image | |
class Detector: | |
def __init__(self): | |
# You need to download the weight.zip file from here https://www.dropbox.com/s/g1z2u8wl6srjh6v/weigths.zip?dl=0 | |
# and uncompress it into the main folder. | |
self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE'] | |
# GPU configuration if available | |
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
self.nets = [] | |
for i in range(5): | |
# Instantiate and load network | |
network_class = getattr(architectures, 'EfficientNetB4') | |
net = network_class(n_classes=2, pretrained=False).eval().to(self.device) | |
print('Loading model...') | |
state_tmp = torch.load(self.weights_path_list[i], map_location='cpu') | |
if 'net' not in state_tmp.keys(): | |
state = OrderedDict({'net': OrderedDict()}) | |
[state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()] | |
else: | |
state = state_tmp | |
incomp_keys = net.load_state_dict(state['net'], strict=True) | |
print(incomp_keys) | |
print('Model loaded!\n') | |
self.nets += [net] | |
net_normalizer = net.get_normalizer() # pick normalizer from last network | |
transform = [ | |
A.Normalize(mean=net_normalizer.mean, std=net_normalizer.std), | |
Ap.transforms.ToTensorV2() | |
] | |
self.trans = A.Compose(transform) | |
self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.) | |
self.criterion = torch.nn.CrossEntropyLoss(reduction='none') | |
def synth_real_detector(self, img_path: str, n_patch: int = 200): | |
# Load image: | |
img = np.asarray(Image.open(img_path)) | |
# Opt-out if image is non conforming | |
if img.shape == (): | |
print('{} None dimension'.format(img_path)) | |
return None | |
if img.shape[0] < 128 or img.shape[1] < 128: | |
print('Too small image') | |
return None | |
if img.ndim != 3: | |
print('RGB images only') | |
return None | |
if img.shape[2] > 3: | |
print('Omitting alpha channel') | |
img = img[:, :, :3] | |
img_net_scores = [] | |
for net_idx, net in enumerate(self.nets): | |
if net_idx == 0: | |
# only for detector A, extract N = 200 random patches per image | |
patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)] | |
else: | |
# for detectors B, C, D, E, extract patches aligned with the 8 x 8 pixel grid: | |
# we want more or less 200 patches per img | |
stride_0 = ((((img.shape[0] - 128) // 20) + 7) // 8) * 8 | |
stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8 | |
pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3)) | |
patches = pe.extract(img) | |
patch_list = list(patches.reshape((patches.shape[0] * patches.shape[1], 128, 128, 3))) | |
# Normalization | |
transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list] | |
# Compute scores | |
transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device) | |
with torch.no_grad(): | |
patch_scores = net(transf_patch_tensor).cpu().numpy() | |
patch_predictions = np.argmax(patch_scores, axis=1) | |
maj_voting = np.any(patch_predictions).astype(int) | |
scores_maj_voting = patch_scores[:, maj_voting] | |
img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting)) | |
# final score is the average among the 5 scores returned by the detectors | |
print(img_net_scores) | |
img_score = np.mean(img_net_scores) | |
return img_score | |
def main(): | |
# debug img_path on fermi: | |
img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png' | |
detector = Detector() | |
score = detector.synth_real_detector(img_path) | |
print('Image Score: {}'.format(score)) | |
return 0 | |
if __name__ == '__main__': | |
main() | |