ajitrajasekharan
commited on
Commit
•
c6d5fcb
1
Parent(s):
ecce248
Update app.py
Browse files
app.py
CHANGED
@@ -36,7 +36,10 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
|
|
36 |
text_sentence += ' .'
|
37 |
|
38 |
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
|
39 |
-
|
|
|
|
|
|
|
40 |
return input_ids, mask_idx
|
41 |
|
42 |
def get_all_predictions(text_sentence, top_clean=5):
|
|
|
36 |
text_sentence += ' .'
|
37 |
|
38 |
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
|
39 |
+
if (tokenizer.mask_token in text_sentence.split()):
|
40 |
+
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
|
41 |
+
else:
|
42 |
+
mask_idx = 0
|
43 |
return input_ids, mask_idx
|
44 |
|
45 |
def get_all_predictions(text_sentence, top_clean=5):
|