asdasdasdasd commited on
Commit
04dca3b
·
1 Parent(s): e1bfa3e

Upload detect_from_videos.py

Browse files
Files changed (1) hide show
  1. detect_from_videos.py +233 -0
detect_from_videos.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import os
3
+ import argparse
4
+ from os.path import join
5
+ import cv2
6
+ import dlib
7
+ import torch
8
+ import torch.nn as nn
9
+ from PIL import Image as pil_image
10
+ from tqdm import tqdm
11
+ from model_core import Two_Stream_Net
12
+ from torchvision import transforms
13
+
14
+ xception_default_data_transforms_256 = {
15
+ 'train': transforms.Compose([
16
+ transforms.Resize((256, 256)),
17
+ transforms.ToTensor(),
18
+ transforms.Normalize([0.5]*3, [0.5]*3)
19
+ ]),
20
+ 'val': transforms.Compose([
21
+ transforms.Resize((256, 256)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.5] * 3, [0.5] * 3)
24
+ ]),
25
+ 'test': transforms.Compose([
26
+ transforms.Resize((256, 256)),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize([0.5] * 3, [0.5] * 3)
29
+ ]),
30
+ }
31
+
32
+ def get_boundingbox(face, width, height, scale=1.3, minsize=None):
33
+ """
34
+ Expects a dlib face to generate a quadratic bounding box.
35
+ :param face: dlib face class
36
+ :param width: frame width
37
+ :param height: frame height
38
+ :param scale: bounding box size multiplier to get a bigger face region
39
+ :param minsize: set minimum bounding box size
40
+ :return: x, y, bounding_box_size in opencv form
41
+ """
42
+ x1 = face.left()
43
+ y1 = face.top()
44
+ x2 = face.right()
45
+ y2 = face.bottom()
46
+ size_bb = int(max(x2 - x1, y2 - y1) * scale)
47
+ if minsize:
48
+ if size_bb < minsize:
49
+ size_bb = minsize
50
+ center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
51
+
52
+ # Check for out of bounds, x-y top left corner
53
+ x1 = max(int(center_x - size_bb // 2), 0)
54
+ y1 = max(int(center_y - size_bb // 2), 0)
55
+ # Check for too big bb size for given x, y
56
+ size_bb = min(width - x1, size_bb)
57
+ size_bb = min(height - y1, size_bb)
58
+
59
+ return x1, y1, size_bb
60
+
61
+
62
+ def preprocess_image(image, cuda=True):
63
+ """
64
+ Preprocesses the image such that it can be fed into our network.
65
+ During this process we envoke PIL to cast it into a PIL image.
66
+
67
+ :param image: numpy image in opencv form (i.e., BGR and of shape
68
+ :return: pytorch tensor of shape [1, 3, image_size, image_size], not
69
+ necessarily casted to cuda
70
+ """
71
+ # Revert from BGR
72
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
73
+ # Preprocess using the preprocessing function used during training and
74
+ # casting it to PIL image
75
+ preprocess = xception_default_data_transforms_256['test']
76
+ preprocessed_image = preprocess(pil_image.fromarray(image))
77
+ # Add first dimension as the network expects a batch
78
+ preprocessed_image = preprocessed_image.unsqueeze(0)
79
+ if cuda:
80
+ preprocessed_image = preprocessed_image.cuda()
81
+ return preprocessed_image
82
+
83
+
84
+ def predict_with_model(image, model, post_function=nn.Softmax(dim=1),
85
+ cuda=True):
86
+ """
87
+ Predicts the label of an input image. Preprocesses the input image and
88
+ casts it to cuda if required
89
+
90
+ :param image: numpy image
91
+ :param model: torch model with linear layer at the end
92
+ :param post_function: e.g., softmax
93
+ :param cuda: enables cuda, must be the same parameter as the model
94
+ :return: prediction (1 = fake, 0 = real)
95
+ """
96
+ # Preprocess
97
+ preprocessed_image = preprocess_image(image, cuda).cuda()
98
+
99
+ # print(preprocessed_image.shape)
100
+
101
+ # Model prediction
102
+ output = model(preprocessed_image)
103
+ # print(output)
104
+ # output = post_function(output[0])
105
+
106
+ # Cast to desired
107
+ _, prediction = torch.max(output[0], 1) # argmax
108
+ prediction = float(prediction.cpu().numpy())
109
+ # print(prediction)
110
+
111
+ return int(prediction), output
112
+
113
+
114
+ def test_full_image_network(video_path, model_path, output_path,
115
+ start_frame=0, end_frame=None, cuda=True):
116
+ """
117
+ Reads a video and evaluates a subset of frames with the a detection network
118
+ that takes in a full frame. Outputs are only given if a face is present
119
+ and the face is highlighted using dlib.
120
+ :param video_path: path to video file
121
+ :param model_path: path to model file (should expect the full sized image)
122
+ :param output_path: path where the output video is stored
123
+ :param start_frame: first frame to evaluate
124
+ :param end_frame: last frame to evaluate
125
+ :param cuda: enable cuda
126
+ :return:
127
+ """
128
+ print('Starting: {}'.format(video_path))
129
+
130
+ if not os.path.exists(output_path):
131
+ os.mkdir(output_path)
132
+
133
+ # Read and write
134
+ reader = cv2.VideoCapture(video_path)
135
+
136
+ # video_fn = video_path.split('/')[-1].split('.')[0]+'.avi'
137
+ video_fn = 'output_video.avi'
138
+ os.makedirs(output_path, exist_ok=True)
139
+ fourcc = cv2.VideoWriter_fourcc(*'MJPG')
140
+ fps = reader.get(cv2.CAP_PROP_FPS)
141
+ num_frames = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))
142
+ writer = None
143
+
144
+ # Face detector
145
+ face_detector = dlib.get_frontal_face_detector()
146
+
147
+ # Load model
148
+ # model, *_ = model_selection(modelname='xception', num_out_classes=2)
149
+ model = Two_Stream_Net()
150
+ model.load_state_dict(torch.load(model_path))
151
+ model = model.cuda()
152
+ model.eval()
153
+
154
+ if cuda:
155
+ model = model.cuda()
156
+
157
+ # Text variables
158
+ font_face = cv2.FONT_HERSHEY_SIMPLEX
159
+ thickness = 2
160
+ font_scale = 1
161
+
162
+ frame_num = 0
163
+ assert start_frame < num_frames - 1
164
+ end_frame = end_frame if end_frame else num_frames
165
+ pbar = tqdm(total=end_frame-start_frame)
166
+
167
+ while reader.isOpened():
168
+ _, image = reader.read()
169
+ if image is None:
170
+ break
171
+ frame_num += 1
172
+
173
+ if frame_num < start_frame:
174
+ continue
175
+ pbar.update(1)
176
+
177
+ # Image size
178
+ height, width = image.shape[:2]
179
+
180
+ # Init output writer
181
+ if writer is None:
182
+ # writer = cv2.VideoWriter(join(output_path, video_fn), fourcc, fps,
183
+ # (height, width)[::-1])
184
+ writer = cv2.VideoWriter(video_fn, fourcc, fps,
185
+ (height, width)[::-1])
186
+
187
+ # 2. Detect with dlib
188
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
189
+ faces = face_detector(gray, 1)
190
+ if len(faces):
191
+ # For now only take biggest face
192
+ face = faces[0]
193
+
194
+ # --- Prediction ---------------------------------------------------
195
+ # Face crop with dlib and bounding box scale enlargement
196
+ x, y, size = get_boundingbox(face, width, height)
197
+ cropped_face = image[y:y+size, x:x+size]
198
+
199
+ # Actual prediction using our model
200
+ prediction, output = predict_with_model(cropped_face, model,
201
+ cuda=cuda)
202
+ # ------------------------------------------------------------------
203
+
204
+ # Text and bb
205
+ x = face.left()
206
+ y = face.top()
207
+ w = face.right() - x
208
+ h = face.bottom() - y
209
+ label = 'fake' if prediction == 0 else 'real'
210
+ color = (0, 255, 0) if prediction == 1 else (0, 0, 255)
211
+ output_list = ['{0:.2f}'.format(float(x)) for x in
212
+ output[0].detach().cpu().numpy()[0]]
213
+ cv2.putText(image, str(output_list)+'=>'+label, (x, y+h+30),
214
+ font_face, font_scale,
215
+ color, thickness, 2)
216
+ # draw box over face
217
+ cv2.rectangle(image, (x, y), (x + w, y + h), color, 2)
218
+
219
+ if frame_num >= end_frame:
220
+ break
221
+
222
+ # Show
223
+ # cv2.imshow('test', image)
224
+ # cv2.waitKey(33) # About 30 fps
225
+ writer.write(image)
226
+ pbar.close()
227
+ if writer is not None:
228
+ writer.release()
229
+ print('Finished! Output saved under {}'.format(output_path))
230
+ else:
231
+ print('Input video file was empty')
232
+ return 'output_video.avi'
233
+