ajitrajasekharan
commited on
Commit
•
f8dc81b
1
Parent(s):
d503dfd
Update app.py
Browse files
app.py
CHANGED
@@ -46,7 +46,7 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
|
|
46 |
mask_idx = 0
|
47 |
return input_ids, mask_idx
|
48 |
|
49 |
-
def get_all_predictions(text_sentence, top_clean=5):
|
50 |
# ========================= BERT =================================
|
51 |
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
|
52 |
with torch.no_grad():
|
@@ -55,20 +55,20 @@ def get_all_predictions(text_sentence, top_clean=5):
|
|
55 |
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
|
56 |
|
57 |
if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
|
58 |
-
return {'Input sentence':text_sentence,'Masked position': bert,'[CLS]':cls}
|
59 |
else:
|
60 |
-
return {'Input sentence':text_sentence,'[CLS]':cls}
|
61 |
|
62 |
-
def get_bert_prediction(input_text,top_k):
|
63 |
try:
|
64 |
#input_text += ' <mask>'
|
65 |
-
res = get_all_predictions(input_text, top_clean=int(top_k))
|
66 |
return res
|
67 |
except Exception as error:
|
68 |
pass
|
69 |
|
70 |
|
71 |
-
def run_test(sent,top_k):
|
72 |
start = None
|
73 |
global bert_tokenizer
|
74 |
global bert_model
|
@@ -77,7 +77,7 @@ def run_test(sent,top_k):
|
|
77 |
with st.spinner("Computing"):
|
78 |
start = time.time()
|
79 |
try:
|
80 |
-
res = get_bert_prediction(sent,top_k)
|
81 |
st.caption("Results in JSON")
|
82 |
st.json(res)
|
83 |
|
@@ -115,13 +115,13 @@ try:
|
|
115 |
custom_model_name = st.text_input("Model not listed on left? Type the model name (fill-mask models only)", "")
|
116 |
if (len(custom_model_name) > 0):
|
117 |
model_name = custom_model_name
|
118 |
-
st.info("Custom model selected:" + model_name)
|
119 |
bert_tokenizer, bert_model = load_bert_model(model_name)
|
120 |
if len(input_text) > 0:
|
121 |
-
run_test(input_text,top_k)
|
122 |
else:
|
123 |
if len(option) > 0:
|
124 |
-
run_test(option,top_k)
|
125 |
if (bert_tokenizer is None):
|
126 |
bert_tokenizer, bert_model = load_bert_model(model_name)
|
127 |
|
|
|
46 |
mask_idx = 0
|
47 |
return input_ids, mask_idx
|
48 |
|
49 |
+
def get_all_predictions(text_sentence, model_name,top_clean=5):
|
50 |
# ========================= BERT =================================
|
51 |
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
|
52 |
with torch.no_grad():
|
|
|
55 |
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
|
56 |
|
57 |
if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
|
58 |
+
return {'Input sentence':text_sentence,'Model':model_name,'Masked position': bert,'[CLS]':cls}
|
59 |
else:
|
60 |
+
return {'Input sentence':text_sentence,'Model':model_name,'[CLS]':cls}
|
61 |
|
62 |
+
def get_bert_prediction(input_text,top_k,model_name):
|
63 |
try:
|
64 |
#input_text += ' <mask>'
|
65 |
+
res = get_all_predictions(input_text,model_name, top_clean=int(top_k))
|
66 |
return res
|
67 |
except Exception as error:
|
68 |
pass
|
69 |
|
70 |
|
71 |
+
def run_test(sent,top_k,model_name):
|
72 |
start = None
|
73 |
global bert_tokenizer
|
74 |
global bert_model
|
|
|
77 |
with st.spinner("Computing"):
|
78 |
start = time.time()
|
79 |
try:
|
80 |
+
res = get_bert_prediction(sent,top_k,model_name)
|
81 |
st.caption("Results in JSON")
|
82 |
st.json(res)
|
83 |
|
|
|
115 |
custom_model_name = st.text_input("Model not listed on left? Type the model name (fill-mask models only)", "")
|
116 |
if (len(custom_model_name) > 0):
|
117 |
model_name = custom_model_name
|
118 |
+
st.info("Custom model selected: " + model_name)
|
119 |
bert_tokenizer, bert_model = load_bert_model(model_name)
|
120 |
if len(input_text) > 0:
|
121 |
+
run_test(input_text,top_k,model_name)
|
122 |
else:
|
123 |
if len(option) > 0:
|
124 |
+
run_test(option,top_k,model_name)
|
125 |
if (bert_tokenizer is None):
|
126 |
bert_tokenizer, bert_model = load_bert_model(model_name)
|
127 |
|