Martlgap commited on
Commit
c59c6ab
·
1 Parent(s): 7c2803a

added missing file

Browse files
Files changed (2) hide show
  1. app.py +186 -167
  2. tools/gallery.py +37 -0
app.py CHANGED
@@ -3,140 +3,102 @@ import time
3
  from typing import List
4
  from streamlit_webrtc import webrtc_streamer, WebRtcMode
5
  import logging
6
- import mediapipe as mp
7
- import tflite_runtime.interpreter as tflite
8
  import av
9
- import numpy as np
10
  import queue
11
  from streamlit_toggle import st_toggle_switch
12
  import pandas as pd
13
- from tools.nametypes import Stats, Detection
14
- from pathlib import Path
15
- from tools.utils import get_ice_servers, download_file, display_match, rgb, format_dflist
16
- from tools.face_recognition import (
17
- detect_faces,
18
- align_faces,
19
- inference,
20
- draw_detections,
21
- recognize_faces,
22
- process_gallery,
23
- )
24
 
25
  # Set logging level to error (To avoid getting spammed by queue warnings etc.)
26
  logger = logging.getLogger(__name__)
27
  logging.basicConfig(level=logging.ERROR)
28
 
29
- ROOT = Path(__file__).parent
30
-
31
- MODEL_URL = (
32
- "https://github.com/Martlgap/FaceIDLight/releases/download/v.0.1/mobileNet.tflite"
33
- )
34
- MODEL_LOCAL_PATH = ROOT / "./models/mobileNet.tflite"
35
-
36
- DETECTION_CONFIDENCE = 0.5
37
- TRACKING_CONFIDENCE = 0.5
38
- MAX_FACES = 2
39
 
40
  # Set page layout for streamlit to wide
41
- st.set_page_config(
42
- layout="wide", page_title="FaceID App Demo", page_icon=":sunglasses:"
43
- )
44
  with st.sidebar:
45
- st.markdown("# Preferences")
46
  face_rec_on = st_toggle_switch(
47
- "Face Recognition",
48
  key="activate_face_rec",
49
  default_value=True,
50
  active_color=rgb(255, 75, 75),
51
  track_color=rgb(50, 50, 50),
 
52
  )
53
 
54
- st.markdown("## Webcam & Stream")
55
- resolution = st.selectbox(
56
- "Webcam Resolution",
57
- [(1920, 1080), (1280, 720), (640, 360)],
58
- index=2,
59
- )
60
- st.markdown("Note: To change the resolution, you have to restart the stream.")
 
61
 
