File size: 3,844 Bytes
b7137b1
 
 
 
 
 
 
3b3fa96
 
b7137b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aed5912
 
 
b7137b1
 
 
 
 
 
 
 
1c53eb1
 
b7137b1
 
 
 
 
 
 
 
f264b44
 
b7137b1
 
 
1c53eb1
b7137b1
 
 
 
 
5408f33
 
b7137b1
 
5408f33
b7137b1
5408f33
 
 
b7137b1
 
5408f33
3b3fa96
b7137b1
3b3fa96
5408f33
fce5f58
0a7c967
dec3f54
aed5912
2056bb6
fce5f58
 
 
d4804d5
 
 
 
 
 
 
e32131f
 
d4804d5
 
5408f33
b7137b1
 
5408f33
 
b7137b1
 
d4804d5
5408f33
b7137b1
32acf13
b7137b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import time
import streamlit as st
import torch
import string

from transformers import BertTokenizer, BertForMaskedLM

st.set_page_config(page_title='Qualitative pretrained model eveluation', page_icon=None, layout='centered', initial_sidebar_state='auto')

@st.cache()
def load_bert_model(model_name):
  try:
    bert_tokenizer = BertTokenizer.from_pretrained(model_name)
    bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
    return bert_tokenizer,bert_model
  except Exception as e:
    pass



  
def decode(tokenizer, pred_idx, top_clean):
  ignore_tokens = string.punctuation + '[PAD]'
  tokens = []
  for w in pred_idx:
    token = ''.join(tokenizer.decode(w).split())
    #if token not in ignore_tokens:
    #  tokens.append(token.replace('##', ''))
    tokens.append(token)
  return '\n'.join(tokens[:top_clean])

def encode(tokenizer, text_sentence, add_special_tokens=True):
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
    # if <mask> is the last token, append a "." so that models dont predict punctuation.
  if tokenizer.mask_token == text_sentence.split()[-1]:
    text_sentence += ' .'

  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
  mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
  return input_ids, mask_idx

def get_all_predictions(text_sentence, top_clean=5):
    # ========================= BERT =================================
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
  with torch.no_grad():
    predict = bert_model(input_ids)[0]
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k).indices.tolist(), top_clean)
  return {'bert': bert,'[CLS]':cls}

def get_bert_prediction(input_text,top_k):
  try:
    #input_text += ' <mask>'
    res = get_all_predictions(input_text, top_clean=int(top_k))
    return res
  except Exception as error:
    pass

st.title("Qualitative evaluation of Pretrained BERT models")
st.markdown("""
        <a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><small style="font-size:18px; color: #8f8f8f">This app is used to qualitatively examine the performance of pretrained models to do NER , <b>with no fine tuning</b></small></a>
        """, unsafe_allow_html=True)
st.write("Incomplete. Work in progress...")
  #st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
st.write("CLS vectors as well as the model prediction for a blank position are examined")
top_k = 10
print(top_k)

  
start = None
  #if st.button("Submit"):
    
  #  with st.spinner("Computing"):
try:
            
      model_name = st.sidebar.selectbox(label='Select Model to Apply',  options=['ajitrajasekharan/biomedical', 'bert-base-cased','bert-large-cased','microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext','allenai/scibert_scivocab_cased'], index=0,  key = "model_name")
      bert_tokenizer, bert_model  = load_bert_model(model_name)
      default_text = "Imatinib is used to [MASK] nsclc"
      input_text = st.text_area(
                  label="Original text",
                  value=default_text,
                )
      if st.button("Submit"):
        with st.spinner("Computing"):
          start = time.time()
          try:
            res = get_bert_prediction(input_text,top_k)
            st.header("JSON:")
            st.json(res)
            
          except Exception as e:
            st.error("Some error occurred during prediction" + str(e))
            st.stop()
 

	
      if start is not None:
          st.text(f"prediction took {time.time() - start:.2f}s")

except Exception as e:
  st.error("Some error occurred during loading" + str(e))
  st.stop()  
	
st.write("---")