|
import pickle |
|
from pathlib import Path |
|
from typing import List, Optional |
|
|
|
import cv2 |
|
import face_recognition |
|
import numpy as np |
|
from PIL import Image |
|
from pydantic import model_validator |
|
|
|
from ....models.od.schemas import Target |
|
from ....utils.registry import registry |
|
from ...base import ArgSchema, BaseModelTool |
|
|
|
ARGSCHEMA = {} |
|
|
|
|
|
@registry.register_tool() |
|
class FaceRecognition(BaseModelTool): |
|
args_schema: ArgSchema = ArgSchema(**ARGSCHEMA) |
|
description: str = ( |
|
"This tool can recognize facial information in images/extracted frames and identify who the person is." |
|
"Images/extracted frames are already loaded." |
|
) |
|
threshold: float = 0.6 |
|
num_jitters: int = 1 |
|
face_db: str = "data/face_db" |
|
model: str = "large" |
|
loaded_face_db: Optional[dict] = None |
|
|
|
@model_validator(mode="after") |
|
def face_db_validator(self) -> "FaceRecognition": |
|
if self.loaded_face_db is None: |
|
if Path(self.face_db).exists(): |
|
self.loaded_face_db = self._load_face_db(self.face_db) |
|
else: |
|
raise ValueError(f"Face database not found at {self.face_db}") |
|
elif isinstance(self.loaded_face_db, dict): |
|
if ( |
|
"embeddings" not in self.loaded_face_db |
|
or "names" not in self.loaded_face_db |
|
): |
|
raise ValueError( |
|
"Face database must have 'embeddings' and 'names' keys." |
|
) |
|
else: |
|
raise ValueError("Face database must be a dictionary.") |
|
return self |
|
|
|
def _load_face_db(self, path: str): |
|
cached_model = Path(path).joinpath(f"representations_{self.model}_face.pkl") |
|
|
|
|
|
|
|
face_db = Path(path) |
|
embeddings = [] |
|
names = [] |
|
for known_image in face_db.rglob("*"): |
|
if known_image.suffix in [".jpg", ".png", ".webp"]: |
|
loaded_image = np.array(Image.open(known_image).convert("RGB")) |
|
loaded_image = cv2.cvtColor(loaded_image, cv2.COLOR_RGB2BGR) |
|
known_encoding = face_recognition.face_encodings( |
|
loaded_image, model="large" |
|
)[0] |
|
embeddings.append(known_encoding) |
|
names.append(known_image.parent.name) |
|
loaded_face_db = {"embeddings": embeddings, "names": names} |
|
pickle.dump(loaded_face_db, open(cached_model, "wb")) |
|
return loaded_face_db |
|
|
|
def infer(self, image: Image.Image) -> List[Target]: |
|
img = np.array(image.convert("RGB")) |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
face_locations = face_recognition.face_locations(img) |
|
face_encodings = face_recognition.face_encodings(img, face_locations) |
|
rec_res = [] |
|
for (top, right, bottom, left), face_encoding in zip( |
|
face_locations, face_encodings |
|
): |
|
face_distances = face_recognition.face_distance( |
|
self.loaded_face_db.get("embeddings"), face_encoding |
|
) |
|
best_match_index = np.argmin(face_distances) |
|
if face_distances[best_match_index] <= self.threshold: |
|
name = self.loaded_face_db["names"][best_match_index] |
|
bbox = [left, top, right, bottom] |
|
rec_res.append( |
|
Target(label=name, bbox=bbox, conf=face_distances[best_match_index]) |
|
) |
|
return rec_res |
|
|
|
def _run(self): |
|
names = set() |
|
for key in self.stm.image_cache.keys(): |
|
anno = self.infer(self.stm.image_cache[key]) |
|
self.stm.image_cache[key] = self.visual_prompting( |
|
self.stm.image_cache[key], anno |
|
) |
|
names.update([item.label for item in anno]) |
|
|
|
return f"Recognized {len(names)} faces: {', '.join(names)}" |
|
|
|
async def _arun(self): |
|
return self._run() |
|
|