File size: 3,844 Bytes
b7137b1 3b3fa96 b7137b1 aed5912 b7137b1 1c53eb1 b7137b1 f264b44 b7137b1 1c53eb1 b7137b1 5408f33 b7137b1 5408f33 b7137b1 5408f33 b7137b1 5408f33 3b3fa96 b7137b1 3b3fa96 5408f33 fce5f58 0a7c967 dec3f54 aed5912 2056bb6 fce5f58 d4804d5 e32131f d4804d5 5408f33 b7137b1 5408f33 b7137b1 d4804d5 5408f33 b7137b1 32acf13 b7137b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import time
import streamlit as st
import torch
import string
from transformers import BertTokenizer, BertForMaskedLM
st.set_page_config(page_title='Qualitative pretrained model eveluation', page_icon=None, layout='centered', initial_sidebar_state='auto')
@st.cache()
def load_bert_model(model_name):
try:
bert_tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
return bert_tokenizer,bert_model
except Exception as e:
pass
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation + '[PAD]'
tokens = []
for w in pred_idx:
token = ''.join(tokenizer.decode(w).split())
#if token not in ignore_tokens:
# tokens.append(token.replace('##', ''))
tokens.append(token)
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
# if <mask> is the last token, append a "." so that models dont predict punctuation.
if tokenizer.mask_token == text_sentence.split()[-1]:
text_sentence += ' .'
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
return input_ids, mask_idx
def get_all_predictions(text_sentence, top_clean=5):
# ========================= BERT =================================
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k).indices.tolist(), top_clean)
return {'bert': bert,'[CLS]':cls}
def get_bert_prediction(input_text,top_k):
try:
#input_text += ' <mask>'
res = get_all_predictions(input_text, top_clean=int(top_k))
return res
except Exception as error:
pass
st.title("Qualitative evaluation of Pretrained BERT models")
st.markdown("""
<a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><small style="font-size:18px; color: #8f8f8f">This app is used to qualitatively examine the performance of pretrained models to do NER , <b>with no fine tuning</b></small></a>
""", unsafe_allow_html=True)
st.write("Incomplete. Work in progress...")
#st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
st.write("CLS vectors as well as the model prediction for a blank position are examined")
top_k = 10
print(top_k)
start = None
#if st.button("Submit"):
# with st.spinner("Computing"):
try:
model_name = st.sidebar.selectbox(label='Select Model to Apply', options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased'], index=0, key = "model_name")
bert_tokenizer, bert_model = load_bert_model(model_name)
default_text = "Imatinib is used to [MASK] nsclc"
input_text = st.text_area(
label="Original text",
value=default_text,
)
if st.button("Submit"):
with st.spinner("Computing"):
start = time.time()
try:
res = get_bert_prediction(input_text,top_k)
st.header("JSON:")
st.json(res)
except Exception as e:
st.error("Some error occurred during prediction" + str(e))
st.stop()
if start is not None:
st.text(f"prediction took {time.time() - start:.2f}s")
except Exception as e:
st.error("Some error occurred during loading" + str(e))
st.stop()
st.write("---")
|