import torch import torchvision.transforms as transforms from PIL import Image import sys import numpy as np from scipy.special import betainc device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') msg_decoder_path = sys.argv[3] img_path = sys.argv[1] key = int(sys.argv[2]) transform_imnet = transforms.Compose([ transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]) ]) img = Image.open(sys.argv[1]).convert("RGB").resize((256, 256), Image.BICUBIC) img = transform_imnet(img).unsqueeze(0).to(device) print("img.min", img.min()) print("img.max", img.max()) print("img.shape", img.shape) msg_decoder = torch.jit.load(msg_decoder_path).to(device) msg_decoder.eval() with torch.no_grad(): dec = msg_decoder(img)[0].cpu().numpy() #print("dec = ", dec) print("dec = ", dec.shape) msg = np.random.default_rng(seed=key).standard_normal(256) msg = msg / np.sqrt(np.dot(msg, msg)) print("dec.dec", dec.dot(dec)) print("msg.msg", msg.dot(msg)) print("dec.msg", dec.dot(msg)) cos_angle = dec.dot(msg) pfa = betainc((256 - 1) * 0.5, 0.5, 1 - cos_angle*cos_angle) print("pfa = ", pfa)