File size: 8,897 Bytes
b7137b1 afd47cf b72b8d9 167f69d b7137b1 68f2100 3b3fa96 b7137b1 6f4ba26 b7137b1 b9fa3c7 b7137b1 08b9f95 ecce248 b7137b1 a03c359 b7137b1 424d29e 1c53eb1 c6d5fcb 1f4b5d0 b7137b1 f8dc81b a03c359 424d29e a03c359 b7137b1 1f4b5d0 b7137b1 0d25a6d 293e817 2406613 b9fa3c7 2406613 b7137b1 f8dc81b b7137b1 1c53eb1 f8dc81b b7137b1 b9f419a eee0c36 d1b63cc eee0c36 d1b63cc 9d7bba2 b5aa429 eee0c36 b5aa429 eee0c36 b9f419a eee0c36 3f2b07b eee0c36 3f2b07b 6b14017 caf6c21 24b81ce b9f419a 77d733c eee0c36 672678f eee0c36 6f5d2d2 d1b63cc b7137b1 6f5d2d2 d1b63cc eb9cce7 a340b6b 68f2100 6f5d2d2 220d4ac b7137b1 8a3b8f4 f0a51d8 cae21f4 6f5d2d2 1e5d7d3 24b81ce d1b63cc 167f69d 6f5d2d2 a7a3b55 fde541a eee0c36 24b81ce eee0c36 6244fd0 eee0c36 24b81ce eee0c36 24b81ce eee0c36 b9f419a b7137b1 6f5d2d2 b7137b1 6421c87 1e5d7d3 eb960a9 6421c87 6f5d2d2 316bafd 6f5d2d2 b7137b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import time
import streamlit as st
import torch
import string
from transformers import BertTokenizer, BertForMaskedLM
st.set_page_config(page_title='Compare pretrained BERT models qualitatively', page_icon=None, layout='centered', initial_sidebar_state='auto')
@st.cache()
def load_bert_model(model_name):
try:
bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case
=False)
bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
return bert_tokenizer,bert_model
except Exception as e:
pass
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation
tokens = []
for w in pred_idx:
token = ''.join(tokenizer.decode(w).split())
if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
#tokens.append(token.replace('##', ''))
tokens.append(token)
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
tokenized_text = tokenizer.tokenize(text_sentence)
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
if (tokenizer.mask_token in text_sentence.split()):
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
else:
mask_idx = 0
return input_ids, mask_idx,tokenized_text
def get_all_predictions(text_sentence, model_name,top_clean=5):
bert_tokenizer = st.session_state['bert_tokenizer']
bert_model = st.session_state['bert_model']
top_k = st.session_state['top_k']
# ========================= BERT =================================
input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*2).indices.tolist(), top_clean)
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*2).indices.tolist(), top_clean)
if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
return {'Input sentence':text_sentence,'Tokenized text': tokenized_text, 'results_count':top_k,'Model':model_name,'Masked position': bert,'[CLS]':cls}
else:
return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'results_count':top_k,'Model':model_name,'[CLS]':cls}
def get_bert_prediction(input_text,top_k,model_name):
try:
#input_text += ' <mask>'
res = get_all_predictions(input_text,model_name, top_clean=int(top_k))
return res
except Exception as error:
pass
def run_test(sent,top_k,model_name,display_area):
if (st.session_state['bert_tokenizer'] is None):
display_area.text("Loading model:" + st.session_state['model_name'])
st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
display_area.text("Model " + str(st.session_state['model_name']) + " load complete")
try:
display_area.text("Computing fill-mask prediction...")
res = get_bert_prediction(sent,st.session_state['top_k'],st.session_state['model_name'])
display_area.text("Fill-mask prediction complete")
return res
except Exception as e:
st.error("Some error occurred during prediction" + str(e))
st.stop()
return {}
def on_text_change(text,display_area):
return run_test(text,st.session_state['top_k'],st.session_state['model_name'],display_area)
def on_model_change(model_name):
if (model_name != st.session_state['model_name']):
st.session_state['model_name'] = model_name
st.session_state['bert_tokenizer'], st.session_state['bert_model'] = load_bert_model(st.session_state['model_name'])
def init_selectbox():
return st.selectbox(
'Choose any of the sentences in pull-down below',
("[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","Lou Gehrig who lives in New York and works for [MASK] suffers from Parkinson's","Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's", "Parkinson who lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for XCorp suffers from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute lymphoblastic leukemia is a [MASK]","eGFR is a [MASK]","EGFR is a [MASK]","Trileptal is a [MASK]","no bond or se curity of any kind will be required of any [MASK] of this will","habeas corpus is a [MASK]","modus operandi is a [MASK]","the volunteers were instructed to buy specific systems using our usual [MASK] —anonymously and with cash"),key='my_choice')
def init_session_states():
if 'top_k' not in st.session_state:
st.session_state['top_k'] = 20
if 'bert_tokenizer' not in st.session_state:
st.session_state['bert_tokenizer'] = None
if 'bert_model' not in st.session_state:
st.session_state['bert_model'] = None
if 'model_name' not in st.session_state:
st.session_state['model_name'] = "ajitrajasekharan/biomedical"
def main():
init_session_states()
st.markdown("<h3 style='text-align: center;'>Compare pretrained BERT models qualitatively</h3>", unsafe_allow_html=True)
st.markdown("""
<small style="font-size:20px; color: #2f2f2f"><br/>Why compare pretrained models <b>before fine-tuning</b>?</small><br/><small style="font-size:16px; color: #7f7f7f">Pretrained BERT models can be used as is, <a href="https://huggingface.co/spaces/ajitrajasekharan/self-supervised-ner-biomedical" target='_blank'><b>with no fine tuning to perform tasks like NER.</b><br/></a>This can be done ideally by using both fill-mask and CLS predictions, or just using fill-mask predictions if CLS predictions are poor</small>
""", unsafe_allow_html=True)
st.write("This app can be used to examine both fill-mask predictions as well as the neighborhood of CLS vector")
st.write(" - To examine fill-mask predictions, enter the token [MASK] or <mask> in a sentence")
st.write(" - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
st.write("Pretrained BERT models from three domains (biomedical,PHI [person,location,org, etc.], and legal) are listed below. Their performance on domain specific sentences reveal both their strength and weakness.")
try:
with st.form('my_form'):
selected_sentence = init_selectbox()
text_input = st.text_input("Type any sentence below", "",key='my_text')
selected_model = st.selectbox(label='Select Model to Apply', options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased','dmis-lab/biobert-v1.1','nlpaueb/legal-bert-base-uncased'], index=0, key = "my_model1")
custom_model_selection = st.text_input("Model not listed on above? Type the model name (**fill-mask BERT models only**)", "",key="my_model2")
results_count = st.slider("Select count of predictions to display", 1 , 50, 20,key='my_slider') #some times it is possible to have less words
submit_button = st.form_submit_button('Submit')
input_status_area = st.empty()
display_area = st.empty()
if submit_button:
start = time.time()
if (len(text_input) == 0):
text_input = selected_sentence
st.session_state['top_k'] = results_count
if (len(custom_model_selection) != 0):
on_model_change(custom_model_selection)
else:
on_model_change(selected_model)
input_status_area.text("Input sentence: " + text_input)
results = on_text_change(text_input,display_area)
display_area.empty()
with display_area.container():
st.text(f"prediction took {time.time() - start:.2f}s")
st.json(results)
except Exception as e:
st.error("Some error occurred during loading" + str(e))
st.stop()
st.markdown("""
<h3 style="font-size:16px; color: #7f7f7f; text-align: center">Link to post <a href='https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html' target='_blank'>describing this approach </a></h3>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()
|