ajitrajasekharan commited on
Commit
c6d5fcb
1 Parent(s): ecce248

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -36,7 +36,10 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
36
  text_sentence += ' .'
37
 
38
  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
39
- mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
 
 
 
40
  return input_ids, mask_idx
41
 
42
  def get_all_predictions(text_sentence, top_clean=5):
 
36
  text_sentence += ' .'
37
 
38
  input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
39
+ if (tokenizer.mask_token in text_sentence.split()):
40
+ mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
41
+ else:
42
+ mask_idx = 0
43
  return input_ids, mask_idx
44
 
45
  def get_all_predictions(text_sentence, top_clean=5):