Spaces:
Build error
Build error
import logging | |
from typing import Optional, List, Tuple, Set | |
from presidio_analyzer import ( | |
RecognizerResult, | |
EntityRecognizer, | |
AnalysisExplanation, | |
) | |
from presidio_analyzer.nlp_engine import NlpArtifacts | |
logger = logging.getLogger("presidio-analyzer") | |
try: | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForTokenClassification, | |
pipeline, | |
models, | |
) | |
from transformers.models.bert.modeling_bert import BertForTokenClassification | |
except ImportError: | |
logger.error("transformers is not installed") | |
class TransformersRecognizer(EntityRecognizer): | |
""" | |
Wrapper for a transformers model, if needed to be used within Presidio Analyzer. | |
:example: | |
>from presidio_analyzer import AnalyzerEngine, RecognizerRegistry | |
>transformers_recognizer = TransformersRecognizer() | |
>registry = RecognizerRegistry() | |
>registry.add_recognizer(transformers_recognizer) | |
>analyzer = AnalyzerEngine(registry=registry) | |
>results = analyzer.analyze( | |
> "My name is Christopher and I live in Irbid.", | |
> language="en", | |
> return_decision_process=True, | |
>) | |
>for result in results: | |
> print(result) | |
> print(result.analysis_explanation) | |
""" | |
ENTITIES = [ | |
"LOCATION", | |
"PERSON", | |
"ORGANIZATION", | |
"AGE", | |
"ID", | |
"PHONE_NUMBER", | |
"EMAIL", | |
"DATE", | |
] | |
DEFAULT_EXPLANATION = "Identified as {} by transformers's Named Entity Recognition" | |
CHECK_LABEL_GROUPS = [ | |
({"LOCATION"}, {"LOC", "HOSP"}), | |
({"PERSON"}, {"PER", "PERSON", "STAFF","PATIENT"}), | |
({"ORGANIZATION"}, {"ORGANIZATION", "ORG", "PATORG"}), | |
({"AGE"}, {"AGE"}), | |
({"ID"}, {"ID"}), | |
({"EMAIL"}, {"EMAIL"}), | |
({"DATE"}, {"DATE"}), | |
({"PHONE_NUMBER"}, {"PHONE"}), | |
] | |
PRESIDIO_EQUIVALENCES = { | |
"PER": "PERSON", | |
"LOC": "LOCATION", | |
"ORG": "ORGANIZATION", | |
"AGE": "AGE", | |
"ID": "ID", | |
"EMAIL": "EMAIL", | |
"PATIENT": "PERSON", | |
"STAFF": "PERSON", | |
"HOSP": "LOCATION", | |
"PATORG": "ORGANIZATION", | |
"DATE": "DATE_TIME", | |
"PHONE": "PHONE_NUMBER", | |
} | |
DEFAULT_MODEL_PATH = "obi/deid_roberta_i2b2" | |
def __init__( | |
self, | |
supported_entities: Optional[List[str]] = None, | |
check_label_groups: Optional[Tuple[Set, Set]] = None, | |
model: Optional[BertForTokenClassification] = None, | |
model_path: Optional[str] = None, | |
): | |
if not model and not model_path: | |
model_path = self.DEFAULT_MODEL_PATH | |
logger.warning( | |
f"Both 'model' and 'model_path' arguments are None. Using default model_path={model_path}" | |
) | |
if model and model_path: | |
logger.warning( | |
f"Both 'model' and 'model_path' arguments were provided. Ignoring the model_path" | |
) | |
self.check_label_groups = ( | |
check_label_groups if check_label_groups else self.CHECK_LABEL_GROUPS | |
) | |
supported_entities = supported_entities if supported_entities else self.ENTITIES | |
self.model = ( | |
model | |
if model | |
else pipeline( | |
"ner", | |
model=AutoModelForTokenClassification.from_pretrained(model_path), | |
tokenizer=AutoTokenizer.from_pretrained(model_path), | |
aggregation_strategy="simple", | |
) | |
) | |
super().__init__( | |
supported_entities=supported_entities, name="transformers Analytics", | |
) | |
def load(self) -> None: | |
"""Load the model, not used. Model is loaded during initialization.""" | |
pass | |
def get_supported_entities(self) -> List[str]: | |
""" | |
Return supported entities by this model. | |
:return: List of the supported entities. | |
""" | |
return self.supported_entities | |
# Class to use transformers with Presidio as an external recognizer. | |
def analyze( | |
self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts = None | |
) -> List[RecognizerResult]: | |
""" | |
Analyze text using Text Analytics. | |
:param text: The text for analysis. | |
:param entities: Not working properly for this recognizer. | |
:param nlp_artifacts: Not used by this recognizer. | |
:return: The list of Presidio RecognizerResult constructed from the recognized | |
transformers detections. | |
""" | |
results = [] | |
ner_results = self.model(text) | |
# If there are no specific list of entities, we will look for all of it. | |
if not entities: | |
entities = self.supported_entities | |
for entity in entities: | |
if entity not in self.supported_entities: | |
continue | |
for res in ner_results: | |
if not self.__check_label( | |
entity, res["entity_group"], self.check_label_groups | |
): | |
continue | |
textual_explanation = self.DEFAULT_EXPLANATION.format( | |
res["entity_group"] | |
) | |
explanation = self.build_transformers_explanation( | |
round(res["score"], 2), textual_explanation | |
) | |
transformers_result = self._convert_to_recognizer_result( | |
res, explanation | |
) | |
results.append(transformers_result) | |
return results | |
def _convert_to_recognizer_result(self, res, explanation) -> RecognizerResult: | |
entity_type = self.PRESIDIO_EQUIVALENCES.get( | |
res["entity_group"], res["entity_group"] | |
) | |
transformers_score = round(res["score"], 2) | |
transformers_results = RecognizerResult( | |
entity_type=entity_type, | |
start=res["start"], | |
end=res["end"], | |
score=transformers_score, | |
analysis_explanation=explanation, | |
) | |
return transformers_results | |
def build_transformers_explanation( | |
self, original_score: float, explanation: str | |
) -> AnalysisExplanation: | |
""" | |
Create explanation for why this result was detected. | |
:param original_score: Score given by this recognizer | |
:param explanation: Explanation string | |
:return: | |
""" | |
explanation = AnalysisExplanation( | |
recognizer=self.__class__.__name__, | |
original_score=original_score, | |
textual_explanation=explanation, | |
) | |
return explanation | |
def __check_label( | |
entity: str, label: str, check_label_groups: Tuple[Set, Set] | |
) -> bool: | |
return any( | |
[entity in egrp and label in lgrp for egrp, lgrp in check_label_groups] | |
) | |
if __name__ == "__main__": | |
from presidio_analyzer import AnalyzerEngine, RecognizerRegistry | |
transformers_recognizer = ( | |
TransformersRecognizer() | |
) # This would download a large (~500Mb) model on the first run | |
registry = RecognizerRegistry() | |
registry.add_recognizer(transformers_recognizer) | |
analyzer = AnalyzerEngine(registry=registry) | |
results = analyzer.analyze( | |
"My name is Christopher and I live in Irbid.", | |
language="en", | |
return_decision_process=True, | |
) | |
for result in results: | |
print(result) | |
print(result.analysis_explanation) | |