leavoigt commited on
Commit
c8b94cd
1 Parent(s): a29c372

Update utils/sdg_classifier.py

Browse files
Files changed (1) hide show
  1. utils/sdg_classifier.py +31 -31
utils/sdg_classifier.py CHANGED
@@ -14,27 +14,27 @@ except ImportError:
14
  logging.info("Streamlit not installed")
15
 
16
  ## Labels dictionary ###
17
- _lab_dict = {0: 'no_cat',
18
- 1:'SDG 1 - No poverty',
19
- 2:'SDG 2 - Zero hunger',
20
- 3:'SDG 3 - Good health and well-being',
21
- 4:'SDG 4 - Quality education',
22
- 5:'SDG 5 - Gender equality',
23
- 6:'SDG 6 - Clean water and sanitation',
24
- 7:'SDG 7 - Affordable and clean energy',
25
- 8:'SDG 8 - Decent work and economic growth',
26
- 9:'SDG 9 - Industry, Innovation and Infrastructure',
27
- 10:'SDG 10 - Reduced inequality',
28
- 11:'SDG 11 - Sustainable cities and communities',
29
- 12:'SDG 12 - Responsible consumption and production',
30
- 13:'SDG 13 - Climate action',
31
- 14:'SDG 14 - Life below water',
32
- 15:'SDG 15 - Life on land',
33
- 16:'SDG 16 - Peace, justice and strong institutions',
34
- 17:'SDG 17 - Partnership for the goals',}
35
 
36
  @st.cache(allow_output_mutation=True)
37
- def load_sdgClassifier(config_file:str = None, classifier_name:str = None):
38
  """
39
  loads the document classifier using haystack, where the name/path of model
40
  in HF-hub as string is used to fetch the model object.Either configfile or
@@ -57,7 +57,7 @@ def load_sdgClassifier(config_file:str = None, classifier_name:str = None):
57
  return
58
  else:
59
  config = getconfig(config_file)
60
- classifier_name = config.get('sdg','MODEL')
61
 
62
  logging.info("Loading classifier")
63
  doc_classifier = TransformersDocumentClassifier(
@@ -68,7 +68,7 @@ def load_sdgClassifier(config_file:str = None, classifier_name:str = None):
68
 
69
 
70
  @st.cache(allow_output_mutation=True)
71
- def sdg_classification(haystack_doc:List[Document],
72
  threshold:float = 0.8,
73
  classifier_model:TransformersDocumentClassifier= None
74
  )->Tuple[DataFrame,Series]:
@@ -95,10 +95,10 @@ def sdg_classification(haystack_doc:List[Document],
95
  the number of times it is covered/discussed/count_of_paragraphs.
96
 
97
  """
98
- logging.info("Working on SDG Classification")
99
  if not classifier_model:
100
  if check_streamlit():
101
- classifier_model = st.session_state['sdg_classifier']
102
  else:
103
  logging.warning("No streamlit envinornment found, Pass the classifier")
104
  return
@@ -109,23 +109,23 @@ def sdg_classification(haystack_doc:List[Document],
109
  labels_= [(l.meta['classification']['label'],
110
  l.meta['classification']['score'],l.content,) for l in results]
111
 
112
- df = DataFrame(labels_, columns=["SDG","Relevancy","text"])
113
 
114
  df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
115
  df.index += 1
116
  df =df[df['Relevancy']>threshold]
117
 
118
  # creating the dataframe for value counts of SDG, along with 'title' of SDGs
119
- x = df['SDG'].value_counts()
120
  x = x.rename('count')
121
- x = x.rename_axis('SDG').reset_index()
122
- x["SDG"] = pd.to_numeric(x["SDG"])
123
  x = x.sort_values(by=['count'], ascending=False)
124
- x['SDG_name'] = x['SDG'].apply(lambda x: _lab_dict[x])
125
- x['SDG_Num'] = x['SDG'].apply(lambda x: "SDG "+str(x))
126
 
127
- df['SDG'] = pd.to_numeric(df['SDG'])
128
- df = df.sort_values('SDG')
129
 
130
  return df, x
131
 
 
14
  logging.info("Streamlit not installed")
15
 
16
  ## Labels dictionary ###
17
+ _lab_dict = {0: 'Agricultural communities',
18
+ 1: 'Children',
19
+ 2: 'Coastal communities',
20
+ 3: 'Ethnic, racial or other minorities',
21
+ 4: 'Fishery communities',
22
+ 5: 'Informal sector workers',
23
+ 6: 'Members of indigenous and local communities',
24
+ 7: 'Migrants and displaced persons',
25
+ 8: 'Older persons',
26
+ 9: 'Other',
27
+ 10: 'Persons living in poverty',
28
+ 11: 'Persons with disabilities',
29
+ 12: 'Persons with pre-existing health conditions',
30
+ 13: 'Residents of drought-prone regions',
31
+ 14: 'Rural populations',
32
+ 15: 'Sexual minorities (LGBTQI+)',
33
+ 16: 'Urban populations',
34
+ 17: 'Women and other genders'}
35
 
36
  @st.cache(allow_output_mutation=True)
37
+ def load_Classifier(config_file:str = None, classifier_name:str = None):
38
  """
39
  loads the document classifier using haystack, where the name/path of model
40
  in HF-hub as string is used to fetch the model object.Either configfile or
 
57
  return
58
  else:
59
  config = getconfig(config_file)
60
+ classifier_name = config.get('vulnerability','MODEL')
61
 
62
  logging.info("Loading classifier")
63
  doc_classifier = TransformersDocumentClassifier(
 
68
 
69
 
70
  @st.cache(allow_output_mutation=True)
71
+ def classification(haystack_doc:List[Document],
72
  threshold:float = 0.8,
73
  classifier_model:TransformersDocumentClassifier= None
74
  )->Tuple[DataFrame,Series]:
 
95
  the number of times it is covered/discussed/count_of_paragraphs.
96
 
97
  """
98
+ logging.info("Working on Vulnerability Classification")
99
  if not classifier_model:
100
  if check_streamlit():
101
+ classifier_model = st.session_state['vulnerability_classifier']
102
  else:
103
  logging.warning("No streamlit envinornment found, Pass the classifier")
104
  return
 
109
  labels_= [(l.meta['classification']['label'],
110
  l.meta['classification']['score'],l.content,) for l in results]
111
 
112
+ df = DataFrame(labels_, columns=["Vulnerability","Relevancy","text"])
113
 
114
  df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
115
  df.index += 1
116
  df =df[df['Relevancy']>threshold]
117
 
118
  # creating the dataframe for value counts of SDG, along with 'title' of SDGs
119
+ x = df['Vulnerability'].value_counts()
120
  x = x.rename('count')
121
+ x = x.rename_axis('Vulnerability').reset_index()
122
+ x["Vulnerability"] = pd.to_numeric(x["Vulnerability"])
123
  x = x.sort_values(by=['count'], ascending=False)
124
+ x['SDG_name'] = x['Vulnerability'].apply(lambda x: _lab_dict[x])
125
+ x['SDG_Num'] = x['Vulnerability'].apply(lambda x: "Vulnerability "+str(x))
126
 
127
+ df['Vulnerability'] = pd.to_numeric(df['Vulnerability'])
128
+ df = df.sort_values('Vulnerability')
129
 
130
  return df, x
131