Martlgap commited on
Commit
32d37f5
·
1 Parent(s): ed39f77

initial trial

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit_toggle as tog
3
+ import time
4
+ import numpy as np
5
+ import cv2
6
+ from tools.annotation import draw_mesh, draw_landmarks, draw_bounding_box, draw_text
7
+ from tools.alignment import align_faces
8
+ from tools.identification import load_identification_model, inference, identify
9
+ from tools.utils import show_images, show_faces, rgb
10
+ from tools.detection import load_detection_model, detect_faces
11
+ from tools.webcam import init_webcam
12
+ import logging
13
+
14
+
15
+ # Set logging level to error (To avoid getting spammed by queue warnings etc.)
16
+ logging.basicConfig(level=logging.ERROR)
17
+
18
+
19
+ # Set page layout for streamlit to wide
20
+ st.set_page_config(layout="wide")
21
+
22
+
23
+ # Initialize the Face Detection and Identification Models
24
+ detection_model = load_detection_model(max_faces=2, detection_confidence=0.5, tracking_confidence=0.9)
25
+ identification_model = load_identification_model(name="MobileNet")
26
+
27
+
28
+ # Gallery Processing
29
+ @st.cache_data
30
+ def gallery_processing(gallery_files):
31
+ """Process the gallery images (Complete Face Recognition Pipeline)
32
+
33
+ Args:
34
+ gallery_files (_type_): Files uploaded by the user
35
+
36
+ Returns:
37
+ _type_: Gallery Images, Gallery Embeddings, Gallery Names
38
+ """
39
+ gallery_images, gallery_embs, gallery_names = [], [], []
40
+ if gallery_files is not None:
41
+ for file in gallery_files:
42
+ file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
43
+ img = cv2.cvtColor(
44
+ cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB
45
+ )
46
+ gallery_names.append(
47
+ file.name.split(".jpg")[0].split(".png")[0].split(".jpeg")[0]
48
+ )
49
+ detections = detect_faces(img, detection_model)
50
+ aligned_faces = align_faces(img, np.asarray([detections[0]]))
51
+ gallery_images.append(aligned_faces[0])
52
+ gallery_embs.append(inference(aligned_faces, identification_model)[0])
53
+ return gallery_images, gallery_embs, gallery_names
54
+
55
+
56
+ class SideBar:
57
+ """A class to handle the sidebar
58
+ """
59
+ def __init__(self):
60
+ with st.sidebar:
61
+ st.markdown("# Preferences")
62
+ self.on_face_recognition = tog.st_toggle_switch(
63
+ "Face Recognition", key="activate_face_rec", default_value=True, active_color=rgb(255, 75, 75), track_color=rgb(50, 50, 50)
64
+ )
65
+
66
+ st.markdown("---")
67
+
68
+ st.markdown("## Webcam")
69
+ self.resolution = st.selectbox(
70
+ "Webcam Resolution",
71
+ [(1920, 1080), (1280, 720), (640, 360)],
72
+ index=2,
73
+ )
74
+ st.markdown("To change webcam resolution: Please refresh page and select resolution before starting webcam stream.")
75
+
76
+ st.markdown("---")
77
+ st.markdown("## Face Detection")
78
+ self.max_faces = st.number_input(
79
+ "Maximum Number of Faces", value=2, min_value=1
80
+ )
81
+ self.detection_confidence = st.slider(
82
+ "Min Detection Confidence", min_value=0.0, max_value=1.0, value=0.5
83
+ )
84
+ self.tracking_confidence = st.slider(
85
+ "Min Tracking Confidence", min_value=0.0, max_value=1.0, value=0.9
86
+ )
87
+ switch1, switch2 = st.columns(2)
88
+ with switch1:
89
+ self.on_bounding_box = tog.st_toggle_switch(
90
+ "Show Bounding Box", key="show_bounding_box", default_value=True, active_color=rgb(255, 75, 75), track_color=rgb(50, 50, 50)
91
+ )
92
+ with switch2:
93
+ self.on_five_landmarks = tog.st_toggle_switch(
94
+ "Show Five Landmarks", key="show_five_landmarks", default_value=True, active_color=rgb(255, 75, 75),
95
+ track_color=rgb(50, 50, 50)
96
+ )
97
+ switch3, switch4 = st.columns(2)
98
+ with switch3:
99
+ self.on_mesh = tog.st_toggle_switch(
100
+ "Show Mesh", key="show_mesh", default_value=True, active_color=rgb(255, 75, 75),
101
+ track_color=rgb(50, 50, 50)
102
+ )
103
+ with switch4:
104
+ self.on_text = tog.st_toggle_switch(
105
+ "Show Text", key="show_text", default_value=True, active_color=rgb(255, 75, 75),
106
+ track_color=rgb(50, 50, 50)
107
+ )
108
+ st.markdown("---")
109
+
110
+ st.markdown("## Face Recognition")
111
+ self.similarity_threshold = st.slider(
112
+ "Similarity Threshold", min_value=0.0, max_value=2.0, value=0.67
113
+ )
114
+
115
+ self.on_show_faces = tog.st_toggle_switch(
116
+ "Show Recognized Faces", key="show_recognized_faces", default_value=True, active_color=rgb(255, 75, 75), track_color=rgb(50, 50, 50)
117
+ )
118
+
119
+ self.model_name = st.selectbox(
120
+ "Model",
121
+ ["MobileNet", "ResNet"],
122
+ index=0,
123
+ )
124
+ st.markdown("---")
125
+
126
+ st.markdown("## Gallery")
127
+ self.uploaded_files = st.file_uploader(
128
+ "Choose multiple images to upload", accept_multiple_files=True
129
+ )
130
+
131
+ self.gallery_images, self.gallery_embs, self.gallery_names= gallery_processing(self.uploaded_files)
132
+
133
+ st.markdown("**Gallery Faces**")
134
+ show_images(self.gallery_images, self.gallery_names, 3)
135
+ st.markdown("---")
136
+
137
+
138
+ class KPI:
139
+ """Class for displaying KPIs in a row
140
+ Args:
141
+ keys (list): List of KPI names
142
+ """
143
+ def __init__(self, keys):
144
+ self.kpi_texts = []
145
+ row = st.columns(len(keys))
146
+ for kpi, key in zip(row, keys):
147
+ with kpi:
148
+ item_row = st.columns(2)
149
+ item_row[0].markdown(f"**{key}**:")
150
+ self.kpi_texts.append(item_row[1].markdown("-"))
151
+
152
+ def update_kpi(self, kpi_values):
153
+ for kpi_text, kpi_value in zip(self.kpi_texts, kpi_values):
154
+ kpi_text.write(
155
+ f"<h5 style='text-align: center; color: red;'>{kpi_value:.2f}</h5>"
156
+ if isinstance(kpi_value, float)
157
+ else f"<h5 style='text-align: center; color: red;'>{kpi_value}</h5>",
158
+ unsafe_allow_html=True,
159
+ )
160
+
161
+ # -----------------------------------------------------------------------------------------------
162
+ # Streamlit App
163
+ st.title("FaceID App Demonstration")
164
+
165
+ # Sidebar
166
+ sb = SideBar()
167
+
168
+ # Get Access to Webcam
169
+ webcam = init_webcam(width=sb.resolution[0])
170
+
171
+ # KPI Section
172
+ st.markdown("**Stats**")
173
+ kpi = KPI([
174
+ "**FrameRate**",
175
+ "**Detected Faces**",
176
+ "**Image Dims**",
177
+ "**Detection [ms]**",
178
+ "**Normalization [ms]**",
179
+ "**Inference [ms]**",
180
+ "**Recognition [ms]**",
181
+ "**Annotations [ms]**",
182
+ "**Show Faces [ms]**",
183
+ ])
184
+ st.markdown("---")
185
+
186
+ # Live Stream Display
187
+ stream_display = st.empty()
188
+ st.markdown("---")
189
+
190
+ # Display Detected Faces
191
+ st.markdown("**Detected Faces**")
192
+ face_window = st.empty()
193
+ st.markdown("---")
194
+
195
+
196
+ if webcam:
197
+ prevTime = 0
198
+ while True:
199
+ # Init times to "-" to show something if face recognition is turned off
200
+ time_detection = "-"
201
+ time_alignment = "-"
202
+ time_inference = "-"
203
+ time_identification = "-"
204
+ time_annotations = "-"
205
+ time_show_faces = "-"
206
+
207
+ try:
208
+ # Get Frame from Webcam
209
+ frame = webcam.get_frame(timeout=1)
210
+
211
+ # Convert to OpenCV Image
212
+ frame = frame.to_ndarray(format="rgb24")
213
+ except:
214
+ continue
215
+
216
+ # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
217
+ # FACE RECOGNITION PIPELINE
218
+ if sb.on_face_recognition:
219
+ # FACE DETECTION ---------------------------------------------------------
220
+ start_time = time.time()
221
+ detections = detect_faces(frame, detection_model)
222
+ time_detection = (time.time() - start_time) * 1000
223
+
224
+ # FACE ALIGNMENT ---------------------------------------------------------
225
+ start_time = time.time()
226
+ aligned_faces = align_faces(frame, detections)
227
+ time_alignment = (time.time() - start_time) * 1000
228
+
229
+ # INFERENCE --------------------------------------------------------------
230
+ start_time = time.time()
231
+ if len(sb.gallery_embs) > 0:
232
+ faces_embs = inference(aligned_faces, identification_model)
233
+ else:
234
+ faces_embs = []
235
+ time_inference = (time.time() - start_time) * 1000
236
+
237
+ # FACE IDENTIFCATION -----------------------------------------------------
238
+ start_time = time.time()
239
+ if len(faces_embs) > 0 and len(sb.gallery_embs) > 0:
240
+ ident_names, ident_dists, ident_imgs = identify(faces_embs, sb.gallery_embs, sb.gallery_names, sb.gallery_images, thresh=sb.similarity_threshold)
241
+ else:
242
+ ident_names, ident_dists, ident_imgs = [], [], []
243
+ time_identification = (time.time() - start_time) * 1000
244
+
245
+ # ANNOTATIONS ------------------------------------------------------------
246
+ start_time = time.time()
247
+ frame = cv2.resize(frame, (1920, 1080)) # to make annotation in HD
248
+ frame.flags.writeable = True # (hack to make annotations faster)
249
+ if sb.on_mesh:
250
+ frame = draw_mesh(frame, detections)
251
+ if sb.on_five_landmarks:
252
+ frame = draw_landmarks(frame, detections)
253
+ if sb.on_bounding_box:
254
+ frame = draw_bounding_box(frame, detections, ident_names)
255
+ if sb.on_text:
256
+ frame = draw_text(frame, detections, ident_names)
257
+ time_annotations = (time.time() - start_time) * 1000
258
+
259
+ # DISPLAY DETECTED FACES -------------------------------------------------
260
+ start_time = time.time()
261
+ if sb.on_show_faces:
262
+ show_faces(
263
+ aligned_faces,
264
+ ident_names,
265
+ ident_dists,
266
+ ident_imgs,
267
+ num_cols=3,
268
+ channels="RGB",
269
+ display=face_window,
270
+ )
271
+ time_show_faces = (time.time() - start_time) * 1000
272
+ # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
273
+
274
+
275
+
276
+ # DISPLAY THE LIVE STREAM --------------------------------------------------
277
+ stream_display.image(
278
+ frame, channels="RGB", caption="Live-Stream", use_column_width=True
279
+ )
280
+
281
+ # CALCULATE FPS -----------------------------------------------------------
282
+ currTime = time.time()
283
+ fps = 1 / (currTime - prevTime)
284
+ prevTime = currTime
285
+
286
+ # UPDATE KPIS -------------------------------------------------------------
287
+ kpi.update_kpi(
288
+ [
289
+ fps,
290
+ len(detections),
291
+ sb.resolution,
292
+ time_detection,
293
+ time_alignment,
294
+ time_inference,
295
+ time_identification,
296
+ time_annotations,
297
+ time_show_faces,
298
+ ]
299
+ )
models/mobileNet.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c19b789f661caa8da735566490bfd8895beffb2a1ec97a56b126f0539991aa6
3
+ size 8210384
models/resNet.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4d8b0194957a3ad766135505fc70a91343660151a8103bbb6c3b8ac34dbb4e2
3
+ size 40946048
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ scikit-image
3
+ scikit-learn
4
+ mediapipe
5
+ opencv-python-headless
6
+ watchdog
7
+ streamlit-webrtc
8
+ matplotlib
9
+ streamlit-toggle-switch
10
+ tflite-runtime
tools/__init__.py ADDED
File without changes
tools/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (155 Bytes). View file
 
