ajitrajasekharan commited on
Commit
b9fa3c7
1 Parent(s): d1cc326

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -22,7 +22,7 @@ def load_bert_model(model_name):
22
 
23
 
24
  def decode(tokenizer, pred_idx, top_clean):
25
- ignore_tokens = string.punctuation + '[PAD]'
26
  tokens = []
27
  for w in pred_idx:
28
  token = ''.join(tokenizer.decode(w).split())
@@ -51,7 +51,10 @@ def get_all_predictions(text_sentence, top_clean=5):
51
  predict = bert_model(input_ids)[0]
52
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
53
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
54
- return {'Input sentence':text_sentence,'Masked position': bert,'[CLS]':cls}
 
 
 
55
 
56
  def get_bert_prediction(input_text,top_k):
57
  try:
 
22
 
23
 
24
  def decode(tokenizer, pred_idx, top_clean):
25
+ ignore_tokens = string.punctuation
26
  tokens = []
27
  for w in pred_idx:
28
  token = ''.join(tokenizer.decode(w).split())
 
51
  predict = bert_model(input_ids)[0]
52
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
53
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
54
+ if ("[MASK]" in text_sentence or "<mask>" in text_sentecence):
55
+ return {'Input sentence':text_sentence,'Masked position': bert,'[CLS]':cls}
56
+ else:
57
+ return {'Input sentence':text_sentence,'[CLS]':cls}
58
 
59
  def get_bert_prediction(input_text,top_k):
60
  try: