File size: 6,788 Bytes
92a6f43
 
2e3bc39
92a6f43
7e6964a
75a65be
7e6964a
 
75a65be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6964a
2e3bc39
 
 
 
 
 
d09e4cf
2e3bc39
 
 
7e6964a
 
2e3bc39
 
 
 
 
 
 
 
7e6964a
 
2e3bc39
081d311
7e6964a
 
 
1709ba8
2e3bc39
7e6964a
 
 
 
92a6f43
 
7e6964a
92a6f43
2e3bc39
 
 
 
 
 
 
 
 
 
 
92a6f43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e3bc39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92a6f43
7e6964a
 
ce2493d
7e6964a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3898ef
 
 
 
ce2493d
 
 
 
 
 
 
 
f3898ef
ce2493d
 
9d2f4c9
 
ce2493d
 
 
 
 
9d2f4c9
ce2493d
 
 
f3898ef
 
ce2493d
 
f3898ef
ce2493d
 
 
9d2f4c9
ce2493d
 
f3898ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import json 
import copy
import pickle

import torch
from simplemma import lemmatize
from transformers import AutoTokenizer

from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification
from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens

# code originaly from data_manipulation.creation_gazetteers
def lemmatizing(x):
    if x == "":
        return ""
    return lemmatize(x, lang="cs")

# code originaly from data_manipulation.creation_gazetteers
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
    reverse_dictionary = {}
    for key, values in dictionary.items():
        for value in values:
            reverse_dictionary[value] = key
            if apply_lemmatizing:
                temp = lemmatizing(value)
                if temp != value:
                    reverse_dictionary[temp] = key
    return reverse_dictionary

def load_json(path):
    """
    Load gazetteers from a file
    :param path: path to the gazetteer file
    :return: a dict of gazetteers
    """
    with open(path, 'r') as file:
        data = json.load(file)
    return data

def load_pickle(path):
    """
    Load pickle gazetteers from a file
    :param path: path to the gazetteer file
    :return: a dict of gazetteers
    """
    with open(path, 'rb') as file:
        data = pickle.load(file)
    return data


def load():
    """
    Load the tokenizer, model, and gazetteers for named entity recognition.

    Returns:
        tokenizer (AutoTokenizer): The tokenizer for tokenizing input text.
        model (ExtendedEmbeddigsRobertaForTokenClassification): The pre-trained model for named entity recognition.
        gazetteers_for_matching (list): A list of gazetteers for matching named entities.
    """
    model_name = "ufal/robeczech-base"
    model_path = "bettystr/NerRoB-czech"
    gazetteers_path = "gazetteers.pkl"

    model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.eval()

    gazetteers_for_matching = load_pickle(gazetteers_path)
    temp = []
    for i in gazetteers_for_matching.keys():
        temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
    gazetteers_for_matching = temp
    return tokenizer, model, gazetteers_for_matching


def add_additional_gazetteers(gazetteers_for_matching, file_names):
    """
    Adds additional gazetteers to the existing dict.

    Args:
        gazetteers_for_matching (dict): The list of gazetteers to be updated.
        file_names (list): The list of file names containing additional gazetteers.

    Returns:
        dict: The updated list of gazetteers.

    """
    if file_names is None or file_names == []:
        return gazetteers_for_matching
    temp = []
    for l1 in gazetteers_for_matching:
        d2 = copy.deepcopy(l1)
        temp.append(d2)
    for file_name in file_names:
        with open(file_name, 'r') as file:
            data = json.load(file)
        for key, value_lst in data.items():
            key = key.upper()
            for dictionary in temp:
                if key in dictionary.values():
                    for value in value_lst:
                        dictionary[value] = key
    return temp


def run(tokenizer, model, gazetteers, text, file_names=None):
    """
    Runs the named entity recognition (NER) model on the given text.

    Args:
        tokenizer (Tokenizer): The tokenizer used to tokenize the input text.
        model (Model): The NER model used for prediction.
        gazetteers (list): A list of gazetteers used for matching entities in the text.
        text (str): The input text to perform NER on.
        file_names (list, optional): A list of file names to be used as additional gazetteers.

    Returns:
        list: A list of dictionaries representing the predicted entities in the text. Each dictionary contains the following keys:
            - "start" (int): The starting position of the entity in the text.
            - "end" (int): The ending position of the entity in the text.
            - "entity" (str): The type of the entity.
            - "score" (float): The confidence score of the entity prediction.
            - "word" (str): The actual word representing the entity.
            - "count" (int): The number of tokens in the entity.

    """
    gazetteers_for_matching = add_additional_gazetteers(gazetteers, file_names)

    tokenized_inputs = tokenizer(
        text, truncation=True, is_split_into_words=False, return_offsets_mapping=True
    )
    matches = gazetteer_matching(text, gazetteers_for_matching)
    new_g = []
    word_ids = tokenized_inputs.word_ids()
    new_g.append(align_gazetteers_with_tokens(matches, word_ids))
    p, o, l = [], [], []
    for i in new_g:
        p.append([x[0] for x in i])
        o.append([x[1] for x in i])
        l.append([x[2] for x in i])

    input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0)
    attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0)
    per = torch.tensor(p, device="cpu")
    org = torch.tensor(o, device="cpu")
    loc = torch.tensor(l, device="cpu")
    output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
    predictions = torch.argmax(output, dim=2).tolist()
    predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
    
    softmax = torch.nn.Softmax(dim=2)
    scores = softmax(output).squeeze(0).tolist()
    result = []
    temp = {
            "start": 0,
            "end": 0,
            "entity": "O",
            "score": 0,
            "word": "",
            "count": 0
        }
    for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
        if pos[0] == pos[1] or entity == "O":
            continue
        if "I-" + temp["entity"] == entity:  # same entity
            temp["word"] += text[temp["end"]:pos[0]] + text[pos[0]:pos[1]]
            temp["end"] = pos[1]
            temp["count"] += 1
            temp["score"] += max(score)
        else:  # new entity
            if temp["count"] > 0:
                temp["score"] += max(score)
                temp["score"] /= temp.pop("count")
                result.append(temp)
            temp = {
            "start": pos[0],
            "end": pos[1],
            "entity": entity[2:],
            "score": 0,
            "word": text[pos[0]:pos[1]],
            "count": 1
            }
    if temp["count"] > 0:
        temp["score"] += max(score)
        temp["score"] /= temp.pop("count")
        result.append(temp)
    return result