beki commited on
Commit
1f21ea2
·
1 Parent(s): 3967b64

Upload transformers_recognizer.py

Browse files
Files changed (1) hide show
  1. transformers_recognizer.py +252 -0
transformers_recognizer.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, List, Tuple, Set
3
+
4
+ from presidio_analyzer import (
5
+ RecognizerResult,
6
+ EntityRecognizer,
7
+ AnalysisExplanation,
8
+ )
9
+ from presidio_analyzer.nlp_engine import NlpArtifacts
10
+
11
+ logger = logging.getLogger("presidio-analyzer")
12
+
13
+ try:
14
+ from transformers import (
15
+ AutoTokenizer,
16
+ AutoModelForTokenClassification,
17
+ pipeline,
18
+ models,
19
+ )
20
+ from transformers.models.bert.modeling_bert import BertForTokenClassification
21
+ except ImportError:
22
+ logger.error("transformers is not installed")
23
+
24
+
25
+
26
+ class TransformersRecognizer(EntityRecognizer):
27
+ """
28
+ Wrapper for a transformers model, if needed to be used within Presidio Analyzer.
29
+
30
+ :example:
31
+ >from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
32
+
33
+ >transformers_recognizer = TransformersRecognizer()
34
+
35
+ >registry = RecognizerRegistry()
36
+ >registry.add_recognizer(transformers_recognizer)
37
+
38
+ >analyzer = AnalyzerEngine(registry=registry)
39
+
40
+ >results = analyzer.analyze(
41
+ > "My name is Christopher and I live in Irbid.",
42
+ > language="en",
43
+ > return_decision_process=True,
44
+ >)
45
+ >for result in results:
46
+ > print(result)
47
+ > print(result.analysis_explanation)
48
+
49
+
50
+ """
51
+
52
+ ENTITIES = [
53
+ "LOCATION",
54
+ "PERSON",
55
+ "ORGANIZATION",
56
+ "AGE",
57
+ "ID",
58
+ "PHONE_NUMBER",
59
+ "EMAIL",
60
+ "DATE",
61
+
62
+ ]
63
+
64
+ DEFAULT_EXPLANATION = "Identified as {} by transformers's Named Entity Recognition"
65
+
66
+ CHECK_LABEL_GROUPS = [
67
+ ({"LOCATION"}, {"LOC", "HOSP"}),
68
+ ({"PERSON"}, {"PER", "PERSON", "STAFF","PATIENT"}),
69
+ ({"ORGANIZATION"}, {"ORGANIZATION", "ORG", "PATORG"}),
70
+ ({"AGE"}, {"AGE"}),
71
+ ({"ID"}, {"ID"}),
72
+ ({"EMAIL"}, {"EMAIL"}),
73
+ ({"DATE"}, {"DATE"}),
74
+ ({"PHONE_NUMBER"}, {"PHONE"}),
75
+
76
+ ]
77
+
78
+ PRESIDIO_EQUIVALENCES = {
79
+ "PER": "PERSON",
80
+ "LOC": "LOCATION",
81
+ "ORG": "ORGANIZATION",
82
+ "AGE": "AGE",
83
+ "ID": "ID",
84
+ "EMAIL": "EMAIL",
85
+ "PATIENT": "PERSON",
86
+ "STAFF": "PERSON",
87
+ "HOSP": "LOCATION",
88
+ "PATORG": "ORGANIZATION",
89
+ "DATE": "DATE_TIME",
90
+ "PHONE": "PHONE_NUMBER",
91
+ }
92
+
93
+ DEFAULT_MODEL_PATH = "obi/deid_roberta_i2b2"
94
+
95
+ def __init__(
96
+ self,
97
+ supported_entities: Optional[List[str]] = None,
98
+ check_label_groups: Optional[Tuple[Set, Set]] = None,
99
+ model: Optional[BertForTokenClassification] = None,
100
+ model_path: Optional[str] = None,
101
+ ):
102
+ if not model and not model_path:
103
+ model_path = self.DEFAULT_MODEL_PATH
104
+ logger.warning(
105
+ f"Both 'model' and 'model_path' arguments are None. Using default model_path={model_path}"
106
+ )
107
+
108
+ if model and model_path:
109
+ logger.warning(
110
+ f"Both 'model' and 'model_path' arguments were provided. Ignoring the model_path"
111
+ )
112
+
113
+ self.check_label_groups = (
114
+ check_label_groups if check_label_groups else self.CHECK_LABEL_GROUPS
115
+ )
116
+
117
+ supported_entities = supported_entities if supported_entities else self.ENTITIES
118
+ self.model = (
119
+ model
120
+ if model
121
+ else pipeline(
122
+ "ner",
123
+ model=AutoModelForTokenClassification.from_pretrained(model_path),
124
+ tokenizer=AutoTokenizer.from_pretrained(model_path),
125
+ aggregation_strategy="simple",
126
+ )
127
+ )
128
+
129
+ super().__init__(
130
+ supported_entities=supported_entities, name="transformers Analytics",
131
+ )
132
+
133
+ def load(self) -> None:
134
+ """Load the model, not used. Model is loaded during initialization."""
135
+ pass
136
+
137
+ def get_supported_entities(self) -> List[str]:
138
+ """
139
+ Return supported entities by this model.
140
+
141
+ :return: List of the supported entities.
142
+ """
143
+ return self.supported_entities
144
+
145
+ # Class to use transformers with Presidio as an external recognizer.
146
+ def analyze(
147
+ self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts = None
148
+ ) -> List[RecognizerResult]:
149
+ """
150
+ Analyze text using Text Analytics.
151
+
152
+ :param text: The text for analysis.
153
+ :param entities: Not working properly for this recognizer.
154
+ :param nlp_artifacts: Not used by this recognizer.
155
+ :return: The list of Presidio RecognizerResult constructed from the recognized
156
+ transformers detections.
157
+ """
158
+
159
+ results = []
160
+ ner_results = self.model(text)
161
+
162
+ # If there are no specific list of entities, we will look for all of it.
163
+ if not entities:
164
+ entities = self.supported_entities
165
+
166
+ for entity in entities:
167
+ if entity not in self.supported_entities:
168
+ continue
169
+
170
+ for res in ner_results:
171
+ if not self.__check_label(
172
+ entity, res["entity_group"], self.check_label_groups
173
+ ):
174
+ continue
175
+ textual_explanation = self.DEFAULT_EXPLANATION.format(
176
+ res["entity_group"]
177
+ )
178
+ explanation = self.build_transformers_explanation(
179
+ round(res["score"], 2), textual_explanation
180
+ )
181
+ transformers_result = self._convert_to_recognizer_result(
182
+ res, explanation
183
+ )
184
+
185
+ results.append(transformers_result)
186
+
187
+ return results
188
+
189
+ def _convert_to_recognizer_result(self, res, explanation) -> RecognizerResult:
190
+
191
+ entity_type = self.PRESIDIO_EQUIVALENCES.get(
192
+ res["entity_group"], res["entity_group"]
193
+ )
194
+ transformers_score = round(res["score"], 2)
195
+
196
+ transformers_results = RecognizerResult(
197
+ entity_type=entity_type,
198
+ start=res["start"],
199
+ end=res["end"],
200
+ score=transformers_score,
201
+ analysis_explanation=explanation,
202
+ )
203
+
204
+ return transformers_results
205
+
206
+ def build_transformers_explanation(
207
+ self, original_score: float, explanation: str
208
+ ) -> AnalysisExplanation:
209
+ """
210
+ Create explanation for why this result was detected.
211
+
212
+ :param original_score: Score given by this recognizer
213
+ :param explanation: Explanation string
214
+ :return:
215
+ """
216
+ explanation = AnalysisExplanation(
217
+ recognizer=self.__class__.__name__,
218
+ original_score=original_score,
219
+ textual_explanation=explanation,
220
+ )
221
+ return explanation
222
+
223
+ @staticmethod
224
+ def __check_label(
225
+ entity: str, label: str, check_label_groups: Tuple[Set, Set]
226
+ ) -> bool:
227
+ return any(
228
+ [entity in egrp and label in lgrp for egrp, lgrp in check_label_groups]
229
+ )
230
+
231
+
232
+ if __name__ == "__main__":
233
+
234
+ from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
235
+
236
+ transformers_recognizer = (
237
+ TransformersRecognizer()
238
+ ) # This would download a large (~500Mb) model on the first run
239
+
240
+ registry = RecognizerRegistry()
241
+ registry.add_recognizer(transformers_recognizer)
242
+
243
+ analyzer = AnalyzerEngine(registry=registry)
244
+
245
+ results = analyzer.analyze(
246
+ "My name is Christopher and I live in Irbid.",
247
+ language="en",
248
+ return_decision_process=True,
249
+ )
250
+ for result in results:
251
+ print(result)
252
+ print(result.analysis_explanation)