ozyman commited on
Commit
73cb092
·
1 Parent(s): a3d3525

refactor, add thresh, save fix

Browse files
Files changed (1) hide show
  1. app.py +107 -80
app.py CHANGED
@@ -36,8 +36,9 @@ app_version = 'ddn1'
36
 
37
  device = torch.device("cpu")
38
  labels = ['Live', 'Spoof']
39
- pix_threshhold = 0.45
40
- dsdg_threshold = 0.003
 
41
  examples = [
42
  ['examples/1_1_21_2_33_scene_fake.jpg'],
43
  ['examples/frame150_real.jpg'],
@@ -78,7 +79,7 @@ class Normaliztion_valtest(object):
78
  return image_x
79
 
80
 
81
- def prepare_data(images, boxes, depths):
82
  transform = transforms.Compose([Normaliztion_valtest()])
83
  files_total = 1
84
  image_x = np.zeros((files_total, 256, 256, 3))
@@ -86,10 +87,10 @@ def prepare_data(images, boxes, depths):
86
 
87
  for i, (image, bbox, depth_img) in enumerate(
88
  zip(images, boxes, depths)):
89
- x, y, w, h = bbox
90
  depth_img = cv.cvtColor(depth_img, cv.COLOR_RGB2GRAY)
91
- image = image[y:y + h, x:x + w]
92
- depth_img = depth_img[y:y + h, x:x + w]
93
 
94
  image_x[i, :, :, :] = cv.resize(image, (256, 256))
95
  # transform to binary mask --> threshold = 0
@@ -100,89 +101,110 @@ def prepare_data(images, boxes, depths):
100
  depth_x = torch.from_numpy(depth_x.astype(float)).float()
101
  return image_x, depth_x
102
 
103
-
104
  def find_largest_face(faces):
 
105
  largest_face = None
106
  largest_area = 0
107
-
108
- for (x, y, w, h) in faces:
109
  area = w * h
110
  if area > largest_area:
111
  largest_area = area
112
- largest_face = (x, y, w, h)
113
  return largest_face
114
 
115
 
116
- def inference(img):
 
117
  if img is None:
118
- return None, {}, None, None, {}, None, None
119
  grey = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
120
  faces = faceClassifier.detectMultiScale(
121
  grey, scaleFactor=1.1, minNeighbors=4)
122
- face = find_largest_face(faces)
123
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if face is not None:
125
  x, y, w, h = face
126
  x2 = x + w
127
  y2 = y + h
128
- faceRegion = img[y:y2, x:x2]
129
- faceRegion = tfms(faceRegion)
130
- faceRegion = faceRegion.unsqueeze(0)
131
-
132
- # if model_name == 'DeePixBiS':
133
- mask, binary = deepix_model.forward(faceRegion)
134
- res_deepix = torch.mean(mask).item()
135
- cls_deepix = 'Real' if res_deepix >= pix_threshhold else 'Spoof'
136
-
137
- confidences_deepix = {'Real confidence': res_deepix}
138
- color_deepix = (0, 255, 0) if cls_deepix == 'Real' else (255, 0, 0)
139
- img_deepix = cv.rectangle(img.copy(), (x, y), (x2, y2), color_deepix, 2)
140
- cv.putText(img_deepix, cls_deepix, (x, y2 + 30),
141
- cv.FONT_HERSHEY_COMPLEX, 1, color_deepix)
142
-
143
- # else:
144
- dense_flag = True
145
- box = [x, y, x2, y2, 1]
146
- param_lst, roi_box_lst = tddfa(img, [box])
147
-
148
- ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)
149
- depth_img = depth(img, ver_lst, tddfa.tri, with_bg_flag=False)
150
- with torch.no_grad():
151
- map_score_list = []
152
- image_x, map_x = prepare_data([img], [list(face)], [depth_img])
153
- # get the inputs
154
- image_x = image_x.unsqueeze(0)
155
- map_x = map_x.unsqueeze(0)
156
- inputs = image_x.to(device)
157
- test_maps = map_x.to(device)
158
- optimizer.zero_grad()
159
-
160
- map_score = 0.0
161
- for frame_t in range(inputs.shape[1]):
162
- mu, logvar, map_x, x_concat, x_Block1, x_Block2, x_Block3, x_input = cdcn_model(inputs[:, frame_t, :, :, :])
163
-
164
- score_norm = torch.sum(mu) / torch.sum(test_maps[:, frame_t, :, :])
165
- map_score += score_norm
166
- map_score = map_score / inputs.shape[1]
167
- map_score_list.append(map_score)
168
-
169
- res_dsdg = map_score_list[0].item()
170
- if res_dsdg > 10:
171
- res_dsdg = 0.0
172
- cls_dsdg = 'Real' if res_dsdg >= dsdg_threshold else 'Spoof'
173
- res_dsdg = res_dsdg * 300
174
-
175
- confidences_dsdg = {'Real confidence': res_dsdg}
176
- color_dsdg = (0, 255, 0) if cls_dsdg == 'Real' else (255, 0, 0)
177
- img_dsdg = cv.rectangle(img.copy(), (x, y), (x2, y2), color_dsdg, 2)
178
- cv.putText(img_dsdg, cls_dsdg, (x, y2 + 30),
179
- cv.FONT_HERSHEY_COMPLEX, 1, color_dsdg)
180
-
181
- cls_deepix, cls_dsdg = [1 if cls_ == 'Real' else 0 for cls_ in [cls_deepix, cls_dsdg]]
182
-
183
- return img_deepix, confidences_deepix, img_dsdg, confidences_dsdg, cls_deepix, cls_dsdg
184
  else:
