ajitrajasekharan commited on
Commit
f8dc81b
1 Parent(s): d503dfd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -46,7 +46,7 @@ def encode(tokenizer, text_sentence, add_special_tokens=True):
46
  mask_idx = 0
47
  return input_ids, mask_idx
48
 
49
- def get_all_predictions(text_sentence, top_clean=5):
50
  # ========================= BERT =================================
51
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
52
  with torch.no_grad():
@@ -55,20 +55,20 @@ def get_all_predictions(text_sentence, top_clean=5):
55
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
56
 
57
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
58
- return {'Input sentence':text_sentence,'Masked position': bert,'[CLS]':cls}
59
  else:
60
- return {'Input sentence':text_sentence,'[CLS]':cls}
61
 
62
- def get_bert_prediction(input_text,top_k):
63
  try:
64
  #input_text += ' <mask>'
65
- res = get_all_predictions(input_text, top_clean=int(top_k))
66
  return res
67
  except Exception as error:
68
  pass
69
 
70
 
71
- def run_test(sent,top_k):
72
  start = None
73
  global bert_tokenizer
74
  global bert_model
@@ -77,7 +77,7 @@ def run_test(sent,top_k):
77
  with st.spinner("Computing"):
78
  start = time.time()
79
  try:
80
- res = get_bert_prediction(sent,top_k)
81
  st.caption("Results in JSON")
82
  st.json(res)
83
 
@@ -115,13 +115,13 @@ try:
115
  custom_model_name = st.text_input("Model not listed on left? Type the model name (fill-mask models only)", "")
116
  if (len(custom_model_name) > 0):
117
  model_name = custom_model_name
118
- st.info("Custom model selected:" + model_name)
119
  bert_tokenizer, bert_model = load_bert_model(model_name)
120
  if len(input_text) > 0:
121
- run_test(input_text,top_k)
122
  else:
123
  if len(option) > 0:
124
- run_test(option,top_k)
125
  if (bert_tokenizer is None):
126
  bert_tokenizer, bert_model = load_bert_model(model_name)
127
 
 
46
  mask_idx = 0
47
  return input_ids, mask_idx
48
 
49
+ def get_all_predictions(text_sentence, model_name,top_clean=5):
50
  # ========================= BERT =================================
51
  input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
52
  with torch.no_grad():
 
55
  cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*5).indices.tolist(), top_clean)
56
 
57
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
58
+ return {'Input sentence':text_sentence,'Model':model_name,'Masked position': bert,'[CLS]':cls}
59
  else:
60
+ return {'Input sentence':text_sentence,'Model':model_name,'[CLS]':cls}
61
 
62
+ def get_bert_prediction(input_text,top_k,model_name):
63
  try:
64
  #input_text += ' <mask>'
65
+ res = get_all_predictions(input_text,model_name, top_clean=int(top_k))
66
  return res
67
  except Exception as error:
68
  pass
69
 
70
 
71
+ def run_test(sent,top_k,model_name):
72
  start = None
73
  global bert_tokenizer
74
  global bert_model
 
77
  with st.spinner("Computing"):
78
  start = time.time()
79
  try:
80
+ res = get_bert_prediction(sent,top_k,model_name)
81
  st.caption("Results in JSON")
82
  st.json(res)
83
 
 
115
  custom_model_name = st.text_input("Model not listed on left? Type the model name (fill-mask models only)", "")
116
  if (len(custom_model_name) > 0):
117
  model_name = custom_model_name
118
+ st.info("Custom model selected: " + model_name)
119
  bert_tokenizer, bert_model = load_bert_model(model_name)
120
  if len(input_text) > 0:
121
+ run_test(input_text,top_k,model_name)
122
  else:
123
  if len(option) > 0:
124
+ run_test(option,top_k,model_name)
125
  if (bert_tokenizer is None):
126
  bert_tokenizer, bert_model = load_bert_model(model_name)
127