File size: 8,897 Bytes
b7137b1
 
 
 
afd47cf
b72b8d9
167f69d
b7137b1
 
68f2100
3b3fa96
b7137b1
 
 
6f4ba26
 
b7137b1
 
 
 
 
 
 
 
 
b9fa3c7
b7137b1
 
 
08b9f95
ecce248
 
b7137b1
 
 
a03c359
b7137b1
 
424d29e
1c53eb1
c6d5fcb
 
 
 
1f4b5d0
b7137b1
f8dc81b
a03c359
 
424d29e
a03c359
b7137b1
1f4b5d0
 
b7137b1
 
0d25a6d
 
293e817
 
2406613
b9fa3c7
2406613
b7137b1
f8dc81b
b7137b1
1c53eb1
f8dc81b
b7137b1
 
 
b9f419a
 
eee0c36
d1b63cc
eee0c36
d1b63cc
9d7bba2
b5aa429
 
eee0c36
b5aa429
eee0c36
b9f419a
eee0c36
 
 
 
3f2b07b
eee0c36
 
3f2b07b
6b14017
caf6c21
24b81ce
 
 
 
 
b9f419a
77d733c
eee0c36
672678f
eee0c36
6f5d2d2
d1b63cc
 
 
 
 
 
 
 
 
b7137b1
6f5d2d2
d1b63cc
 
eb9cce7
a340b6b
68f2100
6f5d2d2
220d4ac
b7137b1
8a3b8f4
f0a51d8
cae21f4
6f5d2d2
1e5d7d3
24b81ce
d1b63cc
167f69d
6f5d2d2
a7a3b55
fde541a
eee0c36
 
 
24b81ce
 
 
eee0c36
6244fd0
eee0c36
 
 
 
 
 
24b81ce
eee0c36
24b81ce
 
 
 
eee0c36
 
 
 
 
 
 
b9f419a
b7137b1
6f5d2d2
 
 
b7137b1
6421c87
1e5d7d3
eb960a9
6421c87
6f5d2d2
 
316bafd
 
6f5d2d2
b7137b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import time
import streamlit as st
import torch
import string



from transformers import BertTokenizer, BertForMaskedLM

st.set_page_config(page_title='Compare pretrained BERT models qualitatively', page_icon=None, layout='centered', initial_sidebar_state='auto')

@st.cache()
def load_bert_model(model_name):
  try:
    bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case
    =False)
    bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
    return bert_tokenizer,bert_model
  except Exception as e:
    pass



  
def decode(tokenizer, pred_idx, top_clean):
  ignore_tokens = string.punctuation
  tokens = []
  for w in pred_idx:
    token = ''.join(tokenizer.decode(w).split())
    if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
      #tokens.append(token.replace('##', ''))
      tokens.append(token)
  return '\n'.join(tokens[:top_clean])

def encode(tokenizer, text_sentence, add_special_tokens=True):
  
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)

  tokenized_text = tokenizer.tokenize(text_sentence) 
  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
  if (tokenizer.mask_token in text_sentence.split()):
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
  else:
    mask_idx = 0
  return input_ids, mask_idx,tokenized_text

def get_all_predictions(text_sentence, model_name,top_clean=5):
  bert_tokenizer = st.session_state['bert_tokenizer']
  bert_model = st.session_state['bert_model']
  top_k = st.session_state['top_k']
  
    # ========================= BERT =================================
  input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
   
  with torch.no_grad():
    predict = bert_model(input_ids)[0]
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*2).indices.tolist(), top_clean)
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*2).indices.tolist(), top_clean)
  
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
    return {'Input sentence':text_sentence,'Tokenized text': tokenized_text, 'results_count':top_k,'Model':model_name,'Masked position': bert,'[CLS]':cls}
  else:
    return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'results_count':top_k,'Model':model_name,'[CLS]':cls}

def get_bert_prediction(input_text,top_k,model_name):
  try:
    #input_text += ' <mask>'
    res = get_all_predictions(input_text,model_name, top_clean=int(top_k))
    return res
  except Exception as error:
    pass
    
 
def run_test(sent,top_k,model_name,display_area):
  if (st.session_state['bert_tokenizer'] is None):
        display_area.text("Loading model:" + st.session_state['model_name'])
        st.session_state['bert_tokenizer'], st.session_state['bert_model']  = load_bert_model(st.session_state['model_name'])
        display_area.text("Model " + str(st.session_state['model_name'])  + " load complete")
  try:
        display_area.text("Computing fill-mask prediction...")
        res = get_bert_prediction(sent,st.session_state['top_k'],st.session_state['model_name'])
        display_area.text("Fill-mask prediction complete")
        return res
            
  except Exception as e:
        st.error("Some error occurred during prediction" + str(e))
        st.stop()
  return {}
    