185
- return img, {}, img, {}, None, None
186
 
187
 
188
  def upload_to_s3(image_array, app_version, *labels):
@@ -199,12 +221,12 @@ def upload_to_s3(image_array, app_version, *labels):
199
  s3 = boto3.client('s3')
200
 
201
  # Encode labels and app version in image file name
202
- encoded_labels = '_'.join([str(label) for label in labels])
203
  random_string = str(uuid.uuid4()).split('-')[-1]
204
  image_name = f"{folder}/{app_version}/{encoded_labels}_{random_string}.jpg"
205
 
206
  # Save image as JPEG
207
- image = Image.fromarray(np.uint8(image_array * 255))
208
  image_bytes = io.BytesIO()
209
  image.save(image_bytes, format='JPEG')
210
  image_bytes.seek(0)
@@ -222,25 +244,30 @@ demo = gr.Blocks()
222
  with demo:
223
  with gr.Row():
224
  with gr.Column():
225
- input_img = gr.Image(source='webcam', shape=None, type='numpy')
 
226
  btn_run = gr.Button(value="Run")
227
  with gr.Column():
228
  outputs=[
229
  gr.Image(label='DeePixBiS', type='numpy'),
230
  gr.Label(num_top_classes=2, label='DeePixBiS'),
 
231
  gr.Image(label='DSDG', type='numpy'),
232
- gr.Label(num_top_classes=2, label='DSDG')]
 
233
  with gr.Column():
234
  radio = gr.Radio(
235
- ["Real", "Spoof", "None"], label="True label", type='index')
236
  flag = gr.Button(value="Flag")
237
  status = gr.Textbox()
238
- example_block = gr.Examples(examples, [input_img], outputs+labels)
239
 
240
- labels = [gr.Number(visible=False, value=-1), gr.Number(visible=False, value=-1)]
241
- btn_run.click(inference, [input_img], outputs+labels)
242
  app_version_block = gr.Textbox(value=app_version, visible=False)
243
- flag.click(upload_to_s3, [input_img, app_version_block, radio]+labels, [status], show_progress=True)
 
 
 
244
 
245
 
246
  if __name__ == '__main__':
 
36
 
37
  device = torch.device("cpu")
38
  labels = ['Live', 'Spoof']
