Vivien Chappelier commited on
Commit
91f4aea
·
1 Parent(s): 3f4f0fe
Files changed (1) hide show
  1. detect_torchscript.py +42 -0
detect_torchscript.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import sys
5
+ import numpy as np
6
+ from scipy.special import betainc
7
+
8
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
9
+
10
+ msg_decoder_path = sys.argv[3]
11
+ img_path = sys.argv[1]
12
+ key = int(sys.argv[2])
13
+
14
+ transform_imnet = transforms.Compose([
15
+ transforms.ToTensor(),
16
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
17
+ transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
18
+ ])
19
+
20
+ img = Image.open(sys.argv[1]).convert("RGB").resize((256, 256), Image.BICUBIC)
21
+
22
+ img = transform_imnet(img).unsqueeze(0).to(device)
23
+ print("img.min", img.min())
24
+ print("img.max", img.max())
25
+ print("img.shape", img.shape)
26
+
27
+ msg_decoder = torch.jit.load(msg_decoder_path).to(device)
28
+ msg_decoder.eval()
29
+ with torch.no_grad():
30
+ dec = msg_decoder(img)[0].cpu().numpy()
31
+ #print("dec = ", dec)
32
+ print("dec = ", dec.shape)
33
+
34
+ msg = np.random.default_rng(seed=key).standard_normal(256)
35
+ msg = msg / np.sqrt(np.dot(msg, msg))
36
+ print("dec.dec", dec.dot(dec))
37
+ print("msg.msg", msg.dot(msg))
38
+ print("dec.msg", dec.dot(msg))
39
+
40
+ cos_angle = dec.dot(msg)
41
+ pfa = betainc((256 - 1) * 0.5, 0.5, 1 - cos_angle*cos_angle)
42
+ print("pfa = ", pfa)