62
- ice_server = st.selectbox("ICE Server", ["twilio", "metered"], index=0)
63
- st.markdown(
64
- "Note: metered is a free server with limited bandwidth, and can take a while to connect. Twilio is a paid service and is payed by me, so please don't abuse it."
65
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- st.markdown("## Face Detection")
68
- max_faces = st.number_input("Maximum Number of Faces", value=2, min_value=1)
69
- detection_confidence = st.slider(
70
- "Min Detection Confidence", min_value=0.0, max_value=1.0, value=0.5
71
- )
72
- tracking_confidence = st.slider(
73
- "Min Tracking Confidence", min_value=0.0, max_value=1.0, value=0.9
74
- )
75
- st.markdown("## Face Recognition")
76
- similarity_threshold = st.slider(
77
- "Similarity Threshold", min_value=0.0, max_value=2.0, value=0.67
78
  )
79
- st.markdown(
80
- "This sets a maximum distance for the cosine similarity between the embeddings of the detected face and the gallery images. If the distance is below the threshold, the face is recognized as the gallery image with the lowest distance. If the distance is above the threshold, the face is not recognized."
81
- )
82
-
83
- download_file(
84
- MODEL_URL,
85
- MODEL_LOCAL_PATH,
86
- file_hash="6c19b789f661caa8da735566490bfd8895beffb2a1ec97a56b126f0539991aa6",
87
- )
88
 
89
- # Session-specific caching of the face recognition model
90
- cache_key = "face_id_model"
91
- if cache_key in st.session_state:
92
- face_recognition_model = st.session_state[cache_key]
93
- else:
94
- face_recognition_model = tflite.Interpreter(model_path=MODEL_LOCAL_PATH.as_posix())
95
- st.session_state[cache_key] = face_recognition_model
96
 
97
- # Session-specific caching of the face recognition model
98
- cache_key = "face_id_model_gal"
99
- if cache_key in st.session_state:
100
- face_recognition_model_gal = st.session_state[cache_key]
101
- else:
102
- face_recognition_model_gal = tflite.Interpreter(
103
- model_path=MODEL_LOCAL_PATH.as_posix()
104
- )
105
- st.session_state[cache_key] = face_recognition_model_gal
106
 
107
- # Session-specific caching of the face detection model
108
- cache_key = "face_detection_model"
109
- if cache_key in st.session_state:
110
- face_detection_model = st.session_state[cache_key]
111
- else:
112
- face_detection_model = mp.solutions.face_mesh.FaceMesh(
113
- refine_landmarks=True,
114
- min_detection_confidence=detection_confidence,
115
- min_tracking_confidence=tracking_confidence,
116
- max_num_faces=max_faces,
117
- )
118
- st.session_state[cache_key] = face_detection_model
119
 
120
- # Session-specific caching of the face detection model
121
- cache_key = "face_detection_model_gal"
122
- if cache_key in st.session_state:
123
- face_detection_model_gal = st.session_state[cache_key]
124
- else:
125
- face_detection_model_gal = mp.solutions.face_mesh.FaceMesh(
126
- refine_landmarks=True,
127
- min_detection_confidence=detection_confidence,
128
- min_tracking_confidence=tracking_confidence,
129
- max_num_faces=max_faces,
130
- )
131
- st.session_state[cache_key] = face_detection_model_gal
132
 
133
- stats_queue: "queue.Queue[Stats]" = queue.Queue()
134
- detections_queue: "queue.Queue[List[Detection]]" = queue.Queue()
135
 
136
 
137
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
138
  # Initialize detections
139
- detections = []
140
 
141
  # Initialize stats
142
  stats = Stats()
@@ -154,29 +116,24 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
154
  if face_rec_on:
155
  # Run face detection
156
  start = time.time()
157
- detections = detect_faces(frame, face_detection_model)
158
  stats = stats._replace(num_faces=len(detections) if detections else 0)
159
  stats = stats._replace(detection=(time.time() - start) * 1000)
160
 
161
- # Run face alignment
162
- start = time.time()
163
- detections = align_faces(frame, detections)
164
- stats = stats._replace(alignment=(time.time() - start) * 1000)
165
-
166
- # Run inference
167
- start = time.time()
168
- detections = inference(detections, face_recognition_model)
169
- stats = stats._replace(inference=(time.time() - start) * 1000)
170
-
171
  # Run face recognition
172
  start = time.time()
173
- detections = recognize_faces(detections, gallery, similarity_threshold)
174
  stats = stats._replace(recognition=(time.time() - start) * 1000)
175
 
176
- # Draw detections
 
 
 
 
 
177
  start = time.time()
178
- frame = draw_detections(frame, detections)
179
- stats = stats._replace(drawing=(time.time() - start) * 1000)
180
 
181
  # Convert frame back to av.VideoFrame
182
  frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
@@ -185,30 +142,16 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
185
  stats = stats._replace(fps=1 / (time.time() - frame_start))
186
 
187
  # Send data to other thread
188
- detections_queue.put_nowait(detections)
189
- stats_queue.put_nowait(stats)
190
 
191
  return frame
192
 
193
 
194
  # Streamlit app
195
- st.title("FaceID App Demonstration")
196
 
197
- st.sidebar.markdown("**Gallery**")
198
- gallery = st.sidebar.file_uploader(
199
- "Upload images to gallery", type=["png", "jpg", "jpeg"], accept_multiple_files=True
200
- )
201
- if gallery:
202
- gallery = process_gallery(gallery, face_detection_model_gal, face_recognition_model_gal)
203
- st.sidebar.markdown("**Gallery Images**")
204
- st.sidebar.image(
205
- [identity.image for identity in gallery],
206
- caption=[identity.name for identity in gallery],
207
- width=112,
208
- )
209
-
210
- st.markdown("**Stats**")
211
- stats = st.empty()
212
 
213
  ctx = webrtc_streamer(
214
  key="FaceIDAppDemo",
@@ -233,41 +176,117 @@ ctx = webrtc_streamer(
233
  async_processing=True,
234
  )
235
 
236
- st.markdown("**Identified Faces**")
237
- identified_faces = st.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- st.markdown("**Detections**")
240
- detections = st.empty()
241
 
242
  # Display Live Stats
243
  if ctx.state.playing:
244
  while True:
245
- # Get stats
246
- stats_data = stats_queue.get()
247
- stats_dataframe = pd.DataFrame([stats_data])
248
- stats_dataframe.style.format(thousands=" ", precision=2)
249
-
250
- # Write stats to streamlit
251
- stats.dataframe(stats_dataframe)
252
-
253
- # Get detections
254
- detections_data = detections_queue.get()
255
- detections_dataframe = (
256
- pd.DataFrame(detections_data)
257
- .drop(columns=["face", "face_match"], errors="ignore")
258
- .applymap(lambda x: (format_dflist(x)))
259
- )
260
 
261
- # Write detections to streamlit
262
- detections.dataframe(detections_dataframe)
263
-
264
- # Write identified faces to streamlit
265
- identified_faces.image(
266
- [display_match(d) for d in detections_data if d.name is not None],
267
- caption=[
268
- d.name + f"({d.distance:2f})"
269
- for d in detections_data
270
- if d.name is not None
271
- ],
272
- width=112,
273
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from typing import List
4
  from streamlit_webrtc import webrtc_streamer, WebRtcMode
5
  import logging
 
 
6
  import av
 
7
  import queue
8
  from streamlit_toggle import st_toggle_switch
9
  import pandas as pd
10
+ from tools.nametypes import Stats, Detection, Identity, Match
11
+ from tools.utils import get_ice_servers, rgb, format_dflist
12
+ from tools.face_detection import FaceDetection
13
+ from tools.face_recognition import FaceRecognition
14
+ from tools.annotation import Annotation
15
+ from tools.gallery import init_gallery
16
+ from tools.pca import pca
17
+
 
 
 
18
 
19
  # Set logging level to error (To avoid getting spammed by queue warnings etc.)
20
  logger = logging.getLogger(__name__)
21
  logging.basicConfig(level=logging.ERROR)
22
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Set page layout for streamlit to wide
25
+ st.set_page_config(layout="wide", page_title="FaceID App Demo", page_icon=":sunglasses:")
 
 
26
  with st.sidebar:
27
+ st.markdown("# Settings")
28
  face_rec_on = st_toggle_switch(
29
+ "Live Face Recognition",
30
  key="activate_face_rec",
31
  default_value=True,
32
  active_color=rgb(255, 75, 75),
33
  track_color=rgb(50, 50, 50),
34
+ label_after=True,
35
  )
36
 
37
+ with st.expander("Advanced Settings", expanded=False):
38
+ st.markdown("## Webcam & Stream")
39
+ resolution = st.selectbox(
40
+ "Webcam Resolution",
41
+ [(1920, 1080), (1280, 720), (640, 360)],
42
+ index=2,
43
+ )
44
+ st.markdown("Note: To change the resolution, you have to restart the stream.")
45
 
46
+ ice_server = st.selectbox("ICE Server", ["twilio", "metered"], index=1)
47
+ st.markdown(
48
+ "Note: metered is a free server with limited bandwidth, and can take a while to connect. Twilio is a paid service and is payed by me, so please don't abuse it."
49
+ )
50
+ st.markdown("---")
51
+ st.markdown("## Face Detection")
52
+ detection_min_face_size = st.slider("Min Face Size", min_value=5, max_value=120, value=40)
53
+ detection_scale_factor = st.slider("Scale Factor", min_value=0.1, max_value=1.0, value=0.7)
54
+ detection_confidence = st.slider("Min Detection Confidence", min_value=0.5, max_value=1.0, value=0.9)
55
+ st.markdown("---")
56
+ st.markdown("## Face Recognition")
57
+ similarity_threshold = st.slider("Similarity Threshold", min_value=0.0, max_value=2.0, value=0.67)
58
+ st.markdown(
59
+ "This sets a maximum distance for the cosine similarity between the embeddings of the detected face and the gallery images. If the distance is below the threshold, the face is recognized as the gallery image with the lowest distance. If the distance is above the threshold, the face is not recognized."
60
+ )
61
+ model_name = st.selectbox("Model", ["mobileNet", "resNet"], index=0)
62
+ st.markdown(
63
+ "Note: The mobileNet model is smaller and faster, but less accurate. The resNet50 model is bigger and slower, but more accurate."
64
+ )
65
 
66
+ st.markdown("# Face Gallery")
67
+ files = st.sidebar.file_uploader(
68
+ "Upload images to gallery",
69
+ type=["png", "jpg", "jpeg"],
70
+ accept_multiple_files=True,
71
+ label_visibility="collapsed",
 
 
 
 
 
72
  )
 
 
 
 
 
 
 
 
 
73
 
74
+ with st.expander("Uploaded Images", expanded=True):
75
+ if files:
76
+ st.image(files, width=112, caption=files)
77
+ else:
78
+ st.info("No images uploaded yet.")
 
 
79
 
 
 
 
 
 
 
 
 
 
80
 
81
+ gallery = init_gallery(
82
+ files,
83
+ min_detections_conf=detection_confidence,
84
+ min_similarity=similarity_threshold,
85
+ model_name=model_name,
86
+ )
 
 
 
 
 
 
87
 
88
+ face_detector = FaceDetection(
89
+ min_detections_conf=detection_confidence,
90
+ min_face_size=detection_min_face_size,
91
+ scale_factor=detection_scale_factor,
92
+ )
93
+ face_recognizer = FaceRecognition(model_name=model_name, min_similarity=similarity_threshold)
94
+ annotator = Annotation()
 
 
 
 
 
95
 
96
+ transfer_queue: "queue.Queue[Stats, List[Detection], List[Identity], List[Match]]" = queue.Queue()
 
97
 
98
 
99
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
100
  # Initialize detections
101
+ detections, identities, matches = [], [], []
102
 
103
  # Initialize stats
104
  stats = Stats()
 
116
  if face_rec_on:
117
  # Run face detection
118
  start = time.time()
119
+ frame, detections = face_detector(frame)
120
  stats = stats._replace(num_faces=len(detections) if detections else 0)
121
  stats = stats._replace(detection=(time.time() - start) * 1000)
122
 
 
 
 
 
 
 
 
 
 
 
123
  # Run face recognition
124
  start = time.time()
125
+ identities = face_recognizer(frame, detections)
126
  stats = stats._replace(recognition=(time.time() - start) * 1000)
127
 
128
+ # Do matching
129
+ start = time.time()
130
+ matches = face_recognizer.find_matches(identities, gallery)
131
+ stats = stats._replace(matching=(time.time() - start) * 1000)
132
+
133
+ # Draw annotations
134
  start = time.time()
135
+ frame = annotator(frame, detections, identities, matches, gallery)
136
+ stats = stats._replace(annotation=(time.time() - start) * 1000)
137
 
138
  # Convert frame back to av.VideoFrame
139
  frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
 
142
  stats = stats._replace(fps=1 / (time.time() - frame_start))
143
 
144
  # Send data to other thread
145
+ transfer_queue.put_nowait([stats, detections, identities, matches])
 
146
 
147
  return frame
148
 
149
 
150
  # Streamlit app
151
+ st.title("Live Webcam Face Recognition")
152
 
153
+ st.markdown("**Stream Stats**")
154
+ disp_stats = st.info("No streaming statistics yet, please start the stream.")
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  ctx = webrtc_streamer(
157
  key="FaceIDAppDemo",
 
176
  async_processing=True,
177
  )
178
 
179
+ tab_recognition, tab_metrics, tab_pca = st.tabs(["Recognized Identities", "Recognition Metrics", "Live PCAs"])
180
+
181
+
182
+ with tab_recognition:
183
+ # Display Gallery and Recognized Identities
184
+ col1, col2 = st.columns(2)
185
+ col1.markdown("**Gallery Identities**")
186
+ disp_identities_gal = col1.info("No gallery images uploaded yet ...")
187
+ col2.markdown("**Recognized Identities**")
188
+ disp_identities_rec = col2.info("No recognized identities yet ...")
189
+
190
+ with tab_metrics:
191
+ # Display Detections and Identities
192
+ st.markdown("**Detection Metrics**")
193
+ disp_detection_metrics = st.info("No detected faces yet ...")
194
+
195
+ # Display Recognition Metrics
196
+ st.markdown("**Recognition Metrics**")
197
+ disp_recognition_metrics = st.info("No recognized identities yet ...")
198
+
199
+ with tab_pca:
200
+ # Display 2D and 3D PCA
201
+ col1, col2 = st.columns(2)
202
+ col1.markdown("**PCA 2D**")
203
+ disp_pca3d = col1.info("Only available if more than 1 recognized face ...")
204
+ col2.markdown("**PCA 3D**")
205
+ disp_pca2d = col2.info("Only available if more than 1 recognized face ...")
206
+ freeze_pcas = st.button("Freeze PCAs for Interaction", key="reset_pca")
207
+
208
+ # Show PCAs
209
+ if freeze_pcas and gallery:
210
+ col1, col2 = st.columns(2)
211
+ if len(st.session_state.matches) > 1:
212
+ col1.plotly_chart(
213
+ pca(
214
+ st.session_state.matches,
215
+ st.session_state.identities,
216
+ gallery,
217
+ dim=3,
218
+ ),
219
+ use_container_width=True,
220
+ )
221
+ col2.plotly_chart(
222
+ pca(
223
+ st.session_state.matches,
224
+ st.session_state.identities,
225
+ gallery,
226
+ dim=2,
227
+ ),
228
+ use_container_width=True,
229
+ )
230
+
231
+
232
+ # Show Gallery Identities
233
+ if gallery:
234
+ disp_identities_gal.image(
235
+ image=[identity.face_aligned for identity in gallery],
236
+ caption=[match.name for match in gallery],
237
+ )
238
+ else:
239
+ disp_identities_gal.info("No gallery images uploaded yet ...")
240
 
 
 
241
 
242
  # Display Live Stats
243
  if ctx.state.playing:
244
  while True:
245
+ # Retrieve data from other thread
246
+ stats, detections, identities, matches = transfer_queue.get()
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ # Save for PCA Snapshot
249
+ st.session_state.identities = identities
250
+ st.session_state.matches = matches
251
+
252
+ # Show Stats
253
+ disp_stats.dataframe(
254
+ pd.DataFrame([stats]).applymap(lambda x: (format_dflist(x))),
255
+ use_container_width=True,
 
 
 
 
256
  )
257
+
258
+ # Show Detections Metrics
259
+ if detections:
260
+ disp_detection_metrics.dataframe(
261
+ pd.DataFrame(detections).applymap(lambda x: (format_dflist(x))),
262
+ use_container_width=True,
263
+ )
264
+ else:
265
+ disp_detection_metrics.info("No detected faces yet ...")
266
+
267
+ # Show Match Metrics
268
+ if matches:
269
+ disp_recognition_metrics.dataframe(
270
+ pd.DataFrame(matches).applymap(lambda x: (format_dflist(x))),
271
+ use_container_width=True,
272
+ )
273
+ else:
274
+ disp_recognition_metrics.info("No recognized identities yet ...")
275
+
276
+ if len(matches) > 1:
277
+ disp_pca3d.plotly_chart(pca(matches, identities, gallery, dim=3), use_container_width=True)
278
+ disp_pca2d.plotly_chart(pca(matches, identities, gallery, dim=2), use_container_width=True)
279
+ else:
280
+ disp_pca3d.info("Only available if more than 1 recognized face ...")
281
+ disp_pca2d.info("Only available if more than 1 recognized face ...")
282
+
283
+ # Show Recognized Identities
284
+ if matches:
285
+ disp_identities_rec.image(
286
+ image=[identities[match.identity_idx].face_aligned for match in matches],
287
+ caption=[gallery[match.gallery_idx].name for match in matches],
288
+ )
289
+ else:
290
+ disp_identities_rec.info("No recognized identities yet ...")
291
+
292
+ # BUG Recognized Identity Image is not updating on cloud version? (works on local!!!)
tools/gallery.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .face_detection import FaceDetection
2
+ from .face_recognition import FaceRecognition
3
+ from .nametypes import Identity
4
+ import cv2
5
+ import os
6
+ import numpy as np
7
+
8
+
9
+ def init_gallery(files, min_detections_conf=0.8, min_similarity=0.67, model_name="mobileNet"):
10
+ face_detector = FaceDetection(min_detections_conf=min_detections_conf)
11
+ face_recognizer = FaceRecognition(model_name=model_name, min_similarity=min_similarity)
12
+
13
+ gallery = []
14
+ for file in files:
15
+ file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
16
+ img = cv2.cvtColor(cv2.imdecode(file_bytes, cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
17
+ # Face Detection
18
+ img, detections = face_detector(img)
19
+
20
+ if detections == []:
21
+ continue
22
+ elif len(detections) > 1:
23
+ detections = detections[:1]
24
+
25
+ # Face Recognition
26
+ identities = face_recognizer(img, detections)
27
+
28
+ # Add to gallery
29
+ gallery.append(
30
+ Identity(
31
+ name=os.path.splitext(file.name)[0],
32
+ embedding=identities[0].embedding,
33
+ face_aligned=identities[0].face_aligned,
34
+ )
35
+ )
36
+
37
+ return gallery