Spaces:
Sleeping
Sleeping
File size: 6,788 Bytes
92a6f43 2e3bc39 92a6f43 7e6964a 75a65be 7e6964a 75a65be 7e6964a 2e3bc39 d09e4cf 2e3bc39 7e6964a 2e3bc39 7e6964a 2e3bc39 081d311 7e6964a 1709ba8 2e3bc39 7e6964a 92a6f43 7e6964a 92a6f43 2e3bc39 92a6f43 2e3bc39 92a6f43 7e6964a ce2493d 7e6964a f3898ef ce2493d f3898ef ce2493d 9d2f4c9 ce2493d 9d2f4c9 ce2493d f3898ef ce2493d f3898ef ce2493d 9d2f4c9 ce2493d f3898ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import json
import copy
import pickle
import torch
from simplemma import lemmatize
from transformers import AutoTokenizer
from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification
from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens
# code originaly from data_manipulation.creation_gazetteers
def lemmatizing(x):
if x == "":
return ""
return lemmatize(x, lang="cs")
# code originaly from data_manipulation.creation_gazetteers
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
reverse_dictionary = {}
for key, values in dictionary.items():
for value in values:
reverse_dictionary[value] = key
if apply_lemmatizing:
temp = lemmatizing(value)
if temp != value:
reverse_dictionary[temp] = key
return reverse_dictionary
def load_json(path):
"""
Load gazetteers from a file
:param path: path to the gazetteer file
:return: a dict of gazetteers
"""
with open(path, 'r') as file:
data = json.load(file)
return data
def load_pickle(path):
"""
Load pickle gazetteers from a file
:param path: path to the gazetteer file
:return: a dict of gazetteers
"""
with open(path, 'rb') as file:
data = pickle.load(file)
return data
def load():
"""
Load the tokenizer, model, and gazetteers for named entity recognition.
Returns:
tokenizer (AutoTokenizer): The tokenizer for tokenizing input text.
model (ExtendedEmbeddigsRobertaForTokenClassification): The pre-trained model for named entity recognition.
gazetteers_for_matching (list): A list of gazetteers for matching named entities.
"""
model_name = "ufal/robeczech-base"
model_path = "bettystr/NerRoB-czech"
gazetteers_path = "gazetteers.pkl"
model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
gazetteers_for_matching = load_pickle(gazetteers_path)
temp = []
for i in gazetteers_for_matching.keys():
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
gazetteers_for_matching = temp
return tokenizer, model, gazetteers_for_matching
def add_additional_gazetteers(gazetteers_for_matching, file_names):
"""
Adds additional gazetteers to the existing dict.
Args:
gazetteers_for_matching (dict): The list of gazetteers to be updated.
file_names (list): The list of file names containing additional gazetteers.
Returns:
dict: The updated list of gazetteers.
"""
if file_names is None or file_names == []:
return gazetteers_for_matching
temp = []
for l1 in gazetteers_for_matching:
d2 = copy.deepcopy(l1)
temp.append(d2)
for file_name in file_names:
with open(file_name, 'r') as file:
data = json.load(file)
for key, value_lst in data.items():
key = key.upper()
for dictionary in temp:
if key in dictionary.values():
for value in value_lst:
dictionary[value] = key
return temp
def run(tokenizer, model, gazetteers, text, file_names=None):
"""
Runs the named entity recognition (NER) model on the given text.
Args:
tokenizer (Tokenizer): The tokenizer used to tokenize the input text.
model (Model): The NER model used for prediction.
gazetteers (list): A list of gazetteers used for matching entities in the text.
text (str): The input text to perform NER on.
file_names (list, optional): A list of file names to be used as additional gazetteers.
Returns:
list: A list of dictionaries representing the predicted entities in the text. Each dictionary contains the following keys:
- "start" (int): The starting position of the entity in the text.
- "end" (int): The ending position of the entity in the text.
- "entity" (str): The type of the entity.
- "score" (float): The confidence score of the entity prediction.
- "word" (str): The actual word representing the entity.
- "count" (int): The number of tokens in the entity.
"""
gazetteers_for_matching = add_additional_gazetteers(gazetteers, file_names)
tokenized_inputs = tokenizer(
text, truncation=True, is_split_into_words=False, return_offsets_mapping=True
)
matches = gazetteer_matching(text, gazetteers_for_matching)
new_g = []
word_ids = tokenized_inputs.word_ids()
new_g.append(align_gazetteers_with_tokens(matches, word_ids))
p, o, l = [], [], []
for i in new_g:
p.append([x[0] for x in i])
o.append([x[1] for x in i])
l.append([x[2] for x in i])
input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0)
attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0)
per = torch.tensor(p, device="cpu")
org = torch.tensor(o, device="cpu")
loc = torch.tensor(l, device="cpu")
output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
predictions = torch.argmax(output, dim=2).tolist()
predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
softmax = torch.nn.Softmax(dim=2)
scores = softmax(output).squeeze(0).tolist()
result = []
temp = {
"start": 0,
"end": 0,
"entity": "O",
"score": 0,
"word": "",
"count": 0
}
for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
if pos[0] == pos[1] or entity == "O":
continue
if "I-" + temp["entity"] == entity: # same entity
temp["word"] += text[temp["end"]:pos[0]] + text[pos[0]:pos[1]]
temp["end"] = pos[1]
temp["count"] += 1
temp["score"] += max(score)
else: # new entity
if temp["count"] > 0:
temp["score"] += max(score)
temp["score"] /= temp.pop("count")
result.append(temp)
temp = {
"start": pos[0],
"end": pos[1],
"entity": entity[2:],
"score": 0,
"word": text[pos[0]:pos[1]],
"count": 1
}
if temp["count"] > 0:
temp["score"] += max(score)
temp["score"] /= temp.pop("count")
result.append(temp)
return result
|