Commit
·
1f4b5d0
1
Parent(s):
5ce6476
Update app.py
Browse files
app.py
CHANGED
@@ -36,28 +36,30 @@ def decode(tokenizer, pred_idx, top_clean):
|
|
36 |
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
37 |
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
|
38 |
# if <mask> is the last token, append a "." so that models dont predict punctuation.
|
39 |
-
if tokenizer.mask_token == text_sentence.split()[-1]:
|
40 |
-
|
41 |
|
|
|
42 |
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
|
43 |
if (tokenizer.mask_token in text_sentence.split()):
|
44 |
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
|
45 |
else:
|
46 |
mask_idx = 0
|
47 |
-
return input_ids, mask_idx
|
48 |
|
49 |
def get_all_predictions(text_sentence, model_name,top_clean=5):
|
50 |
# ========================= BERT =================================
|
51 |
-
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
|
|
|
52 |
with torch.no_grad():
|
53 |
predict = bert_model(input_ids)[0]
|
54 |
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
|
55 |
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
|
56 |
|
57 |
if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
|
58 |
-
return {'Input sentence':text_sentence,'Model':model_name,'Masked position': bert,'[CLS]':cls}
|
59 |
else:
|
60 |
-
return {'Input sentence':text_sentence,'Model':model_name,'[CLS]':cls}
|
61 |
|
62 |
def get_bert_prediction(input_text,top_k,model_name):
|
63 |
try:
|
|
|
36 |
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
37 |
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
|
38 |
# if <mask> is the last token, append a "." so that models dont predict punctuation.
|
39 |
+
#if tokenizer.mask_token == text_sentence.split()[-1]:
|
40 |
+
# text_sentence += ' .'
|
41 |
|
42 |
+
tokenized_text = bert_tokenizer.tokenize(text_sentence)
|
43 |
input_ids = torch.tensor([tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
|
44 |
if (tokenizer.mask_token in text_sentence.split()):
|
45 |
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
|
46 |
else:
|
47 |
mask_idx = 0
|
48 |
+
return input_ids, mask_idx,tokenized_text
|
49 |
|
50 |
def get_all_predictions(text_sentence, model_name,top_clean=5):
|
51 |
# ========================= BERT =================================
|
52 |
+
input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
|
53 |
+
|
54 |
with torch.no_grad():
|
55 |
predict = bert_model(input_ids)[0]
|
56 |
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*5).indices.tolist(), top_clean)
|
57 |
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
|
58 |
|
59 |
if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
|
60 |
+
return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'Model':model_name,'Masked position': bert,'[CLS]':cls}
|
61 |
else:
|
62 |
+
return {'Input sentence':text_sentence,'Tokenized text': tokenized_text,'Model':model_name,'[CLS]':cls}
|
63 |
|
64 |
def get_bert_prediction(input_text,top_k,model_name):
|
65 |
try:
|