Spaces:
Runtime error
Runtime error
gchhablani
commited on
Commit
•
b5bd188
1
Parent(s):
69e32d1
Fix prediction function
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ def load_model(ckpt):
|
|
21 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
22 |
|
23 |
@st.cache(persist=True)
|
24 |
-
def predict(
|
25 |
return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0])
|
26 |
|
27 |
def softmax(logits):
|
@@ -125,7 +125,7 @@ state.answer_lang_id = col2.selectbox('Answer Language', index=options.index(sta
|
|
125 |
with st.spinner('Loading model...'):
|
126 |
model = load_model(checkpoints[0])
|
127 |
with st.spinner('Predicting...'):
|
128 |
-
logits = predict(
|
129 |
logits = softmax(logits)
|
130 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
131 |
translated_labels = translate_labels(labels, state.answer_lang_id)
|
|
|
21 |
return FlaxCLIPVisionBertForSequenceClassification.from_pretrained(ckpt)
|
22 |
|
23 |
@st.cache(persist=True)
|
24 |
+
def predict(transformed_image, question_inputs):
|
25 |
return np.array(model(pixel_values = transformed_image, **question_inputs)[0][0])
|
26 |
|
27 |
def softmax(logits):
|
|
|
125 |
with st.spinner('Loading model...'):
|
126 |
model = load_model(checkpoints[0])
|
127 |
with st.spinner('Predicting...'):
|
128 |
+
logits = predict(transformed_image, dict(question_inputs))
|
129 |
logits = softmax(logits)
|
130 |
labels, values = get_top_5_predictions(logits, answer_reverse_mapping)
|
131 |
translated_labels = translate_labels(labels, state.answer_lang_id)
|