File size: 6,118 Bytes
b7137b1
 
 
 
afd47cf
a03cf87
 
167f69d
b7137b1
 
3b3fa96
 
b7137b1
 
 
6f4ba26
 
b7137b1
 
 
 
 
 
 
 
 
b9fa3c7
b7137b1
 
 
08b9f95
ecce248
 
b7137b1
 
 
 
 
 
 
 
1c53eb1
c6d5fcb
 
 
 
b7137b1
 
 
 
 
 
 
08b9f95
 
293e817
 
b9fa3c7
 
293e817
b7137b1
 
 
1c53eb1
b7137b1
 
 
 
b9f419a
 
 
 
a03cf87
 
167f69d
 
b9f419a
 
 
 
 
 
 
 
 
 
 
 
 
b7137b1
35a5cd4
5408f33
35a5cd4
b7137b1
 
d1cc326
 
 
9c3de2e
5408f33
167f69d
b7137b1
 
b9f419a
3b3fa96
b7137b1
3b3fa96
5408f33
fce5f58
0a7c967
9edc4d0
b9f419a
b27f63f
90f04ee
0feff30
3af653e
 
49b6266
3af653e
90f04ee
b9f419a
 
 
 
b27f63f
 
b9f419a
b7137b1
 
d4804d5
5408f33
b7137b1
32acf13
b7137b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import time
import streamlit as st
import torch
import string

bert_tokenizer = None
bert_model = None

from transformers import BertTokenizer, BertForMaskedLM

st.set_page_config(page_title='Qualitative pretrained model eveluation', page_icon=None, layout='centered', initial_sidebar_state='auto')

@st.cache()
def load_bert_model(model_name):
  try:
    bert_tokenizer = BertTokenizer.from_pretrained(model_name,do_lower_case
    =False)
    bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
    return bert_tokenizer,bert_model
  except Exception as e:
    pass



  
def decode(tokenizer, pred_idx, top_clean):
  ignore_tokens = string.punctuation
  tokens = []
  for w in pred_idx:
    token = ''.join(tokenizer.decode(w).split())
    if token not in ignore_tokens and len(token) > 1 and not token.startswith('.') and not token.startswith('['):
      #tokens.append(token.replace('##', ''))
      tokens.append(token)
  return '\n'.join(tokens[:top_clean])

def encode(tokenizer, text_sentence, add_special_tokens=True):
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
  if tokenizer.mask_token == text_sentence.split()[-1]:
    text_sentence += ' .'

  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
  if (tokenizer.mask_token in text_sentence.split()):
    mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
  else:
    mask_idx = 0
  return input_ids, mask_idx

def get_all_predictions(text_sentence, top_clean=5):
    # ========================= BERT =================================
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
  with torch.no_grad():
    predict = bert_model(input_ids)[0]
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
  
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
    return {'Input sentence':text_sentence,'Masked position': bert,'[CLS]':cls}
  else:
    return {'Input sentence':text_sentence,'[CLS]':cls}

def get_bert_prediction(input_text,top_k):
  try:
    #input_text += ' <mask>'
    res = get_all_predictions(input_text, top_clean=int(top_k))
    return res
  except Exception as error:
    pass
    
 
def run_test(sent,top_k):
  start = None
  global bert_tokenizer
  global bert_model
  if (bert_tokenizer is None):
        bert_tokenizer, bert_model  = load_bert_model(model_name)
  with st.spinner("Computing"):
          start = time.time()
          try:
            res = get_bert_prediction(sent,top_k)
            st.caption("Results in JSON")
            st.json(res)
            
          except Exception as e:
            st.error("Some error occurred during prediction" + str(e))
            st.stop()
  if start is not None:
    st.text(f"prediction took {time.time() - start:.2f}s")
  

st.markdown("<h3 style='text-align: center;'>Qualitative evaluation of Pretrained BERT models</h3>", unsafe_allow_html=True)
st.markdown("""
        <small style="font-size:18px; color: #8f8f8f">This app is used to qualitatively examine the performance of pretrained models to do NER , <a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><b>with no fine tuning</b></small></a>
        """, unsafe_allow_html=True)
  #st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
st.write("Model prediction for a masked  position as well as the neighborhood of CLS vector for input text can be examined")
st.write("   - To examine model prediction for a position, enter the token [MASK] or <mask>")
st.write("   - To examine just the [CLS] vector, enter a word/phrase or sentence. Example: eGFR or EGFR or non small cell lung cancer")
top_k = st.sidebar.slider("Select how many predictions do you need", 1 , 50, 20) #some times it is possible to have less words
print(top_k)


  

  #if st.button("Submit"):
    
  #  with st.spinner("Computing"):
try:
            
      model_name = st.sidebar.selectbox(label='Select Model to Apply',  options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased'], index=0,  key = "model_name")
      option = st.selectbox(
     'Choose any of these sentences or type any text below',
      ('', "[MASK] who lives in New York and works for XCorp suffers from Parkinson's", "Lou Gehrig who lives in [MASK] and works for XCorp suffers from Parkinson's","'Lou Gehrig who lives in New York and         works for [MASK] suffers from Parkinson's'","'Lou Gehrig who lives in New York and works for XCorp suffers from [MASK]'","[MASK] who lives in New York and works for XCorp suffers from Lou Gehrig's",            "Parkinson who lives in [MASK] and works for XCorp suffers from Lou Gehrig's","Parkinson who lives in New York and works for [MASK] suffers from Lou Gehrig's","Parkinson who lives in New York and works for     XCorp suffers from [MASK]","Lou Gehrig","Parkinson","Lou Gehrigh's is a [MASK]","Parkinson is a [MASK]","New York is a [MASK]","New York","XCorp","XCorp is a [MASK]","acute lymphoblastic leukemia","acute       lymphoblastic leukemia is a [MASK]"))  
      input_text = st.text_input("Enter text below", "")
      custom_model_name = st.text_input("Model not listed on left? Type the model name (fill-mask models only)", "")
      if (len(custom_model_name) > 0):
        model_name = custom_model_name
        st.write("Custom model selected:" + model_name)
        bert_tokenizer, bert_model  = load_bert_model(model_name)
      if len(input_text) > 0:
        run_test(input_text,top_k)
      else:
        if len(option) > 0:
          run_test(option,top_k)
      if (bert_tokenizer is None):
        bert_tokenizer, bert_model  = load_bert_model(model_name)
      

except Exception as e:
  st.error("Some error occurred during loading" + str(e))
  st.stop()  
	
st.write("---")