ajitrajasekharan
commited on
Commit
•
a03c359
1
Parent(s):
07c19db
Update app.py
Browse files
app.py
CHANGED
@@ -33,6 +33,9 @@ def decode(tokenizer, pred_idx, top_clean):
|
|
33 |
return '\n'.join(tokens[:top_clean])
|
34 |
|
35 |
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
|
|
|
|
|
|
36 |
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
|
37 |
# if <mask> is the last token, append a "." so that models dont predict punctuation.
|
38 |
#if tokenizer.mask_token == text_sentence.split()[-1]:
|
@@ -47,6 +50,9 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
|
|
47 |
return input_ids, mask_idx,tokenized_text
|
48 |
|
49 |
def get_all_predictions(text_sentence, model_name,top_clean=5):
|
|
|
|
|
|
|
50 |
# ========================= BERT =================================
|
51 |
input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
|
52 |
|
@@ -95,7 +101,7 @@ def on_text_change():
|
|
95 |
def on_option_change():
|
96 |
|
97 |
text = st.session_state.my_choice
|
98 |
-
st.info("Preselected text chosen")
|
99 |
run_test(text,st.session_state['top_k'],st.session_state['model_name'])
|
100 |
|
101 |
def on_results_count_change():
|
|
|
33 |
return '\n'.join(tokens[:top_clean])
|
34 |
|
35 |
def encode(tokenizer, text_sentence, add_special_tokens=True):
|
36 |
+
bert_tokenizer = st.session_state['bert_tokenizer']
|
37 |
+
bert_model = st.session_state['bert_model']
|
38 |
+
|
39 |
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
|
40 |
# if <mask> is the last token, append a "." so that models dont predict punctuation.
|
41 |
#if tokenizer.mask_token == text_sentence.split()[-1]:
|
|
|
50 |
return input_ids, mask_idx,tokenized_text
|
51 |
|
52 |
def get_all_predictions(text_sentence, model_name,top_clean=5):
|
53 |
+
bert_tokenizer = st.session_state['bert_tokenizer']
|
54 |
+
bert_model = st.session_state['bert_model']
|
55 |
+
|
56 |
# ========================= BERT =================================
|
57 |
input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
|
58 |
|
|
|
101 |
def on_option_change():
|
102 |
|
103 |
text = st.session_state.my_choice
|
104 |
+
st.info("Preselected text chosen:" + text)
|
105 |
run_test(text,st.session_state['top_k'],st.session_state['model_name'])
|
106 |
|
107 |
def on_results_count_change():
|