39
+ PIX_THRESHOLD = 0.45
40
+ DSDG_THRESHOLD = 80
41
+ MIN_FACE_WIDTH_THRESHOLD = 210
42
  examples = [
43
  ['examples/1_1_21_2_33_scene_fake.jpg'],
44
  ['examples/frame150_real.jpg'],
 
79
  return image_x
80
 
81
 
82
+ def prepare_data_dsdg(images, boxes, depths):
83
  transform = transforms.Compose([Normaliztion_valtest()])
84
  files_total = 1
85
  image_x = np.zeros((files_total, 256, 256, 3))
 
87
 
88
  for i, (image, bbox, depth_img) in enumerate(
89
  zip(images, boxes, depths)):
90
+ x, y, x2, y2 = bbox
91
  depth_img = cv.cvtColor(depth_img, cv.COLOR_RGB2GRAY)
92
+ image = image[y:y2, x:x2]
93
+ depth_img = depth_img[y:y2, x:x2]
94
 
95
  image_x[i, :, :, :] = cv.resize(image, (256, 256))
96
  # transform to binary mask --> threshold = 0
 
101
  depth_x = torch.from_numpy(depth_x.astype(float)).float()
102
  return image_x, depth_x
103
 
 
104
  def find_largest_face(faces):
105
+ # find the largest face in the list
106
  largest_face = None
107
  largest_area = 0
108
+ for face in faces:
109
+ x, y, w, h = face
110
  area = w * h
111
  if area > largest_area:
112
  largest_area = area
113
+ largest_face = face
114
  return largest_face
115
 
116
 
117
+ def extract_face(img):
118
+ face = None
119
  if img is None:
120
+ return face
121
  grey = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
122
  faces = faceClassifier.detectMultiScale(
123
  grey, scaleFactor=1.1, minNeighbors=4)
124
+ if len(faces):
125
+ face = find_largest_face(faces)
126
+ return face
127
+
128
+
129
+ def deepix_model_inference(img, bbox):
130
+ x, y, x2, y2 = bbox
131
+ faceRegion = img[y:y2, x:x2]
132
+ faceRegion = tfms(faceRegion)
133
+ faceRegion = faceRegion.unsqueeze(0)
134
+ mask, binary = deepix_model.forward(faceRegion)
135
+ res_deepix = torch.mean(mask).item()
136
+ cls_deepix = 'Real' if res_deepix >= PIX_THRESHOLD else 'Spoof'
137
+ confidences_deepix = {'Real confidence': res_deepix}
138
+ color_deepix = (0, 255, 0) if cls_deepix == 'Real' else (255, 0, 0)
139
+ img_deepix = cv.rectangle(img.copy(), (x, y), (x2, y2), color_deepix, 2)
140
+ cv.putText(img_deepix, cls_deepix, (x, y2 + 30),
141
+ cv.FONT_HERSHEY_COMPLEX, 1, color_deepix)
142
+ cls_deepix = 1 if cls_deepix == 'Real' else 0
143
+ return img_deepix, confidences_deepix, cls_deepix
144
+
145
+
146
+ def dsdg_model_inference(img, bbox, dsdg_thresh):
147
+ dsdg_thresh = dsdg_thresh / 30000
148
+ dense_flag = True
149
+ x, y, x2, y2 = bbox
150
+ w = x2 - x
151
+ h = y2 - y
152
+ if w < MIN_FACE_WIDTH_THRESHOLD:
153
+ color_dsdg = (0, 0, 0)
154
+ text = f'Small res ({w}*{h})'
155
+ img_dsdg = cv.rectangle(img.copy(), (x, y), (x2, y2), color_dsdg, 2)
156
+ cv.putText(img_dsdg, text, (x, y2 + 30),
157
+ cv.FONT_HERSHEY_COMPLEX, 1, color_dsdg)
158
+ cls_dsdg = 2
159
+ return img_dsdg, {}, cls_dsdg
160
+ bbox_conf = list(bbox)
161
+ bbox_conf.append(1)
162
+ param_lst, roi_box_lst = tddfa(img, [bbox_conf])
163
+ ver_lst = tddfa.recon_vers(param_lst, roi_box_lst, dense_flag=dense_flag)
164
+ depth_img = depth(img, ver_lst, tddfa.tri, with_bg_flag=False)
165
+ with torch.no_grad():
166
+ map_score_list = []
167
+ image_x, map_x = prepare_data_dsdg([img], [list(bbox)], [depth_img])
168
+ # get the inputs
169
+ image_x = image_x.unsqueeze(0)
170
+ map_x = map_x.unsqueeze(0)
171
+ inputs = image_x.to(device)
172
+ test_maps = map_x.to(device)
173
+ optimizer.zero_grad()
174
+ map_score = 0.0
175
+ for frame_t in range(inputs.shape[1]):
176
+ mu, logvar, map_x, x_concat, x_Block1, x_Block2, x_Block3, x_input = cdcn_model(inputs[:, frame_t, :, :, :])
177
+ score_norm = torch.sum(mu) / torch.sum(test_maps[:, frame_t, :, :])
178
+ map_score += score_norm
179
+ map_score = map_score / inputs.shape[1]
180
+ map_score_list.append(map_score)
181
+ res_dsdg = map_score_list[0].item()
182
+ if res_dsdg > 10:
183
+ res_dsdg = 0.0
184
+ cls_dsdg = 'Real' if res_dsdg >= dsdg_thresh else 'Spoof'
185
+ text = f'{cls_dsdg} {w}*{h}'
186
+ res_dsdg = res_dsdg * 300
187
+ confidences_dsdg = {'Real confidence': res_dsdg}
188
+ color_dsdg = (0, 255, 0) if cls_dsdg == 'Real' else (255, 0, 0)
189
+ img_dsdg = cv.rectangle(img.copy(), (x, y), (x2, y2), color_dsdg, 2)
190
+ cv.putText(img_dsdg, text, (x, y2 + 30),
191
+ cv.FONT_HERSHEY_COMPLEX, 1, color_dsdg)
192
+ cls_dsdg = 1 if cls_dsdg == 'Real' else 0
193
+ return img_dsdg, confidences_dsdg, cls_dsdg
194
+
195
+
196
+ def inference(img, dsdg_thresh):
197
+ face = extract_face(img)
198
  if face is not None:
199
  x, y, w, h = face
200
  x2 = x + w
201
  y2 = y + h
202
+ bbox = (x, y, x2, y2)
203
+ img_deepix, confidences_deepix, cls_deepix = deepix_model_inference(img, bbox)
204
+ img_dsdg, confidences_dsdg, cls_dsdg = dsdg_model_inference(img, bbox, dsdg_thresh)
205
+ return img_deepix, confidences_deepix, cls_deepix, img_dsdg, confidences_dsdg, cls_dsdg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  else:
207
+ return img, {}, None, img, {}, None
208
 
209
 
210
  def upload_to_s3(image_array, app_version, *labels):
 
221
  s3 = boto3.client('s3')
222
 
223
  # Encode labels and app version in image file name
224
+ encoded_labels = '_'.join([str(int(label)) for label in labels])
225
  random_string = str(uuid.uuid4()).split('-')[-1]
226
  image_name = f"{folder}/{app_version}/{encoded_labels}_{random_string}.jpg"
227
 
228
  # Save image as JPEG
229
+ image = Image.fromarray(image_array)
230
  image_bytes = io.BytesIO()
231
  image.save(image_bytes, format='JPEG')
232
  image_bytes.seek(0)
 
244
  with demo:
245
  with gr.Row():
246
  with gr.Column():
247
+ input_img = gr.Image(source='webcam', shape=None, type='numpy', streaming=False)
248
+ dsdg_thresh = gr.Slider(value=DSDG_THRESHOLD, label='DSDG threshold')
249
  btn_run = gr.Button(value="Run")
250
  with gr.Column():
251
  outputs=[
252
  gr.Image(label='DeePixBiS', type='numpy'),
253
  gr.Label(num_top_classes=2, label='DeePixBiS'),
254
+ gr.Number(visible=False, value=-1),
255
  gr.Image(label='DSDG', type='numpy'),
256
+ gr.Label(num_top_classes=2, label='DSDG'),
257
+ gr.Number(visible=False, value=-1)]
258
  with gr.Column():
259
  radio = gr.Radio(
260
+ ["Spoof", "Real", "None"], label="True label", type='index')
261
  flag = gr.Button(value="Flag")
262
  status = gr.Textbox()
263
+ example_block = gr.Examples(examples, [input_img], outputs)
264
 
265
+ btn_run.click(inference, [input_img, dsdg_thresh], outputs)
 
266
  app_version_block = gr.Textbox(value=app_version, visible=False)
267
+ flag.click(
268
+ upload_to_s3,
269
+ [input_img, app_version_block, radio]+[outputs[2], outputs[5]],
270
+ [status], show_progress=True)
271
 
272
 
273
  if __name__ == '__main__':