Fix model namings
Browse files
app.py
CHANGED
@@ -41,9 +41,9 @@ tfms = transforms.Compose([
|
|
41 |
transforms.ToTensor(),
|
42 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
43 |
])
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
|
48 |
|
49 |
depth_config_path = 'tddfa/configs/mb05_120x120.yml' # 'tddfa/configs/mb1_120x120.yml
|
@@ -51,12 +51,12 @@ cfg = yaml.load(open(depth_config_path), Loader=yaml.SafeLoader)
|
|
51 |
tddfa = TDDFA(gpu_mode=False, **cfg)
|
52 |
|
53 |
|
54 |
-
|
55 |
-
|
56 |
weights = torch.load('./DSDG/DUM/checkpoint/CDCN_U_P1_updated.pkl', map_location=device)
|
57 |
-
|
58 |
-
optimizer = optim.Adam(
|
59 |
-
|
60 |
|
61 |
|
62 |
class Normaliztion_valtest(object):
|
@@ -119,7 +119,7 @@ def inference(img, model_name):
|
|
119 |
faceRegion = faceRegion.unsqueeze(0)
|
120 |
|
121 |
if model_name == 'DeePixBiS':
|
122 |
-
mask, binary =
|
123 |
res = torch.mean(mask).item()
|
124 |
cls = 'Real' if res >= pix_threshhold else 'Spoof'
|
125 |
res = 1 - res
|
@@ -144,7 +144,7 @@ def inference(img, model_name):
|
|
144 |
|
145 |
map_score = 0.0
|
146 |
for frame_t in range(inputs.shape[1]):
|
147 |
-
mu, logvar, map_x, x_concat, x_Block1, x_Block2, x_Block3, x_input =
|
148 |
|
149 |
score_norm = torch.sum(mu) / torch.sum(test_maps[:, frame_t, :, :])
|
150 |
map_score += score_norm
|
|
|
41 |
transforms.ToTensor(),
|
42 |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
43 |
])
|
44 |
+
deepix_model = DeePixBiS(pretrained=False)
|
45 |
+
deepix_model.load_state_dict(torch.load('./DeePixBiS/DeePixBiS.pth'))
|
46 |
+
deepix_model.eval()
|
47 |
|
48 |
|
49 |
depth_config_path = 'tddfa/configs/mb05_120x120.yml' # 'tddfa/configs/mb1_120x120.yml
|
|
|
51 |
tddfa = TDDFA(gpu_mode=False, **cfg)
|
52 |
|
53 |
|
54 |
+
cdcn_model = CDCN_u(basic_conv=Conv2d_cd, theta=0.7)
|
55 |
+
cdcn_model = cdcn_model.to(device)
|
56 |
weights = torch.load('./DSDG/DUM/checkpoint/CDCN_U_P1_updated.pkl', map_location=device)
|
57 |
+
cdcn_model.load_state_dict(weights)
|
58 |
+
optimizer = optim.Adam(cdcn_model.parameters(), lr=0.001, weight_decay=0.00005)
|
59 |
+
cdcn_model.eval()
|
60 |
|
61 |
|
62 |
class Normaliztion_valtest(object):
|
|
|
119 |
faceRegion = faceRegion.unsqueeze(0)
|
120 |
|
121 |
if model_name == 'DeePixBiS':
|
122 |
+
mask, binary = deepix_model.forward(faceRegion)
|
123 |
res = torch.mean(mask).item()
|
124 |
cls = 'Real' if res >= pix_threshhold else 'Spoof'
|
125 |
res = 1 - res
|
|
|
144 |
|
145 |
map_score = 0.0
|
146 |
for frame_t in range(inputs.shape[1]):
|
147 |
+
mu, logvar, map_x, x_concat, x_Block1, x_Block2, x_Block3, x_input = cdcn_model(inputs[:, frame_t, :, :, :])
|
148 |
|
149 |
score_norm = torch.sum(mu) / torch.sum(test_maps[:, frame_t, :, :])
|
150 |
map_score += score_norm
|