ajitrajasekharan commited on
Commit
b7137b1
1 Parent(s): 1370c40

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+ import torch
4
+ import string
5
+
6
+ from transformers import BertTokenizer, BertForMaskedLM
7
+
8
+ @st.cache()
9
+ def load_bert_model(model_name):
10
+ try:
11
+ bert_tokenizer = BertTokenizer.from_pretrained(model_name)
12
+ bert_model = BertForMaskedLM.from_pretrained(model_name).eval()
13
+ return bert_tokenizer,bert_model
14
+ except Exception as e:
15
+ pass
16
+
17
+
18
+
19
+
20
+ def decode(tokenizer, pred_idx, top_clean):
21
+ ignore_tokens = string.punctuation + '[PAD]'
22
+ tokens = []
23
+ for w in pred_idx:
24
+ token = ''.join(tokenizer.decode(w).split())
25
+ if token not in ignore_tokens:
26
+ tokens.append(token.replace('##', ''))
27
+ return '\n'.join(tokens[:top_clean])
28
+
29
+ def encode(tokenizer, text_sentence, add_special_tokens=True):
30
+ text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
31
+ # if <mask> is the last token, append a "." so that models dont predict punctuation.
32
+ if tokenizer.mask_token == text_sentence.split()[-1]:
33
+ text_sentence += ' .'
34
+
35
+ input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
36
+ mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
37
+ return input_ids, mask_idx
38
+
39
+ def get_all_predictions(text_sentence, top_clean=5):
40
+ # ========================= BERT =================================
41
+ input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
42
+ with torch.no_grad():
43
+ predict = bert_model(input_ids)[0]
44
+ bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
45
+ return {'bert': bert}
46
+
47
+ def get_bert_prediction(input_text,top_k):
48
+ try:
49
+ input_text += ' <mask>'
50
+ res = get_all_predictions(input_text, top_clean=int(top_k))
51
+ return res
52
+ except Exception as error:
53
+ pass
54
+
55
+ try:
56
+
57
+ st.title("Qualitative evaluation of Pretrained BERT models")
58
+ st.markdown("""
59
+ <a href="https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html"><small style="font-size:18px; color: #8f8f8f">This app is used to qualitatively examine the performance of pretrained models to do NER , <b>with no fine tuning</b></small></a>
60
+ """, unsafe_allow_html=True)
61
+ st.write("Incomplete. Work in progress...")
62
+ #st.write("https://ajitrajasekharan.github.io/2021/01/02/my-first-post.html")
63
+ st.write("CLS vectors as well as the model prediction for a blank position are examined")
64
+
65
+ top_k = 10
66
+ print(top_k)
67
+
68
+
69
+ bert_tokenizer, bert_model = load_bert_model('ajitrajasekharan/biomedical')
70
+ default_text = "Imatinib is used to treat"
71
+
72
+
73
+ input_text = st.text_area(
74
+ label="Original text",
75
+ value=default_text,
76
+ )
77
+
78
+ start = None
79
+ if st.button("Submit"):
80
+ start = time.time()
81
+ with st.spinner("Computing"):
82
+
83
+
84
+
85
+
86
+ try:
87
+ res = get_bert_prediction(default_text,top_k)
88
+
89
+
90
+ st.header("JSON:")
91
+
92
+ st.json(res)
93
+
94
+ except Exception as e:
95
+ st.error("Some error occured!" + str(e))
96
+ st.stop()
97
+
98
+ st.write("---")
99
+
100
+
101
+
102
+ if start is not None:
103
+ st.text(f"prediction took {time.time() - start:.2f}s")
104
+
105
+ except Exception as e:
106
+ print("SOME PROBLEM OCCURED")
107
+
108
+