leavoigt commited on
Commit
cc0fb09
1 Parent(s): 47756f1

Update utils/target_classifier.py

Browse files
Files changed (1) hide show
  1. utils/target_classifier.py +99 -99
utils/target_classifier.py CHANGED
@@ -1,127 +1,127 @@
1
- # from typing import List, Tuple
2
- # from typing_extensions import Literal
3
- # import logging
4
- # import pandas as pd
5
- # from pandas import DataFrame, Series
6
- # from utils.config import getconfig
7
- # from utils.preprocessing import processingpipeline
8
- # import streamlit as st
9
- # from transformers import pipeline
10
 
11
- # ## Labels dictionary ###
12
- # _lab_dict = {
13
- # '0':'NO',
14
- # '1':'YES',
15
- # }
16
 
17
- # def get_target_labels(preds):
18
 
19
- # """
20
- # Function that takes the numerical predictions as an input and returns a list of the labels.
21
 
22
- # """
23
 
24
- # # Get label names
25
- # preds_list = preds.tolist()
26
 
27
- # predictions_names=[]
28
 
29
- # # loop through each prediction
30
- # for ele in preds_list:
31
 
32
- # # see if there is a value 1 and retrieve index
33
- # try:
34
- # index_of_one = ele.index(1)
35
- # except ValueError:
36
- # index_of_one = "NA"
37
 
38
- # # Retrieve the name of the label (if no prediction made = NA)
39
- # if index_of_one != "NA":
40
- # name = label_dict[index_of_one]
41
- # else:
42
- # name = "Other"
43
 
44
- # # Append name to list
45
- # predictions_names.append(name)
46
 
47
- # return predictions_names
48
 
49
- # @st.cache_resource
50
- # def load_targetClassifier(config_file:str = None, classifier_name:str = None):
51
- # """
52
- # loads the document classifier using haystack, where the name/path of model
53
- # in HF-hub as string is used to fetch the model object.Either configfile or
54
- # model should be passed.
55
- # 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
56
- # 2. https://docs.haystack.deepset.ai/docs/document_classifier
57
- # Params
58
- # --------
59
- # config_file: config file path from which to read the model name
60
- # classifier_name: if modelname is passed, it takes a priority if not \
61
- # found then will look for configfile, else raise error.
62
- # Return: document classifier model
63
- # """
64
- # if not classifier_name:
65
- # if not config_file:
66
- # logging.warning("Pass either model name or config file")
67
- # return
68
- # else:
69
- # config = getconfig(config_file)
70
- # classifier_name = config.get('target','MODEL')
71
 
72
- # logging.info("Loading classifier")
73
 
74
- # doc_classifier = pipeline("text-classification",
75
- # model=classifier_name,
76
- # top_k =1)
77
 
78
- # return doc_classifier
79
 
80
 
81
- # @st.cache_data
82
- # def target_classification(haystack_doc:pd.DataFrame,
83
- # threshold:float = 0.5,
84
- # classifier_model:pipeline= None
85
- # )->Tuple[DataFrame,Series]:
86
- # """
87
- # Text-Classification on the list of texts provided. Classifier provides the
88
- # most appropriate label for each text. There labels indicate whether the paragraph
89
- # references a specific action, target or measure in the paragraph.
90
- # ---------
91
- # haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
92
- # contains the list of paragraphs in different format,here the list of
93
- # Haystack Documents is used.
94
- # threshold: threshold value for the model to keep the results from classifier
95
- # classifiermodel: you can pass the classifier model directly,which takes priority
96
- # however if not then looks for model in streamlit session.
97
- # In case of streamlit avoid passing the model directly.
98
- # Returns
99
- # ----------
100
- # df: Dataframe with two columns['SDG:int', 'text']
101
- # x: Series object with the unique SDG covered in the document uploaded and
102
- # the number of times it is covered/discussed/count_of_paragraphs.
103
- # """
104
 
105
- # logging.info("Working on target/action identification")
106
 
107
- # haystack_doc['Vulnerability Label'] = 'NA'
108
 
109
- # if not classifier_model:
110
 
111
- # classifier_model = st.session_state['target_classifier']
112
 
113
- # # Get predictions
114
- # predictions = classifier_model(list(haystack_doc.text))
115
 
116
- # # Get labels for predictions
117
- # pred_labels = getlabels(predictions)
118
 
119
- # # Save labels
120
- # haystack_doc['Target Label'] = pred_labels
121
 
122
 
123
- # # logging.info("Working on action/target extraction")
124
- # # if not classifier_model:
125
  # # classifier_model = st.session_state['target_classifier']
126
 