def on_text_change(text,display_area):
  return run_test(text,st.session_state['top_k'],st.session_state['model_name'],display_area)

 


def on_model_change(model_name): 
  if (model_name != st.session_state['model_name']):
    st.session_state['model_name'] = model_name
    st.session_state['bert_tokenizer'], st.session_state['bert_model']  = load_bert_model(st.session_state['model_name'])
  
def init_selectbox():
  return st.selectbox(
     'Choose any of the sentences in pull-down below',
     ("[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","Lou Gehrig who lives in New York and works       for [MASK] suffers from Parkinson's","Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's", "Parkinson who      lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for XCorp suffers      from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute lymphoblastic       leukemia is a [MASK]","eGFR is a [MASK]","EGFR is a [MASK]","Trileptal is a [MASK]","no bond or se curity of any kind will be required of any [MASK] of this will","habeas corpus is a [MASK]","modus             operandi is a [MASK]","the volunteers were instructed to buy specific systems using our usual [MASK] —anonymously and with cash"),key='my_choice')
  
def init_session_states():
  if 'top_k' not in st.session_state:
    st.session_state['top_k'] = 20
  if 'bert_tokenizer' not in st.session_state:
    st.session_state['bert_tokenizer'] = None
  if 'bert_model' not in st.session_state:
    st.session_state['bert_model'] = None
  if 'model_name' not in st.session_state:
    st.session_state['model_name'] = "ajitrajasekharan/biomedical"

def main():
  init_session_states()
  

  
  st.markdown("<h3 style='text-align: center;'>Compare pretrained BERT models qualitatively</h3>", unsafe_allow_html=True)
  st.markdown("""
        <small style="font-size:20px; color: #2f2f2f"><br/>Why compare pretrained models <b>before fine-tuning</b>?</small><br/><small style="font-size:16px; color: #7f7f7f">Pretrained BERT models can be used as is, <a href="https://huggingface.co/spaces/ajitrajasekharan/self-supervised-ner-biomedical" target='_blank'><b>with no fine tuning to perform tasks like NER.</b><br/></a>This can be done ideally by using both fill-mask and CLS predictions, or just using fill-mask predictions if CLS predictions are poor</small>
        """, unsafe_allow_html=True)

  st.write("This app can be used to examine both fill-mask predictions as well as the neighborhood of CLS vector")
  st.write("   - To examine fill-mask predictions, enter the token [MASK] or <mask> in a sentence")
  st.write("   - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
  st.write("Pretrained BERT models from three domains (biomedical,PHI [person,location,org, etc.], and legal) are listed below. Their performance on domain specific sentences reveal both their strength and weakness.")
  


  try:
      
      
      with st.form('my_form'):
        selected_sentence = init_selectbox()
        text_input = st.text_input("Type any sentence below", "",key='my_text')
        selected_model = st.selectbox(label='Select Model to Apply',  options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased','dmis-lab/biobert-v1.1','nlpaueb/legal-bert-base-uncased'], index=0,  key = "my_model1")
        custom_model_selection = st.text_input("Model not listed on above? Type the model name (**fill-mask BERT models only**)", "",key="my_model2")
        results_count = st.slider("Select count of predictions to display", 1 , 50, 20,key='my_slider') #some times it is possible to have less words
        submit_button = st.form_submit_button('Submit')
        
        input_status_area = st.empty()
        display_area = st.empty()
        if 	submit_button:
            start = time.time()
            if (len(text_input) == 0):
              text_input = selected_sentence
            st.session_state['top_k'] = results_count
            if (len(custom_model_selection) != 0):
              on_model_change(custom_model_selection)
            else:
              on_model_change(selected_model)    
                
            input_status_area.text("Input sentence:  " + text_input)
            results = on_text_change(text_input,display_area)
            display_area.empty()
            with display_area.container():
              st.text(f"prediction took {time.time() - start:.2f}s")
              st.json(results)
      
      

  except Exception as e:
    st.error("Some error occurred during loading" + str(e))
    st.stop()  
	
        
  st.markdown("""
    <h3 style="font-size:16px; color: #7f7f7f; text-align: center">Link to post <a href='https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html' target='_blank'>describing this approach </a></h3>
  """, unsafe_allow_html=True)
  
 

if __name__ == "__main__":
   main()