ajitrajasekharan commited on
Commit
a03c359
1 Parent(s): 07c19db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -33,6 +33,9 @@ def decode(tokenizer, pred_idx, top_clean):
33
  return '\n'.join(tokens[:top_clean])
34
 
35
  def encode(tokenizer, text_sentence, add_special_tokens=True):
 
 
 
36
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
37
  # if <mask> is the last token, append a "." so that models dont predict punctuation.
38
  #if tokenizer.mask_token == text_sentence.split()[-1]:
@@ -47,6 +50,9 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
47
  return input_ids, mask_idx,tokenized_text
48
 
49
  def get_all_predictions(text_sentence, model_name,top_clean=5):
 
 
 
50
  # ========================= BERT =================================
51
  input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
52
 
@@ -95,7 +101,7 @@ def on_text_change():
95
  def on_option_change():
96
 
97
  text = st.session_state.my_choice
98
- st.info("Preselected text chosen")
99
  run_test(text,st.session_state['top_k'],st.session_state['model_name'])
100
 
101
  def on_results_count_change():
 
33
  return '\n'.join(tokens[:top_clean])
34
 
35
  def encode(tokenizer, text_sentence, add_special_tokens=True):
36
+ bert_tokenizer = st.session_state['bert_tokenizer']
37
+ bert_model = st.session_state['bert_model']
38
+
39
  text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
40
  # if <mask> is the last token, append a "." so that models dont predict punctuation.
41
  #if tokenizer.mask_token == text_sentence.split()[-1]:
 
50
  return input_ids, mask_idx,tokenized_text
51
 
52
  def get_all_predictions(text_sentence, model_name,top_clean=5):
53
+ bert_tokenizer = st.session_state['bert_tokenizer']
54
+ bert_model = st.session_state['bert_model']
55
+
56
  # ========================= BERT =================================
57
  input_ids, mask_idx,tokenized_text = encode(bert_tokenizer, text_sentence)
58
 
 
101
  def on_option_change():
102
 
103
  text = st.session_state.my_choice
104
+ st.info("Preselected text chosen:" + text)
105
  run_test(text,st.session_state['top_k'],st.session_state['model_name'])
106
 
107
  def on_results_count_change():