deepfake-detect / gan_vs_real_detector.py
Nicolò
add more friendly printing
5105c39 unverified
raw
history blame
4.91 kB
import os
from collections import OrderedDict
import numpy as np
import torch
torch.manual_seed(21)
import random
random.seed(21)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import albumentations as A
import albumentations.pytorch as Ap
from utils import architectures
from utils.python_patch_extractor.PatchExtractor import PatchExtractor
from PIL import Image
import argparse
class Detector:
def __init__(self):
# You need to download the weight.zip file from here https://www.dropbox.com/s/g1z2u8wl6srjh6v/weigths.zip?dl=0
# and uncompress it into the main folder.
self.weights_path_list = [os.path.join('weights', f'method_{x}.pth') for x in 'ABCDE']
# GPU configuration if available
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
self.nets = []
for i, l in enumerate('ABCDE'):
# Instantiate and load network
network_class = getattr(architectures, 'EfficientNetB4')
net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
print(f'Loading model {l}...')
state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
if 'net' not in state_tmp.keys():
state = OrderedDict({'net': OrderedDict()})
[state['net'].update({'model.{}'.format(k): v}) for k, v in state_tmp.items()]
else:
state = state_tmp
incomp_keys = net.load_state_dict(state['net'], strict=True)
print(incomp_keys)
print('Model loaded!\n')
self.nets += [net]
net_normalizer = net.get_normalizer() # pick normalizer from last network
transform = [
A.Normalize(mean=net_normalizer.mean, std=net_normalizer.std),
Ap.transforms.ToTensorV2()
]
self.trans = A.Compose(transform)
self.cropper = A.RandomCrop(width=128, height=128, always_apply=True, p=1.)
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
def synth_real_detector(self, img_path: str, n_patch: int = 200):
# Load image:
img = np.asarray(Image.open(img_path))
# Opt-out if image is non conforming
if img.shape == ():
print('{} None dimension'.format(img_path))
return None
if img.shape[0] < 128 or img.shape[1] < 128:
print('Too small image')
return None
if img.ndim != 3:
print('RGB images only')
return None
if img.shape[2] > 3:
print('Omitting alpha channel')
img = img[:, :, :3]
print('Computing scores...')
img_net_scores = []
for net_idx, net in enumerate(self.nets):
if net_idx == 0:
# only for detector A, extract N = 200 random patches per image
patch_list = [self.cropper(image=img)['image'] for _ in range(n_patch)]
else:
# for detectors B, C, D, E, extract patches aligned with the 8 x 8 pixel grid:
# we want more or less 200 patches per img
stride_0 = ((((img.shape[0] - 128) // 20) + 7) // 8) * 8
stride_1 = (((img.shape[1] - 128) // 10 + 7) // 8) * 8
pe = PatchExtractor(dim=(128, 128, 3), stride=(stride_0, stride_1, 3))
patches = pe.extract(img)
patch_list = list(patches.reshape((patches.shape[0] * patches.shape[1], 128, 128, 3)))
# Normalization
transf_patch_list = [self.trans(image=patch)['image'] for patch in patch_list]
# Compute scores
transf_patch_tensor = torch.stack(transf_patch_list, dim=0).to(self.device)
with torch.no_grad():
patch_scores = net(transf_patch_tensor).cpu().numpy()
patch_predictions = np.argmax(patch_scores, axis=1)
maj_voting = np.any(patch_predictions).astype(int)
scores_maj_voting = patch_scores[:, maj_voting]
img_net_scores.append(np.nanmax(scores_maj_voting) if maj_voting == 1 else -np.nanmax(scores_maj_voting))
# final score is the average among the 5 scores returned by the detectors
img_score = np.mean(img_net_scores)
return img_score
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--img_path', help='Pat to the test image', required=True)
args = parser.parse_args()
img_path = args.img_path
# debug img_path on fermi:
# img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
detector = Detector()
score = detector.synth_real_detector(img_path)
print('Image Score: {}'.format(score))
return 0
if __name__ == '__main__':
main()