OmAgent / tool /face_rec /face_rec.py
韩宇
init
1b7e88c
raw
history blame
3.99 kB
import pickle
from pathlib import Path
from typing import List, Optional
import cv2
import face_recognition
import numpy as np
from PIL import Image
from pydantic import model_validator
from ....models.od.schemas import Target
from ....utils.registry import registry
from ...base import ArgSchema, BaseModelTool
ARGSCHEMA = {}
@registry.register_tool()
class FaceRecognition(BaseModelTool):
args_schema: ArgSchema = ArgSchema(**ARGSCHEMA)
description: str = (
"This tool can recognize facial information in images/extracted frames and identify who the person is."
"Images/extracted frames are already loaded."
)
threshold: float = 0.6
num_jitters: int = 1
face_db: str = "data/face_db"
model: str = "large"
loaded_face_db: Optional[dict] = None
@model_validator(mode="after")
def face_db_validator(self) -> "FaceRecognition":
if self.loaded_face_db is None:
if Path(self.face_db).exists():
self.loaded_face_db = self._load_face_db(self.face_db)
else:
raise ValueError(f"Face database not found at {self.face_db}")
elif isinstance(self.loaded_face_db, dict):
if (
"embeddings" not in self.loaded_face_db
or "names" not in self.loaded_face_db
):
raise ValueError(
"Face database must have 'embeddings' and 'names' keys."
)
else:
raise ValueError("Face database must be a dictionary.")
return self
def _load_face_db(self, path: str):
cached_model = Path(path).joinpath(f"representations_{self.model}_face.pkl")
# if cached_model.exists():
# loaded_face_db = pickle.load(open(cached_model, "rb"))
# else:
face_db = Path(path)
embeddings = []
names = []
for known_image in face_db.rglob("*"):
if known_image.suffix in [".jpg", ".png", ".webp"]:
loaded_image = np.array(Image.open(known_image).convert("RGB"))
loaded_image = cv2.cvtColor(loaded_image, cv2.COLOR_RGB2BGR)
known_encoding = face_recognition.face_encodings(
loaded_image, model="large"
)[0]
embeddings.append(known_encoding)
names.append(known_image.parent.name)
loaded_face_db = {"embeddings": embeddings, "names": names}
pickle.dump(loaded_face_db, open(cached_model, "wb"))
return loaded_face_db
def infer(self, image: Image.Image) -> List[Target]:
img = np.array(image.convert("RGB"))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
face_locations = face_recognition.face_locations(img)
face_encodings = face_recognition.face_encodings(img, face_locations)
rec_res = []
for (top, right, bottom, left), face_encoding in zip(
face_locations, face_encodings
):
face_distances = face_recognition.face_distance(
self.loaded_face_db.get("embeddings"), face_encoding
)
best_match_index = np.argmin(face_distances)
if face_distances[best_match_index] <= self.threshold:
name = self.loaded_face_db["names"][best_match_index]
bbox = [left, top, right, bottom]
rec_res.append(
Target(label=name, bbox=bbox, conf=face_distances[best_match_index])
)
return rec_res
def _run(self):
names = set()
for key in self.stm.image_cache.keys():
anno = self.infer(self.stm.image_cache[key])
self.stm.image_cache[key] = self.visual_prompting(
self.stm.image_cache[key], anno
)
names.update([item.label for item in anno])
return f"Recognized {len(names)} faces: {', '.join(names)}"
async def _arun(self):
return self._run()