ajitrajasekharan
commited on
Commit
•
ecce248
1
Parent(s):
aed5912
Update app.py
Browse files
app.py
CHANGED
@@ -24,9 +24,9 @@ def decode(tokenizer, pred_idx, top_clean):
|
|
24 |
tokens = []
|
25 |
for w in pred_idx:
|
26 |
token = ''.join(tokenizer.decode(w).split())
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
return '\n'.join(tokens[:top_clean])
|
31 |
|
32 |
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
@@ -44,8 +44,8 @@ def get_all_predictions(text_sentence, top_clean=5):
|
|
44 |
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
|
45 |
with torch.no_grad():
|
46 |
predict = bert_model(input_ids)[0]
|
47 |
-
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k).indices.tolist(), top_clean)
|
48 |
-
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k).indices.tolist(), top_clean)
|
49 |
return {'bert': bert,'[CLS]':cls}
|
50 |
|
51 |
def get_bert_prediction(input_text,top_k):
|
|
|
24 |
tokens = []
|
25 |
for w in pred_idx:
|
26 |
token = ''.join(tokenizer.decode(w).split())
|
27 |
+
if token not in ignore_tokens:
|
28 |
+
#tokens.append(token.replace('##', ''))
|
29 |
+
tokens.append(token)
|
30 |
return '\n'.join(tokens[:top_clean])
|
31 |
|
32 |
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
|
|
44 |
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
|
45 |
with torch.no_grad():
|
46 |
predict = bert_model(input_ids)[0]
|
47 |
+
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*2).indices.tolist(), top_clean)
|
48 |
+
cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*2).indices.tolist(), top_clean)
|
49 |
return {'bert': bert,'[CLS]':cls}
|
50 |
|
51 |
def get_bert_prediction(input_text,top_k):
|