Spaces:
Sleeping
Sleeping
import json | |
import copy | |
import pickle | |
import torch | |
from simplemma import lemmatize | |
from transformers import AutoTokenizer | |
from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification | |
from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens | |
# code originaly from data_manipulation.creation_gazetteers | |
def lemmatizing(x): | |
if x == "": | |
return "" | |
return lemmatize(x, lang="cs") | |
# code originaly from data_manipulation.creation_gazetteers | |
def build_reverse_dictionary(dictionary, apply_lemmatizing=False): | |
reverse_dictionary = {} | |
for key, values in dictionary.items(): | |
for value in values: | |
reverse_dictionary[value] = key | |
if apply_lemmatizing: | |
temp = lemmatizing(value) | |
if temp != value: | |
reverse_dictionary[temp] = key | |
return reverse_dictionary | |
def load_json(path): | |
""" | |
Load gazetteers from a file | |
:param path: path to the gazetteer file | |
:return: a dict of gazetteers | |
""" | |
with open(path, 'r') as file: | |
data = json.load(file) | |
return data | |
def load_pickle(path): | |
""" | |
Load pickle gazetteers from a file | |
:param path: path to the gazetteer file | |
:return: a dict of gazetteers | |
""" | |
with open(path, 'rb') as file: | |
data = pickle.load(file) | |
return data | |
def load(): | |
""" | |
Load the tokenizer, model, and gazetteers for named entity recognition. | |
Returns: | |
tokenizer (AutoTokenizer): The tokenizer for tokenizing input text. | |
model (ExtendedEmbeddigsRobertaForTokenClassification): The pre-trained model for named entity recognition. | |
gazetteers_for_matching (list): A list of gazetteers for matching named entities. | |
""" | |
model_name = "ufal/robeczech-base" | |
model_path = "bettystr/NerRoB-czech" | |
gazetteers_path = "gazetteers.pkl" | |
model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model.eval() | |
gazetteers_for_matching = load_pickle(gazetteers_path) | |
temp = [] | |
for i in gazetteers_for_matching.keys(): | |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) | |
gazetteers_for_matching = temp | |
return tokenizer, model, gazetteers_for_matching | |
def add_additional_gazetteers(gazetteers_for_matching, file_names): | |
""" | |
Adds additional gazetteers to the existing dict. | |
Args: | |
gazetteers_for_matching (dict): The list of gazetteers to be updated. | |
file_names (list): The list of file names containing additional gazetteers. | |
Returns: | |
dict: The updated list of gazetteers. | |
""" | |
if file_names is None or file_names == []: | |
return gazetteers_for_matching | |
temp = [] | |
for l1 in gazetteers_for_matching: | |
d2 = copy.deepcopy(l1) | |
temp.append(d2) | |
for file_name in file_names: | |
with open(file_name, 'r') as file: | |
data = json.load(file) | |
for key, value_lst in data.items(): | |
key = key.upper() | |
for dictionary in temp: | |
if key in dictionary.values(): | |
for value in value_lst: | |
dictionary[value] = key | |
return temp | |
def run(tokenizer, model, gazetteers, text, file_names=None): | |
""" | |
Runs the named entity recognition (NER) model on the given text. | |
Args: | |
tokenizer (Tokenizer): The tokenizer used to tokenize the input text. | |
model (Model): The NER model used for prediction. | |
gazetteers (list): A list of gazetteers used for matching entities in the text. | |
text (str): The input text to perform NER on. | |
file_names (list, optional): A list of file names to be used as additional gazetteers. | |
Returns: | |
list: A list of dictionaries representing the predicted entities in the text. Each dictionary contains the following keys: | |
- "start" (int): The starting position of the entity in the text. | |
- "end" (int): The ending position of the entity in the text. | |
- "entity" (str): The type of the entity. | |
- "score" (float): The confidence score of the entity prediction. | |
- "word" (str): The actual word representing the entity. | |
- "count" (int): The number of tokens in the entity. | |
""" | |
gazetteers_for_matching = add_additional_gazetteers(gazetteers, file_names) | |
tokenized_inputs = tokenizer( | |
text, truncation=True, is_split_into_words=False, return_offsets_mapping=True | |
) | |
matches = gazetteer_matching(text, gazetteers_for_matching) | |
new_g = [] | |
word_ids = tokenized_inputs.word_ids() | |
new_g.append(align_gazetteers_with_tokens(matches, word_ids)) | |
p, o, l = [], [], [] | |
for i in new_g: | |
p.append([x[0] for x in i]) | |
o.append([x[1] for x in i]) | |
l.append([x[2] for x in i]) | |
input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0) | |
attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0) | |
per = torch.tensor(p, device="cpu") | |
org = torch.tensor(o, device="cpu") | |
loc = torch.tensor(l, device="cpu") | |
output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits | |
predictions = torch.argmax(output, dim=2).tolist() | |
predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions] | |
softmax = torch.nn.Softmax(dim=2) | |
scores = softmax(output).squeeze(0).tolist() | |
result = [] | |
temp = { | |
"start": 0, | |
"end": 0, | |
"entity": "O", | |
"score": 0, | |
"word": "", | |
"count": 0 | |
} | |
for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores): | |
if pos[0] == pos[1] or entity == "O": | |
continue | |
if "I-" + temp["entity"] == entity: # same entity | |
temp["word"] += text[temp["end"]:pos[0]] + text[pos[0]:pos[1]] | |
temp["end"] = pos[1] | |
temp["count"] += 1 | |
temp["score"] += max(score) | |
else: # new entity | |
if temp["count"] > 0: | |
temp["score"] += max(score) | |
temp["score"] /= temp.pop("count") | |
result.append(temp) | |
temp = { | |
"start": pos[0], | |
"end": pos[1], | |
"entity": entity[2:], | |
"score": 0, | |
"word": text[pos[0]:pos[1]], | |
"count": 1 | |
} | |
if temp["count"] > 0: | |
temp["score"] += max(score) | |
temp["score"] /= temp.pop("count") | |
result.append(temp) | |
return result | |