ozyman commited on
Commit
fa7b864
·
1 Parent(s): f3b7b20

Fix model namings

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -41,9 +41,9 @@ tfms = transforms.Compose([
41
  transforms.ToTensor(),
42
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
43
  ])
44
- model = DeePixBiS(pretrained=False)
45
- model.load_state_dict(torch.load('./DeePixBiS/DeePixBiS.pth'))
46
- model.eval()
47
 
48
 
49
  depth_config_path = 'tddfa/configs/mb05_120x120.yml' # 'tddfa/configs/mb1_120x120.yml
@@ -51,12 +51,12 @@ cfg = yaml.load(open(depth_config_path), Loader=yaml.SafeLoader)
51
  tddfa = TDDFA(gpu_mode=False, **cfg)
52
 
53
 
54
- model = CDCN_u(basic_conv=Conv2d_cd, theta=0.7)
55
- model = model.to(device)
56
  weights = torch.load('./DSDG/DUM/checkpoint/CDCN_U_P1_updated.pkl', map_location=device)
57
- model.load_state_dict(weights)
58
- optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.00005)
59
- model.eval()
60
 
61
 
62
  class Normaliztion_valtest(object):
@@ -119,7 +119,7 @@ def inference(img, model_name):
119
  faceRegion = faceRegion.unsqueeze(0)
120
 
121
  if model_name == 'DeePixBiS':
122
- mask, binary = model.forward(faceRegion)
123
  res = torch.mean(mask).item()
124
  cls = 'Real' if res >= pix_threshhold else 'Spoof'
125
  res = 1 - res
@@ -144,7 +144,7 @@ def inference(img, model_name):
144
 
145
  map_score = 0.0
146
  for frame_t in range(inputs.shape[1]):
147
- mu, logvar, map_x, x_concat, x_Block1, x_Block2, x_Block3, x_input = model(inputs[:, frame_t, :, :, :])
148
 
149
  score_norm = torch.sum(mu) / torch.sum(test_maps[:, frame_t, :, :])
150
  map_score += score_norm
 
41
  transforms.ToTensor(),
42
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
43
  ])
44
+ deepix_model = DeePixBiS(pretrained=False)
45
+ deepix_model.load_state_dict(torch.load('./DeePixBiS/DeePixBiS.pth'))
46
+ deepix_model.eval()
47
 
48
 
49
  depth_config_path = 'tddfa/configs/mb05_120x120.yml' # 'tddfa/configs/mb1_120x120.yml
 
51
  tddfa = TDDFA(gpu_mode=False, **cfg)
52
 
53
 
54
+ cdcn_model = CDCN_u(basic_conv=Conv2d_cd, theta=0.7)
55
+ cdcn_model = cdcn_model.to(device)
56
  weights = torch.load('./DSDG/DUM/checkpoint/CDCN_U_P1_updated.pkl', map_location=device)
57
+ cdcn_model.load_state_dict(weights)
58
+ optimizer = optim.Adam(cdcn_model.parameters(), lr=0.001, weight_decay=0.00005)
59
+ cdcn_model.eval()
60
 
61
 
62
  class Normaliztion_valtest(object):
 
119
  faceRegion = faceRegion.unsqueeze(0)
120
 
121
  if model_name == 'DeePixBiS':
122
+ mask, binary = deepix_model.forward(faceRegion)
123
  res = torch.mean(mask).item()
124
  cls = 'Real' if res >= pix_threshhold else 'Spoof'
125
  res = 1 - res
 
144
 
145
  map_score = 0.0
146
  for frame_t in range(inputs.shape[1]):
147
+ mu, logvar, map_x, x_concat, x_Block1, x_Block2, x_Block3, x_input = cdcn_model(inputs[:, frame_t, :, :, :])
148
 
149
  score_norm = torch.sum(mu) / torch.sum(test_maps[:, frame_t, :, :])
150
  map_score += score_norm