NerRoB-czech / website_script.py
AlzbetaStrompova
fix param name
d09e4cf
import json
import copy
import pickle
import torch
from simplemma import lemmatize
from transformers import AutoTokenizer
from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification
from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens
# code originaly from data_manipulation.creation_gazetteers
def lemmatizing(x):
if x == "":
return ""
return lemmatize(x, lang="cs")
# code originaly from data_manipulation.creation_gazetteers
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
reverse_dictionary = {}
for key, values in dictionary.items():
for value in values:
reverse_dictionary[value] = key
if apply_lemmatizing:
temp = lemmatizing(value)
if temp != value:
reverse_dictionary[temp] = key
return reverse_dictionary
def load_json(path):
"""
Load gazetteers from a file
:param path: path to the gazetteer file
:return: a dict of gazetteers
"""
with open(path, 'r') as file:
data = json.load(file)
return data
def load_pickle(path):
"""
Load pickle gazetteers from a file
:param path: path to the gazetteer file
:return: a dict of gazetteers
"""
with open(path, 'rb') as file:
data = pickle.load(file)
return data
def load():
"""
Load the tokenizer, model, and gazetteers for named entity recognition.
Returns:
tokenizer (AutoTokenizer): The tokenizer for tokenizing input text.
model (ExtendedEmbeddigsRobertaForTokenClassification): The pre-trained model for named entity recognition.
gazetteers_for_matching (list): A list of gazetteers for matching named entities.
"""
model_name = "ufal/robeczech-base"
model_path = "bettystr/NerRoB-czech"
gazetteers_path = "gazetteers.pkl"
model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
gazetteers_for_matching = load_pickle(gazetteers_path)
temp = []
for i in gazetteers_for_matching.keys():
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
gazetteers_for_matching = temp
return tokenizer, model, gazetteers_for_matching
def add_additional_gazetteers(gazetteers_for_matching, file_names):
"""
Adds additional gazetteers to the existing dict.
Args:
gazetteers_for_matching (dict): The list of gazetteers to be updated.
file_names (list): The list of file names containing additional gazetteers.
Returns:
dict: The updated list of gazetteers.
"""
if file_names is None or file_names == []:
return gazetteers_for_matching
temp = []
for l1 in gazetteers_for_matching:
d2 = copy.deepcopy(l1)
temp.append(d2)
for file_name in file_names:
with open(file_name, 'r') as file:
data = json.load(file)
for key, value_lst in data.items():
key = key.upper()
for dictionary in temp:
if key in dictionary.values():
for value in value_lst:
dictionary[value] = key
return temp
def run(tokenizer, model, gazetteers, text, file_names=None):
"""
Runs the named entity recognition (NER) model on the given text.
Args:
tokenizer (Tokenizer): The tokenizer used to tokenize the input text.
model (Model): The NER model used for prediction.
gazetteers (list): A list of gazetteers used for matching entities in the text.
text (str): The input text to perform NER on.
file_names (list, optional): A list of file names to be used as additional gazetteers.
Returns:
list: A list of dictionaries representing the predicted entities in the text. Each dictionary contains the following keys:
- "start" (int): The starting position of the entity in the text.
- "end" (int): The ending position of the entity in the text.
- "entity" (str): The type of the entity.
- "score" (float): The confidence score of the entity prediction.
- "word" (str): The actual word representing the entity.
- "count" (int): The number of tokens in the entity.
"""
gazetteers_for_matching = add_additional_gazetteers(gazetteers, file_names)
tokenized_inputs = tokenizer(
text, truncation=True, is_split_into_words=False, return_offsets_mapping=True
)
matches = gazetteer_matching(text, gazetteers_for_matching)
new_g = []
word_ids = tokenized_inputs.word_ids()
new_g.append(align_gazetteers_with_tokens(matches, word_ids))
p, o, l = [], [], []
for i in new_g:
p.append([x[0] for x in i])
o.append([x[1] for x in i])
l.append([x[2] for x in i])
input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0)
attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0)
per = torch.tensor(p, device="cpu")
org = torch.tensor(o, device="cpu")
loc = torch.tensor(l, device="cpu")
output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
predictions = torch.argmax(output, dim=2).tolist()
predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
softmax = torch.nn.Softmax(dim=2)
scores = softmax(output).squeeze(0).tolist()
result = []
temp = {
"start": 0,
"end": 0,
"entity": "O",
"score": 0,
"word": "",
"count": 0
}
for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores):
if pos[0] == pos[1] or entity == "O":
continue
if "I-" + temp["entity"] == entity: # same entity
temp["word"] += text[temp["end"]:pos[0]] + text[pos[0]:pos[1]]
temp["end"] = pos[1]
temp["count"] += 1
temp["score"] += max(score)
else: # new entity
if temp["count"] > 0:
temp["score"] += max(score)
temp["score"] /= temp.pop("count")
result.append(temp)
temp = {
"start": pos[0],
"end": pos[1],
"entity": entity[2:],
"score": 0,
"word": text[pos[0]:pos[1]],
"count": 1
}
if temp["count"] > 0:
temp["score"] += max(score)
temp["score"] /= temp.pop("count")
result.append(temp)
return result