ajitrajasekharan
commited on
Commit
•
b9fa3c7
1
Parent(s):
d1cc326
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ def load_bert_model(model_name):
|
|
22 |
|
23 |
|
24 |
def decode(tokenizer, pred_idx, top_clean):
|
25 |
-
ignore_tokens = string.punctuation
|
26 |
tokens = []
|
27 |
for w in pred_idx:
|
28 |
token = ''.join(tokenizer.decode(w).split())
|
@@ -51,7 +51,10 @@ def get_all_predictions(text_sentence, top_clean=5):
|
|
51 |
predict = bert_model(input_ids)[0]
|
52 |
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
|
53 |
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
def get_bert_prediction(input_text,top_k):
|
57 |
try:
|
|
|
22 |
|
23 |
|
24 |
def decode(tokenizer, pred_idx, top_clean):
|
25 |
+
ignore_tokens = string.punctuation
|
26 |
tokens = []
|
27 |
for w in pred_idx:
|
28 |
token = ''.join(tokenizer.decode(w).split())
|
|
|
51 |
predict = bert_model(input_ids)[0]
|
52 |
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
|
53 |
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
|
54 |
+
if ("[MASK]" in text_sentence or "<mask>" in text_sentecence):
|
55 |
+
return {'Input sentence':text_sentence,'Masked position': bert,'[CLS]':cls}
|
56 |
+
else:
|
57 |
+
return {'Input sentence':text_sentence,'[CLS]':cls}
|
58 |
|
59 |
def get_bert_prediction(input_text,top_k):
|
60 |
try:
|