gchhablani commited on
Commit
b5bd188
1 Parent(s): 69e32d1

Fix prediction function

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -21,7 +21,7 @@ def load_model(ckpt):
21
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
22
 
23
  @st.cache(persist=True)
24
- def predict(model, transformed_image, question_inputs):
25
  return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0])
26
 
27
  def softmax(logits):
@@ -125,7 +125,7 @@ state.answer_lang_id = col2.selectbox('Answer Language', index=options.index(sta
125
  with st.spinner('Loading model...'):
126
  model = load_model(checkpoints[0])
127
  with st.spinner('Predicting...'):
128
- logits = predict(model, transformed_image, dict(question_inputs))
129
  logits = softmax(logits)
130
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
131
  translated_labels = translate_labels(labels, state.answer_lang_id)
 
21
  return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
22
 
23
  @st.cache(persist=True)
24
+ def predict(transformed_image, question_inputs):
25
  return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0])
26
 
27
  def softmax(logits):
 
125
  with st.spinner('Loading model...'):
126
  model = load_model(checkpoints[0])
127
  with st.spinner('Predicting...'):
128
+ logits = predict(transformed_image, dict(question_inputs))
129
  logits = softmax(logits)
130
  labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
131
  translated_labels = translate_labels(labels, state.answer_lang_id)