leavoigt commited on
Commit
5064543
1 Parent(s): 1fc7d6a

Update utils/group_classifier.py

Browse files
Files changed (1) hide show
  1. utils/group_classifier.py +6 -6
utils/group_classifier.py CHANGED
@@ -19,7 +19,7 @@ _lab_dict = {
19
  6: 'Women'}
20
 
21
  @st.cache_resource
22
- def load_targetClassifier(config_file:str = None, classifier_name:str = None):
23
  """
24
  loads the document classifier using haystack, where the name/path of model
25
  in HF-hub as string is used to fetch the model object.Either configfile or
@@ -51,7 +51,7 @@ def load_targetClassifier(config_file:str = None, classifier_name:str = None):
51
 
52
 
53
  @st.cache_data
54
- def target_classification(haystack_doc:pd.DataFrame,
55
  threshold:float = 0.5,
56
  classifier_model:pipeline= None
57
  )->Tuple[DataFrame,Series]:
@@ -74,20 +74,20 @@ def target_classification(haystack_doc:pd.DataFrame,
74
  x: Series object with the unique SDG covered in the document uploaded and
75
  the number of times it is covered/discussed/count_of_paragraphs.
76
  """
77
- logging.info("Working on Target Extraction")
78
  if not classifier_model:
79
- classifier_model = st.session_state['target_classifier']
80
 
81
  results = classifier_model(list(haystack_doc.text))
82
  labels_= [(l[0]['label'],
83
  l[0]['score']) for l in results]
84
 
85
 
86
- df1 = DataFrame(labels_, columns=["Target Label","Relevancy"])
87
  df = pd.concat([haystack_doc,df1],axis=1)
88
 
89
  df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
90
  df.index += 1
91
- df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
92
 
93
  return df
 
19
  6: 'Women'}
20
 
21
  @st.cache_resource
22
+ def load_groupClassifier(config_file:str = None, classifier_name:str = None):
23
  """
24
  loads the document classifier using haystack, where the name/path of model
25
  in HF-hub as string is used to fetch the model object.Either configfile or
 
51
 
52
 
53
  @st.cache_data
54
+ def group_classification(haystack_doc:pd.DataFrame,
55
  threshold:float = 0.5,
56
  classifier_model:pipeline= None
57
  )->Tuple[DataFrame,Series]:
 
74
  x: Series object with the unique SDG covered in the document uploaded and
75
  the number of times it is covered/discussed/count_of_paragraphs.
76
  """
77
+ logging.info("Working on Group Extraction")
78
  if not classifier_model:
79
+ classifier_model = st.session_state['group_classifier']
80
 
81
  results = classifier_model(list(haystack_doc.text))
82
  labels_= [(l[0]['label'],
83
  l[0]['score']) for l in results]
84
 
85
 
86
+ df1 = DataFrame(labels_, columns=["Group Label","Relevancy"])
87
  df = pd.concat([haystack_doc,df1],axis=1)
88
 
89
  df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
90
  df.index += 1
91
+ df['Label_def'] = df['Group Label'].apply(lambda i: _lab_dict[i])
92
 
93
  return df