File size: 1,805 Bytes
654c5de
3ca50ac
654c5de
 
 
 
 
 
3ca50ac
654c5de
3ca50ac
654c5de
70e8d2e
654c5de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Any, Dict, List
import os
from flair.data import Sentence
from flair.models import SequenceTagger

class EndpointHandler():
    def __init__(
        self,
        path: str,
    ):
        self.tagger = SequenceTagger.load(os.path.join(path,"pytorch_model.bin"))

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Args:
            inputs (:obj:`str`):
                a string containing some text
        Return:
            A :obj:`list`:. The object returned should be like [{"entity_group": "XXX", "word": "some word", "start": 3, "end": 6, "score": 0.82}] containing :
                - "entity_group": A string representing what the entity is.
                - "word": A substring of the original string that was detected as an entity.
                - "start": the offset within `input` leading to `answer`. context[start:stop] == word
                - "end": the ending offset within `input` leading to `answer`. context[start:stop] === word
                - "score": A score between 0 and 1 describing how confident the model is for this entity.
        """
        inputs = data.pop("inputs", data)
        sentence: Sentence = Sentence(inputs)

        # Also show scores for recognized NEs
        self.tagger.predict(sentence, label_name="predicted")

        entities = []
        for span in sentence.get_spans("predicted"):
            if len(span.tokens) == 0:
                continue
            current_entity = {
                "entity_group": span.tag,
                "word": span.text,
                "start": span.tokens[0].start_position,
                "end": span.tokens[-1].end_position,
                "score": span.score,
            }

            entities.append(current_entity)

        return entities