Update utils/target_classifier.py
Browse files
utils/target_classifier.py
CHANGED
@@ -107,12 +107,12 @@ def target_classification(haystack_doc:pd.DataFrame,
|
|
107 |
|
108 |
haystack_doc['Target Label'] = 'NA'
|
109 |
|
110 |
-
if not
|
111 |
|
112 |
-
|
113 |
|
114 |
# Get predictions
|
115 |
-
predictions =
|
116 |
st.write('predictions')
|
117 |
st.write(predictions[:10])
|
118 |
|
|
|
107 |
|
108 |
haystack_doc['Target Label'] = 'NA'
|
109 |
|
110 |
+
if not target_classifier_model:
|
111 |
|
112 |
+
target_classifier_model = st.session_state['target_classifier']
|
113 |
|
114 |
# Get predictions
|
115 |
+
predictions = target_classifier_model(list(haystack_doc.text))
|
116 |
st.write('predictions')
|
117 |
st.write(predictions[:10])
|
118 |
|