beki commited on
Commit
62db53d
1 Parent(s): fec18f3

Create new file

Browse files
Files changed (1) hide show
  1. flair_recognizer.py +233 -0
flair_recognizer.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, List, Tuple, Set
3
+
4
+ from presidio_analyzer import (
5
+ RecognizerResult,
6
+ EntityRecognizer,
7
+ AnalysisExplanation,
8
+ )
9
+ from presidio_analyzer.nlp_engine import NlpArtifacts
10
+
11
+ try:
12
+ from flair.data import Sentence
13
+ from flair.models import SequenceTagger
14
+ except ImportError:
15
+ print("Flair is not installed")
16
+
17
+
18
+ logger = logging.getLogger("presidio-analyzer")
19
+
20
+
21
+ class FlairRecognizer(EntityRecognizer):
22
+ """
23
+ Wrapper for a flair model, if needed to be used within Presidio Analyzer.
24
+ :example:
25
+ >from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
26
+ >flair_recognizer = FlairRecognizer()
27
+ >registry = RecognizerRegistry()
28
+ >registry.add_recognizer(flair_recognizer)
29
+ >analyzer = AnalyzerEngine(registry=registry)
30
+ >results = analyzer.analyze(
31
+ > "My name is Christopher and I live in Irbid.",
32
+ > language="en",
33
+ > return_decision_process=True,
34
+ >)
35
+ >for result in results:
36
+ > print(result)
37
+ > print(result.analysis_explanation)
38
+ """
39
+
40
+ ENTITIES = [
41
+ "LOCATION",
42
+ "PERSON",
43
+ "NRP",
44
+ "GPE",
45
+ "ORGANIZATION",
46
+ "MAC_ADDRESS",
47
+ "US_BANK_NUMBER",
48
+ "IMEI",
49
+ "TITLE",
50
+ "LICENSE_PLATE",
51
+ "US_PASSPORT",
52
+ "CURRENCY",
53
+ "ROUTING_NUMBER",
54
+ "US_ITIN",
55
+ "US_BANK_NUMBER",
56
+ "US_DRIVER_LICENSE",
57
+ ]
58
+
59
+ DEFAULT_EXPLANATION = "Identified as {} by Flair's Named Entity Recognition"
60
+
61
+ CHECK_LABEL_GROUPS = [
62
+ ({"LOCATION"}, {"LOC", "LOCATION", "STREET_ADDRESS", "COORDINATE"}),
63
+ ({"PERSON"}, {"PER", "PERSON"}),
64
+ ({"NRP"}, {"NORP", "NRP"}),
65
+ ({"GPE"}, {"GPE"}),
66
+ ({"ORGANIZATION"}, {"ORG"}),
67
+ ({"MAC_ADDRESS"}, {"MAC_ADDRESS"}),
68
+ ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
69
+ ({"IMEI"}, {"IMEI"}),
70
+ ({"TITLE"}, {"TITLE"}),
71
+ ({"LICENSE_PLATE"}, {"LICENSE_PLATE"}),
72
+ ({"US_PASSPORT"}, {"US_PASSPORT"}),
73
+ ({"CURRENCY"}, {"CURRENCY"}),
74
+ ({"ROUTING_NUMBER"}, {"ROUTING_NUMBER"}),
75
+ # ({"US_ITIN"}, {"US_ITIN"}),
76
+ # ({"US_BANK_NUMBER"}, {"US_BANK_NUMBER"}),
77
+ # ({"US_DRIVER_LICENSE"}, {"US_DRIVER_LICENSE"}),
78
+ ]
79
+
80
+ MODEL_LANGUAGES = {
81
+ "en": "beki/flair-ner-debug-english",
82
+ # "es": "flair/ner-spanish-large",
83
+ # "de": "flair/ner-german-large",
84
+ # "nl": "flair/ner-dutch-large",
85
+ }
86
+
87
+ PRESIDIO_EQUIVALENCES = {
88
+ "PER": "PERSON",
89
+ "LOC": "LOCATION",
90
+ "ORG": "ORGANIZATION",
91
+ # 'MISC': 'MISCELLANEOUS' # - Probably not PII
92
+ }
93
+
94
+ def __init__(
95
+ self,
96
+ supported_language: str = "en",
97
+ supported_entities: Optional[List[str]] = None,
98
+ check_label_groups: Optional[Tuple[Set, Set]] = None,
99
+ model: SequenceTagger = None,
100
+ ):
101
+ self.check_label_groups = (
102
+ check_label_groups if check_label_groups else self.CHECK_LABEL_GROUPS
103
+ )
104
+
105
+ supported_entities = supported_entities if supported_entities else self.ENTITIES
106
+ self.model = (
107
+ model
108
+ if model
109
+ else SequenceTagger.load(self.MODEL_LANGUAGES.get(supported_language))
110
+ )
111
+
112
+ super().__init__(
113
+ supported_entities=supported_entities,
114
+ supported_language=supported_language,
115
+ name="Flair Analytics",
116
+ )
117
+
118
+ def load(self) -> None:
119
+ """Load the model, not used. Model is loaded during initialization."""
120
+ pass
121
+
122
+ def get_supported_entities(self) -> List[str]:
123
+ """
124
+ Return supported entities by this model.
125
+ :return: List of the supported entities.
126
+ """
127
+ return self.supported_entities
128
+
129
+ # Class to use Flair with Presidio as an external recognizer.
130
+ def analyze(
131
+ self, text: str, entities: List[str], nlp_artifacts: NlpArtifacts = None
132
+ ) -> List[RecognizerResult]:
133
+ """
134
+ Analyze text using Text Analytics.
135
+ :param text: The text for analysis.
136
+ :param entities: Not working properly for this recognizer.
137
+ :param nlp_artifacts: Not used by this recognizer.
138
+ :param language: Text language. Supported languages in MODEL_LANGUAGES
139
+ :return: The list of Presidio RecognizerResult constructed from the recognized
140
+ Flair detections.
141
+ """
142
+
143
+ results = []
144
+
145
+ sentences = Sentence(text)
146
+ self.model.predict(sentences)
147
+
148
+ # If there are no specific list of entities, we will look for all of it.
149
+ if not entities:
150
+ entities = self.supported_entities
151
+
152
+ for entity in entities:
153
+ if entity not in self.supported_entities:
154
+ continue
155
+
156
+ for ent in sentences.get_spans("ner"):
157
+ if not self.__check_label(
158
+ entity, ent.labels[0].value, self.check_label_groups
159
+ ):
160
+ continue
161
+ textual_explanation = self.DEFAULT_EXPLANATION.format(
162
+ ent.labels[0].value
163
+ )
164
+ explanation = self.build_flair_explanation(
165
+ round(ent.score, 2), textual_explanation
166
+ )
167
+ flair_result = self._convert_to_recognizer_result(ent, explanation)
168
+
169
+ results.append(flair_result)
170
+
171
+ return results
172
+
173
+ def _convert_to_recognizer_result(self, entity, explanation) -> RecognizerResult:
174
+
175
+ entity_type = self.PRESIDIO_EQUIVALENCES.get(entity.tag, entity.tag)
176
+ flair_score = round(entity.score, 2)
177
+
178
+ flair_results = RecognizerResult(
179
+ entity_type=entity_type,
180
+ start=entity.start_position,
181
+ end=entity.end_position,
182
+ score=flair_score,
183
+ analysis_explanation=explanation,
184
+ )
185
+
186
+ return flair_results
187
+
188
+ def build_flair_explanation(
189
+ self, original_score: float, explanation: str
190
+ ) -> AnalysisExplanation:
191
+ """
192
+ Create explanation for why this result was detected.
193
+ :param original_score: Score given by this recognizer
194
+ :param explanation: Explanation string
195
+ :return:
196
+ """
197
+ explanation = AnalysisExplanation(
198
+ recognizer=self.__class__.__name__,
199
+ original_score=original_score,
200
+ textual_explanation=explanation,
201
+ )
202
+ return explanation
203
+
204
+ @staticmethod
205
+ def __check_label(
206
+ entity: str, label: str, check_label_groups: Tuple[Set, Set]
207
+ ) -> bool:
208
+ return any(
209
+ [entity in egrp and label in lgrp for egrp, lgrp in check_label_groups]
210
+ )
211
+
212
+
213
+ if __name__ == "__main__":
214
+
215
+ from presidio_analyzer import AnalyzerEngine, RecognizerRegistry
216
+
217
+ flair_recognizer = (
218
+ FlairRecognizer()
219
+ ) # This would download a very large (+2GB) model on the first run
220
+
221
+ registry = RecognizerRegistry()
222
+ registry.add_recognizer(flair_recognizer)
223
+
224
+ analyzer = AnalyzerEngine(registry=registry)
225
+
226
+ results = analyzer.analyze(
227
+ "{first_name: Moustafa, sale_id: 235234}",
228
+ language="en",
229
+ return_decision_process=True,
230
+ )
231
+ for result in results:
232
+ print(result)
233
+ print(result.analysis_explanation)