tools/__pycache__/alignment.cpython-38.pyc ADDED
Binary file (1.38 kB). View file
 
tools/__pycache__/annotation.cpython-38.pyc ADDED
Binary file (2.83 kB). View file
 
tools/__pycache__/detection.cpython-38.pyc ADDED
Binary file (1.5 kB). View file
 
tools/__pycache__/identification.cpython-38.pyc ADDED
Binary file (1.68 kB). View file
 
tools/__pycache__/normalization.cpython-38.pyc ADDED
Binary file (1.64 kB). View file
 
tools/__pycache__/recognition.cpython-38.pyc ADDED
Binary file (2.52 kB). View file
 
tools/__pycache__/utils.cpython-38.pyc ADDED
Binary file (1.55 kB). View file
 
tools/__pycache__/webcam.cpython-38.pyc ADDED
Binary file (686 Bytes). View file
 
tools/alignment.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from skimage.transform import SimilarityTransform
4
+
5
+
6
+ FIVE_LANDMARKS = [470, 475, 1, 57, 287]
7
+
8
+
9
+ def align(img, landmarks, target_size=(112, 112)):
10
+ dst = np.array(
11
+ [
12
+ [
13
+ landmarks.landmark[i].x * img.shape[1],
14
+ landmarks.landmark[i].y * img.shape[0],
15
+ ]
16
+ for i in FIVE_LANDMARKS
17
+ ],
18
+ )
19
+
20
+ src = np.array(
21
+ [
22
+ [38.2946, 51.6963],
23
+ [73.5318, 51.5014],
24
+ [56.0252, 71.7366],
25
+ [41.5493, 92.3655],
26
+ [70.7299, 92.2041],
27
+ ],
28
+ dtype=np.float32,
29
+ )
30
+ tform = SimilarityTransform()
31
+ tform.estimate(dst, src)
32
+ tmatrix = tform.params[0:2, :]
33
+ return cv2.warpAffine(img, tmatrix, target_size, borderValue=0.0)
34
+
35
+
36
+
37
+ def align_faces(img, detections):
38
+ aligned_faces = [align(img, detection.multi_face_landmarks) for detection in detections]
39
+ return aligned_faces
tools/annotation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import streamlit as st
4
+
5
+
6
+ FIVE_LANDMARKS = [470, 475, 1, 57, 287]
7
+ FACE_CONNECTIONS = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
8
+
9
+
10
+
11
+ def draw_bounding_box(img, detections, ident_names, margin=10):
12
+ # Draw the bounding box on the original frame
13
+ for detection, name in zip(detections, ident_names):
14
+
15
+ color = (255, 0, 0) if name == "Unknown" else (0, 255, 0)
16
+
17
+ x_coords = [
18
+ landmark.x * img.shape[1] for landmark in detection.multi_face_landmarks.landmark
19
+ ]
20
+ y_coords = [
21
+ landmark.y * img.shape[0] for landmark in detection.multi_face_landmarks.landmark
22
+ ]
23
+
24
+ x_min, x_max = int(min(x_coords) - margin), int(max(x_coords) + margin)
25
+ y_min, y_max = int(min(y_coords) - margin), int(max(y_coords) + margin)
26
+
27
+ cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color, 2)
28
+ cv2.rectangle(img, (x_min, y_min - img.shape[0] // 25), (x_max, y_min), color, -1)
29
+
30
+ return img
31
+
32
+
33
+ def draw_text(
34
+ img,
35
+ detections,
36
+ ident_names,
37
+ margin=10,
38
+ font_scale=1,
39
+ font_color=(0, 0, 0),
40
+ font=cv2.FONT_HERSHEY_SIMPLEX,
41
+ ):
42
+
43
+ font_scale = img.shape[0] / 1000
44
+ for detection, name in zip(detections, ident_names):
45
+ x_coords = [
46
+ landmark.x * img.shape[1] for landmark in detection.multi_face_landmarks.landmark
47
+ ]
48
+ y_coords = [
49
+ landmark.y * img.shape[0] for landmark in detection.multi_face_landmarks.landmark
50
+ ]
51
+
52
+ x_min = int(min(x_coords) - margin)
53
+ y_min = int(min(y_coords) - margin)
54
+
55
+ cv2.putText(
56
+ img,
57
+ name,
58
+ (x_min + img.shape[0] // 400, y_min - img.shape[0] // 100),
59
+ font,
60
+ font_scale,
61
+ font_color,
62
+ 2,
63
+ )
64
+
65
+ return img
66
+
67
+
68
+ def draw_mesh(img, detections):
69
+ for detection in detections:
70
+ # Draw the connections
71
+ for connection in FACE_CONNECTIONS:
72
+ cv2.line(
73
+ img,
74
+ (
75
+ int(detection.multi_face_landmarks.landmark[connection[0]].x * img.shape[1]),
76
+ int(detection.multi_face_landmarks.landmark[connection[0]].y * img.shape[0]),
77
+ ),
78
+ (
79
+ int(detection.multi_face_landmarks.landmark[connection[1]].x * img.shape[1]),
80
+ int(detection.multi_face_landmarks.landmark[connection[1]].y * img.shape[0]),
81
+ ),
82
+ (255, 255, 255),
83
+ 1,
84
+ )
85
+
86
+ # Draw the landmarks
87
+ for points in detection.multi_face_landmarks.landmark:
88
+ cv2.circle(
89
+ img,
90
+ (
91
+ int(points.x * img.shape[1]),
92
+ int(points.y * img.shape[0]),
93
+ ),
94
+ 1,
95
+ (0, 255, 0),
96
+ -1,
97
+ )
98
+ return img
99
+
100
+
101
+ def draw_landmarks(img, detections):
102
+ # Draw the face landmarks on the original frame
103
+ for points in FIVE_LANDMARKS:
104
+ for detection in detections:
105
+ cv2.circle(
106
+ img,
107
+ (
108
+ int(
109
+ detection.multi_face_landmarks.landmark[points].x
110
+ * img.shape[1]
111
+ ),
112
+ int(
113
+ detection.multi_face_landmarks.landmark[points].y
114
+ * img.shape[0]
115
+ ),
116
+ ),
117
+ 5,
118
+ (0, 0, 255),
119
+ -1,
120
+ )
121
+ return img
tools/detection.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mediapipe as mp
2
+ import streamlit as st
3
+
4
+
5
+ class Detection:
6
+ multi_face_bboxes = []
7
+ multi_face_landmarks = []
8
+
9
+
10
+ #@st.cache_resource
11
+ def load_detection_model(max_faces=2, detection_confidence=0.5, tracking_confidence=0.5):
12
+ model = mp.solutions.face_mesh.FaceMesh(
13
+ refine_landmarks=True,
14
+ min_detection_confidence=detection_confidence,
15
+ min_tracking_confidence=tracking_confidence,
16
+ max_num_faces=max_faces,
17
+ )
18
+ return model
19
+
20
+
21
+ def detect_faces(frame, model):
22
+
23
+ # Process the frame with MediaPipe Face Mesh
24
+ results = model.process(frame)
25
+
26
+ # Get the Bounding Boxes from the detected faces
27
+ detections = []
28
+ if results.multi_face_landmarks:
29
+ for landmarks in results.multi_face_landmarks:
30
+ x_coords = [
31
+ landmark.x * frame.shape[1] for landmark in landmarks.landmark
32
+ ]
33
+ y_coords = [
34
+ landmark.y * frame.shape[0] for landmark in landmarks.landmark
35
+ ]
36
+
37
+ x_min, x_max = int(min(x_coords)), int(max(x_coords))
38
+ y_min, y_max = int(min(y_coords)), int(max(y_coords))
39
+
40
+ detection = Detection()
41
+ detection.multi_face_bboxes=[x_min, y_min, x_max, y_max]
42
+ detection.multi_face_landmarks=landmarks
43
+ detections.append(detection)
44
+ return detections
tools/identification.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tflite_runtime.interpreter as tflite
3
+ from sklearn.metrics.pairwise import cosine_distances
4
+ import streamlit as st
5
+ import time
6
+
7
+
8
+ MODEL_PATHS = {
9
+ "MobileNet": "./models/mobileNet.tflite",
10
+ "ResNet": "./models/resNet.tflite",
11
+ }
12
+
13
+
14
+ #@st.cache_resource
15
+ def load_identification_model(name="MobileNet"):
16
+ model = tflite.Interpreter(model_path=MODEL_PATHS[name])
17
+ return model
18
+
19
+
20
+ def inference(imgs, model):
21
+ if len(imgs) > 0:
22
+ imgs = np.asarray(imgs).astype(np.float32) / 255
23
+ model.resize_tensor_input(model.get_input_details()[0]["index"], imgs.shape)
24
+ model.allocate_tensors()
25
+ model.set_tensor(model.get_input_details()[0]["index"], imgs)
26
+ model.invoke()
27
+ embs = [model.get_tensor(elem["index"]) for elem in model.get_output_details()]
28
+ return embs[0]
29
+ else:
30
+ return []
31
+
32
+
33
+ def identify(embs_src, embs_gal, labels_gal, imgs_gal, thresh=None):
34
+ all_dists = cosine_distances(embs_src, embs_gal)
35
+ ident_names, ident_dists, ident_imgs = [], [], []
36
+ for dists in all_dists:
37
+ idx_min = np.argmin(dists)
38
+ if thresh and dists[idx_min] > thresh:
39
+ dist = dists[idx_min]
40
+ pred = None
41
+ else:
42
+ dist = dists[idx_min]
43
+ pred = idx_min
44
+ ident_names.append(labels_gal[pred] if pred is not None else "Unknown")
45
+ ident_dists.append(dist)
46
+ ident_imgs.append(imgs_gal[pred] if pred is not None else None)
47
+ return ident_names, ident_dists, ident_imgs
tools/utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+
4
+ def rgb(r, g, b):
5
+ return '#{:02x}{:02x}{:02x}'.format(r, g, b)
6
+
7
+
8
+ def show_images(images, names, num_cols, channels="RGB"):
9
+ num_images = len(images)
10
+
11
+ # Calculate the number of rows and columns
12
+ num_rows = -(
13
+ -num_images // num_cols
14
+ ) # This also handles the case when num_images is not a multiple of num_cols
15
+
16
+ for row in range(num_rows):
17
+ # Create the columns
18
+ cols = st.sidebar.columns(num_cols)
19
+
20
+ for i, col in enumerate(cols):
21
+ idx = row * num_cols + i
22
+
23
+ if idx < num_images:
24
+ img = images[idx]
25
+ if len(names) == 0:
26
+ names = ["Unknown"] * len(images)
27
+ name = names[idx]
28
+ col.image(img, caption=name, channels=channels, width=112)
29
+
30
+
31
+ def show_faces(images, names, distances, gal_images, num_cols, channels="RGB", display=st):
32
+ if len(images) == 0 or len(names) == 0:
33
+ display.write("No faces detected, or gallery empty!")
34
+ return
35
+ # Calculate the number of rows and columns
36
+ num_rows = -(
37
+ -len(images) // num_cols
38
+ ) # This also handles the case when num_images is not a multiple of num_cols
39
+
40
+ for row in range(num_rows):
41
+ # Create the columns
42
+ cols = display.columns(num_cols)
43
+
44
+ for i, col in enumerate(cols):
45
+ idx = row * num_cols + i
46
+
47
+ if idx < len(images):
48
+ img = images[idx]
49
+ name = names[idx]
50
+ dist = distances[idx]
51
+ col.image(img, channels=channels, width=112)
52
+
53
+ if gal_images[idx] is not None:
54
+ col.text(" ⬍ matching ⬍")
55
+ col.image(gal_images[idx], caption=name, channels=channels, width=112)
56
+ else:
57
+ col.markdown("")
58
+ col.write("No match found")
59
+ col.markdown(
60
+ f"**Distance: {dist:.4f}**" if dist else f"**Distance: -**"
61
+ )
62
+ else:
63
+ col.empty()
64
+ col.markdown("")
65
+ col.empty()
66
+ col.markdown("")
tools/webcam.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode
3
+
4
+
5
+ @st.cache_resource(experimental_allow_widgets=True)
6
+ def init_webcam(width=680):
7
+ ctx = webrtc_streamer(
8
+ key="FaceIDAppDemo",
9
+ mode=WebRtcMode.SENDONLY,
10
+ media_stream_constraints={
11
+ "video": {
12
+ "width": {
13
+ "min": width,
14
+ "ideal": width,
15
+ "max": width,
16
+ },
17
+ },
18
+ "audio": False,
19
+ },
20
+ video_receiver_size=1,
21
+ async_processing=True,
22
+ )
23
+ return ctx.video_receiver