ajitrajasekharan commited on
Commit
1f4b5d0
·
1 Parent(s): 5ce6476

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -36,28 +36,30 @@ def decode(tokenizer, pred_idx, top_clean):
36
  def encode(tokenizer, text_sentence, add_special_tokens=True):
37
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
38
  # if <mask> is the last token, append a "." so that models dont predict punctuation.
39
- if tokenizer.mask_token == text_sentence.split()[-1]:
40
- text_sentence += ' .'
41
 
 
42
  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
43
  if (tokenizer.mask_token in text_sentence.split()):
44
  mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
45
  else:
46
  mask_idx = 0
47
- return input_ids, mask_idx
48
 
49
  def get_all_predictions(text_sentence, model_name,top_clean=5):
50
  # ========================= BERT =================================
51
- input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
 
52
  with torch.no_grad():
53
  predict = bert_model(input_ids)[0]
54
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
55
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
56
 
57
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
58
- return {'Input sentence':text_sentence,'Model':model_name,'Masked position': bert,'[CLS]':cls}
59
  else:
60
- return {'Input sentence':text_sentence,'Model':model_name,'[CLS]':cls}
61
 
62
  def get_bert_prediction(input_text,top_k,model_name):
63
  try:
 
36
  def encode(tokenizer, text_sentence, add_special_tokens=True):
37
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
38
  # if <mask> is the last token, append a "." so that models dont predict punctuation.
39
+ #if tokenizer.mask_token == text_sentence.split()[-1]:
40
+ # text_sentence += ' .'
41
 
42
+ tokenized_text = bert_tokenizer.tokenize(text_sentence)
43
  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
44
  if (tokenizer.mask_token in text_sentence.split()):
45
  mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
46
  else:
47
  mask_idx = 0
48
+ return input_ids, mask_idx,tokenized_text
49
 
50
  def get_all_predictions(text_sentence, model_name,top_clean=5):
51
  # ========================= BERT =================================
52
+ input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
53
+
54
  with torch.no_grad():
55
  predict = bert_model(input_ids)[0]
56
  bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
57
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
58
 
59
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
60
+ return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'Model':model_name,'Masked position': bert,'[CLS]':cls}
61
  else:
62
+ return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'Model':model_name,'[CLS]':cls}
63
 
64
  def get_bert_prediction(input_text,top_k,model_name):
65
  try: