leavoigt commited on
Commit
7240967
·
1 Parent(s): a3014e0

Update utils/target_classifier.py

Browse files
Files changed (1) hide show
  1. utils/target_classifier.py +4 -27
utils/target_classifier.py CHANGED
@@ -37,7 +37,7 @@ def get_target_labels(preds):
37
  index_of_one = ele.index(1)
38
  except ValueError:
39
  index_of_one = "NA"
40
- st.write(index_of_one)
41
  # Retrieve the name of the label (if no prediction made = NA)
42
  if index_of_one != "NA":
43
  name = label_dict[index_of_one]
@@ -107,42 +107,19 @@ def target_classification(haystack_doc:pd.DataFrame,
107
  logging.info("Working on target/action identification")
108
 
109
  haystack_doc['Target Label'] = 'NA'
110
- st.write("haystack_doc")
111
- st.write(haystack_doc)
112
-
113
  if not classifier_model:
114
 
115
- st.write("No classifier_model")
116
-
117
  classifier_model = st.session_state['target_classifier']
118
- st.write("classifier model defined")
119
 
120
  # Get predictions
121
  predictions = classifier_model(list(haystack_doc.text))
122
- st.write("predictions made")
123
- st.write(predictions)
124
  # Get labels for predictions
125
  pred_labels = get_target_labels(predictions)
126
- st.write("pred_labels")
127
- st.write(pred_labels)
128
  # Save labels
129
  haystack_doc['Target Label'] = pred_labels
130
 
131
  return haystack_doc
132
- # logging.info("Working on action/target extraction")
133
- # if not classifier_model:
134
- # # classifier_model = st.session_state['target_classifier']
135
-
136
- # # results = classifier_model(list(haystack_doc.text))
137
- # # labels_= [(l[0]['label'],
138
- # # l[0]['score']) for l in results]
139
-
140
-
141
- # # df1 = DataFrame(labels_, columns=["Target Label","Target Score"])
142
- # # df = pd.concat([haystack_doc,df1],axis=1)
143
-
144
- # # df = df.sort_values(by="Target Score", ascending=False).reset_index(drop=True)
145
- # # df['Target Score'] = df['Target Score'].round(2)
146
- # # df.index += 1
147
- # # # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
148
 
 
37
  index_of_one = ele.index(1)
38
  except ValueError:
39
  index_of_one = "NA"
40
+
41
  # Retrieve the name of the label (if no prediction made = NA)
42
  if index_of_one != "NA":
43
  name = label_dict[index_of_one]
 
107
  logging.info("Working on target/action identification")
108
 
109
  haystack_doc['Target Label'] = 'NA'
110
+
 
 
111
  if not classifier_model:
112
 
 
 
113
  classifier_model = st.session_state['target_classifier']
 
114
 
115
  # Get predictions
116
  predictions = classifier_model(list(haystack_doc.text))
117
+
 
118
  # Get labels for predictions
119
  pred_labels = get_target_labels(predictions)
120
+
 
121
  # Save labels
122
  haystack_doc['Target Label'] = pred_labels
123
 
124
  return haystack_doc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125