127
  # # results = classifier_model(list(haystack_doc.text))
@@ -137,4 +137,4 @@
137
  # # df.index += 1
138
  # # # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
139
 
140
- # return haystack_doc
 
1
+ from typing import List, Tuple
2
+ from typing_extensions import Literal
3
+ import logging
4
+ import pandas as pd
5
+ from pandas import DataFrame, Series
6
+ from utils.config import getconfig
7
+ from utils.preprocessing import processingpipeline
8
+ import streamlit as st
9
+ from transformers import pipeline
10
 
11
+ ## Labels dictionary ###
12
+ _lab_dict = {
13
+ '0':'NO',
14
+ '1':'YES',
15
+ }
16
 
17
+ def get_target_labels(preds):
18
 
19
+ """
20
+ Function that takes the numerical predictions as an input and returns a list of the labels.
21
 
22
+ """
23
 
24
+ # Get label names
25
+ preds_list = preds.tolist()
26
 
27
+ predictions_names=[]
28
 
29
+ # loop through each prediction
30
+ for ele in preds_list:
31
 
32
+ # see if there is a value 1 and retrieve index
33
+ try:
34
+ index_of_one = ele.index(1)
35
+ except ValueError:
36
+ index_of_one = "NA"
37
 
38
+ # Retrieve the name of the label (if no prediction made = NA)
39
+ if index_of_one != "NA":
40
+ name = label_dict[index_of_one]
41
+ else:
42
+ name = "Other"
43
 
44
+ # Append name to list
45
+ predictions_names.append(name)
46
 
47
+ return predictions_names
48
 
49
+ @st.cache_resource
50
+ def load_targetClassifier(config_file:str = None, classifier_name:str = None):
51
+ """
52
+ loads the document classifier using haystack, where the name/path of model
53
+ in HF-hub as string is used to fetch the model object.Either configfile or
54
+ model should be passed.
55
+ 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
56
+ 2. https://docs.haystack.deepset.ai/docs/document_classifier
57
+ Params
58
+ --------
59
+ config_file: config file path from which to read the model name
60
+ classifier_name: if modelname is passed, it takes a priority if not \
61
+ found then will look for configfile, else raise error.
62
+ Return: document classifier model
63
+ """
64
+ if not classifier_name:
65
+ if not config_file:
66
+ logging.warning("Pass either model name or config file")
67
+ return
68
+ else:
69
+ config = getconfig(config_file)
70
+ classifier_name = config.get('target','MODEL')
71
 
72
+ logging.info("Loading classifier")
73
 
74
+ doc_classifier = pipeline("text-classification",
75
+ model=classifier_name,
76
+ top_k =1)
77
 
78
+ return doc_classifier
79
 
80
 
81
+ @st.cache_data
82
+ def target_classification(haystack_doc:pd.DataFrame,
83
+ threshold:float = 0.5,
84
+ classifier_model:pipeline= None
85
+ )->Tuple[DataFrame,Series]:
86
+ """
87
+ Text-Classification on the list of texts provided. Classifier provides the
88
+ most appropriate label for each text. There labels indicate whether the paragraph
89
+ references a specific action, target or measure in the paragraph.
90
+ ---------
91
+ haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
92
+ contains the list of paragraphs in different format,here the list of
93
+ Haystack Documents is used.
94
+ threshold: threshold value for the model to keep the results from classifier
95
+ classifiermodel: you can pass the classifier model directly,which takes priority
96
+ however if not then looks for model in streamlit session.
97
+ In case of streamlit avoid passing the model directly.
98
+ Returns
99
+ ----------
100
+ df: Dataframe with two columns['SDG:int', 'text']
101
+ x: Series object with the unique SDG covered in the document uploaded and
102
+ the number of times it is covered/discussed/count_of_paragraphs.
103
+ """
104
 
105
+ logging.info("Working on target/action identification")
106
 
107
+ haystack_doc['Target Label'] = 'NA'
108
 
109
+ if not classifier_model:
110
 
111
+ classifier_model = st.session_state['target_classifier']
112
 
113
+ # Get predictions
114
+ predictions = classifier_model(list(haystack_doc.text))
115
 
116
+ # Get labels for predictions
117
+ pred_labels = getlabels(predictions)
118
 
119
+ # Save labels
120
+ haystack_doc['Target Label'] = pred_labels
121
 
122
 
123
+ # logging.info("Working on action/target extraction")
124
+ # if not classifier_model:
125
  # # classifier_model = st.session_state['target_classifier']
126
 
127
  # # results = classifier_model(list(haystack_doc.text))
 
137
  # # df.index += 1
138
  # # # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
139
 
140
+ return haystack_doc