Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline | |
# use @st.cache decorator to cache model — because it is too large, we do not want to reload it every time | |
# use allow_output_mutation = True to tell streamlit that model should be treated as immutable object — singleton | |
# load model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh") | |
model = AutoModelForTokenClassification.from_pretrained("yeshpanovrustem/xlm-roberta-large-ner-kazakh") | |
labels_dict = {0: 'O', | |
1: 'B-ADAGE', | |
2: 'I-ADAGE', | |
3: 'B-ART', | |
4: 'I-ART', | |
5: 'B-CARDINAL', | |
6: 'I-CARDINAL', | |
7: 'B-CONTACT', | |
8: 'I-CONTACT', | |
9: 'B-DATE', | |
10: 'I-DATE', | |
11: 'B-DISEASE', | |
12: 'I-DISEASE', | |
13: 'B-EVENT', | |
14: 'I-EVENT', | |
15: 'B-FACILITY', | |
16: 'I-FACILITY', | |
17: 'B-GPE', | |
18: 'I-GPE', | |
19: 'B-LANGUAGE', | |
20: 'I-LANGUAGE', | |
21: 'B-LAW', | |
22: 'I-LAW', | |
23: 'B-LOCATION', | |
24: 'I-LOCATION', | |
25: 'B-MISCELLANEOUS', | |
26: 'I-MISCELLANEOUS', | |
27: 'B-MONEY', | |
28: 'I-MONEY', | |
29: 'B-NON_HUMAN', | |
30: 'I-NON_HUMAN', | |
31: 'B-NORP', | |
32: 'I-NORP', | |
33: 'B-ORDINAL', | |
34: 'I-ORDINAL', | |
35: 'B-ORGANISATION', | |
36: 'I-ORGANISATION', | |
37: 'B-PERSON', | |
38: 'I-PERSON', | |
39: 'B-PERCENTAGE', | |
40: 'I-PERCENTAGE', | |
41: 'B-POSITION', | |
42: 'I-POSITION', | |
43: 'B-PRODUCT', | |
44: 'I-PRODUCT', | |
45: 'B-PROJECT', | |
46: 'I-PROJECT', | |
47: 'B-QUANTITY', | |
48: 'I-QUANTITY', | |
49: 'B-TIME', | |
50: 'I-TIME'} | |
# # define function for ner | |
# def label_sentence(text): | |
# load pipeline | |
nlp = pipeline("ner", model = model, tokenizer = tokenizer) | |
example = "Қазақстан Республикасы — Шығыс Еуропа мен Орталық Азияда орналасқан мемлекет." | |
single_sentence_tokens = word_tokenize(example) | |
tokenized_input = tokenizer(single_sentence_tokens, is_split_into_words = True, return_tensors = "pt") | |
tokens = tokenized_input.tokens() | |
output = model(**tokenized_input).logits | |
predictions = torch.argmax(output, dim = 2) | |
# convert label IDs to label names | |
word_ids = tokenized_input.word_ids(batch_index = 0) | |
# print(count, word_ids) | |
previous_word_id = None | |
labels = [] | |
for token, word_id, prediction in zip(tokens, word_ids, predictions[0].numpy()): | |
# # Special tokens have a word id that is None. We set the label to -100 so they are | |
# # automatically ignored in the loss function. | |
# print(token, word_id, prediction) | |
if word_id is None or word_id == previous_word_id: | |
continue | |
elif word_id != previous_word_id: | |
labels.append(labels_dict[prediction]) | |
previous_word_id = word_id | |
# print(len(sentence_tokens), sentence_tokens) | |
# print(len(labels), labels) | |
assert len(single_sentence_tokens) == len(labels), "Mismatch between input token and label sizes!" | |
for token, label in zip(single_sentence_tokens, labels): | |
print(token, label) | |
# st.markdown("# Hello") | |
# # st.set_page_config(page_title = "Kazakh Named Entity Recognition", page_icon = "🔍") | |
# # st.title("🔍 Kazakh Named Entity Recognition") | |
# x = st.slider('Select a value') | |
# st.write(x, 'squared is', x * x) |