Spaces:
Runtime error
Runtime error
Nicolò
commited on
Commit
•
5105c39
1
Parent(s):
7d6d113
add more friendly printing
Browse files- gan_vs_real_detector.py +12 -3
gan_vs_real_detector.py
CHANGED
@@ -17,6 +17,8 @@ from utils import architectures
|
|
17 |
from utils.python_patch_extractor.PatchExtractor import PatchExtractor
|
18 |
from PIL import Image
|
19 |
|
|
|
|
|
20 |
|
21 |
class Detector:
|
22 |
def __init__(self):
|
@@ -29,11 +31,11 @@ class Detector:
|
|
29 |
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
30 |
|
31 |
self.nets = []
|
32 |
-
for i in
|
33 |
# Instantiate and load network
|
34 |
network_class = getattr(architectures, 'EfficientNetB4')
|
35 |
net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
|
36 |
-
print('Loading model...')
|
37 |
state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
|
38 |
|
39 |
if 'net' not in state_tmp.keys():
|
@@ -75,6 +77,7 @@ class Detector:
|
|
75 |
print('Omitting alpha channel')
|
76 |
img = img[:, :, :3]
|
77 |
|
|
|
78 |
img_net_scores = []
|
79 |
for net_idx, net in enumerate(self.nets):
|
80 |
|
@@ -113,8 +116,14 @@ class Detector:
|
|
113 |
|
114 |
|
115 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
# debug img_path on fermi:
|
117 |
-
img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
|
118 |
|
119 |
detector = Detector()
|
120 |
score = detector.synth_real_detector(img_path)
|
|
|
17 |
from utils.python_patch_extractor.PatchExtractor import PatchExtractor
|
18 |
from PIL import Image
|
19 |
|
20 |
+
import argparse
|
21 |
+
|
22 |
|
23 |
class Detector:
|
24 |
def __init__(self):
|
|
|
31 |
self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
32 |
|
33 |
self.nets = []
|
34 |
+
for i, l in enumerate('ABCDE'):
|
35 |
# Instantiate and load network
|
36 |
network_class = getattr(architectures, 'EfficientNetB4')
|
37 |
net = network_class(n_classes=2, pretrained=False).eval().to(self.device)
|
38 |
+
print(f'Loading model {l}...')
|
39 |
state_tmp = torch.load(self.weights_path_list[i], map_location='cpu')
|
40 |
|
41 |
if 'net' not in state_tmp.keys():
|
|
|
77 |
print('Omitting alpha channel')
|
78 |
img = img[:, :, :3]
|
79 |
|
80 |
+
print('Computing scores...')
|
81 |
img_net_scores = []
|
82 |
for net_idx, net in enumerate(self.nets):
|
83 |
|
|
|
116 |
|
117 |
|
118 |
def main():
|
119 |
+
|
120 |
+
parser = argparse.ArgumentParser()
|
121 |
+
parser.add_argument('--img_path', help='Pat to the test image', required=True)
|
122 |
+
args = parser.parse_args()
|
123 |
+
|
124 |
+
img_path = args.img_path
|
125 |
# debug img_path on fermi:
|
126 |
+
# img_path = '/home/nbonettini/nvidia_temp/nvidia-alias-free-gan/faces/alias-free-r-afhqv2-512x512/seed40000.png'
|
127 |
|
128 |
detector = Detector()
|
129 |
score = detector.synth_real_detector(img_path)
|