leavoigt commited on
Commit
3efd370
1 Parent(s): 27d7463

Update utils/target_classifier.py

Browse files
Files changed (1) hide show
  1. utils/target_classifier.py +109 -109
utils/target_classifier.py CHANGED
@@ -1,140 +1,140 @@
1
- from typing import List, Tuple
2
- from typing_extensions import Literal
3
- import logging
4
- import pandas as pd
5
- from pandas import DataFrame, Series
6
- from utils.config import getconfig
7
- from utils.preprocessing import processingpipeline
8
- import streamlit as st
9
- from transformers import pipeline
10
 
11
- ## Labels dictionary ###
12
- _lab_dict = {
13
- '0':'NO',
14
- '1':'YES',
15
- }
16
 
17
- def get_target_labels(preds):
18
 
19
- """
20
- Function that takes the numerical predictions as an input and returns a list of the labels.
21
 
22
- """
23
 
24
- # Get label names
25
- preds_list = preds.tolist()
26
 
27
- predictions_names=[]
28
 
29
- # loop through each prediction
30
- for ele in preds_list:
31
 
32
- # see if there is a value 1 and retrieve index
33
- try:
34
- index_of_one = ele.index(1)
35
- except ValueError:
36
- index_of_one = "NA"
37
 
38
- # Retrieve the name of the label (if no prediction made = NA)
39
- if index_of_one != "NA":
40
- name = label_dict[index_of_one]
41
- else:
42
- name = "Other"
43
 
44
- # Append name to list
45
- predictions_names.append(name)
46
 
47
- return predictions_names
48
 
49
- @st.cache_resource
50
- def load_targetClassifier(config_file:str = None, classifier_name:str = None):
51
- """
52
- loads the document classifier using haystack, where the name/path of model
53
- in HF-hub as string is used to fetch the model object.Either configfile or
54
- model should be passed.
55
- 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
56
- 2. https://docs.haystack.deepset.ai/docs/document_classifier
57
- Params
58
- --------
59
- config_file: config file path from which to read the model name
60
- classifier_name: if modelname is passed, it takes a priority if not \
61
- found then will look for configfile, else raise error.
62
- Return: document classifier model
63
- """
64
- if not classifier_name:
65
- if not config_file:
66
- logging.warning("Pass either model name or config file")
67
- return
68
- else:
69
- config = getconfig(config_file)
70
- classifier_name = config.get('target','MODEL')
71
 
72
- logging.info("Loading classifier")
73
 
74
- doc_classifier = pipeline("text-classification",
75
- model=classifier_name,
76
- top_k =1)
77
 
78
- return doc_classifier
79
 
80
 
81
- @st.cache_data
82
- def target_classification(haystack_doc:pd.DataFrame,
83
- threshold:float = 0.5,
84
- classifier_model:pipeline= None
85
- )->Tuple[DataFrame,Series]:
86
- """
87
- Text-Classification on the list of texts provided. Classifier provides the
88
- most appropriate label for each text. There labels indicate whether the paragraph
89
- references a specific action, target or measure in the paragraph.
90
- ---------
91
- haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
92
- contains the list of paragraphs in different format,here the list of
93
- Haystack Documents is used.
94
- threshold: threshold value for the model to keep the results from classifier
95
- classifiermodel: you can pass the classifier model directly,which takes priority
96
- however if not then looks for model in streamlit session.
97
- In case of streamlit avoid passing the model directly.
98
- Returns
99
- ----------
100
- df: Dataframe with two columns['SDG:int', 'text']
101
- x: Series object with the unique SDG covered in the document uploaded and
102
- the number of times it is covered/discussed/count_of_paragraphs.
103
- """
104
 
105
- logging.info("Working on target/action identification")
106
 
107
- haystack_doc['Vulnerability Label'] = 'NA'
108
 
109
- if not classifier_model:
110
 
111
- classifier_model = st.session_state['target_classifier']
112
 
113
- # Get predictions
114
- predictions = classifier_model(list(haystack_doc.text))
115
 
116
- # Get labels for predictions
117
- pred_labels = getlabels(predictions)
118
 
119
- # Save labels
120
- haystack_doc['Target Label'] = pred_labels
121
 
122
 
123
- # logging.info("Working on action/target extraction")
124
- # if not classifier_model:
125
- # classifier_model = st.session_state['target_classifier']
126
 
127
- # results = classifier_model(list(haystack_doc.text))
128
- # labels_= [(l[0]['label'],
129
- # l[0]['score']) for l in results]
130
 
131
 
132
- # df1 = DataFrame(labels_, columns=["Target Label","Target Score"])
133
- # df = pd.concat([haystack_doc,df1],axis=1)
134
 
135
- # df = df.sort_values(by="Target Score", ascending=False).reset_index(drop=True)
136
- # df['Target Score'] = df['Target Score'].round(2)
137
- # df.index += 1
138
- # # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
139
 
140
- return haystack_doc
 
1
+ # from typing import List, Tuple
2
+ # from typing_extensions import Literal
3
+ # import logging
4
+ # import pandas as pd
5
+ # from pandas import DataFrame, Series
6
+ # from utils.config import getconfig
7
+ # from utils.preprocessing import processingpipeline
8
+ # import streamlit as st
9
+ # from transformers import pipeline
10
 
11
+ # ## Labels dictionary ###
12
+ # _lab_dict = {
13
+ # '0':'NO',
14
+ # '1':'YES',
15
+ # }
16
 
17
+ # def get_target_labels(preds):
18
 
19
+ # """
20
+ # Function that takes the numerical predictions as an input and returns a list of the labels.
21
 
22
+ # """
23
 
24
+ # # Get label names
25
+ # preds_list = preds.tolist()
26
 
27
+ # predictions_names=[]
28
 
29
+ # # loop through each prediction
30
+ # for ele in preds_list:
31
 
32
+ # # see if there is a value 1 and retrieve index
33
+ # try:
34
+ # index_of_one = ele.index(1)
35
+ # except ValueError:
36
+ # index_of_one = "NA"
37
 
38
+ # # Retrieve the name of the label (if no prediction made = NA)
39
+ # if index_of_one != "NA":
40
+ # name = label_dict[index_of_one]
41
+ # else:
42
+ # name = "Other"
43
 
44
+ # # Append name to list
45
+ # predictions_names.append(name)
46
 
47
+ # return predictions_names
48
 
49
+ # @st.cache_resource
50
+ # def load_targetClassifier(config_file:str = None, classifier_name:str = None):
51
+ # """
52
+ # loads the document classifier using haystack, where the name/path of model
53
+ # in HF-hub as string is used to fetch the model object.Either configfile or
54
+ # model should be passed.
55
+ # 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
56
+ # 2. https://docs.haystack.deepset.ai/docs/document_classifier
57
+ # Params
58
+ # --------
59
+ # config_file: config file path from which to read the model name
60
+ # classifier_name: if modelname is passed, it takes a priority if not \
61
+ # found then will look for configfile, else raise error.
62
+ # Return: document classifier model
63
+ # """
64
+ # if not classifier_name:
65
+ # if not config_file:
66
+ # logging.warning("Pass either model name or config file")
67
+ # return
68
+ # else:
69
+ # config = getconfig(config_file)
70
+ # classifier_name = config.get('target','MODEL')
71
 
72
+ # logging.info("Loading classifier")
73
 
74
+ # doc_classifier = pipeline("text-classification",
75
+ # model=classifier_name,
76
+ # top_k =1)
77
 
78
+ # return doc_classifier
79
 
80
 
81
+ # @st.cache_data
82
+ # def target_classification(haystack_doc:pd.DataFrame,
83
+ # threshold:float = 0.5,
84
+ # classifier_model:pipeline= None
85
+ # )->Tuple[DataFrame,Series]:
86
+ # """
87
+ # Text-Classification on the list of texts provided. Classifier provides the
88
+ # most appropriate label for each text. There labels indicate whether the paragraph
89
+ # references a specific action, target or measure in the paragraph.
90
+ # ---------
91
+ # haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
92
+ # contains the list of paragraphs in different format,here the list of
93
+ # Haystack Documents is used.
94
+ # threshold: threshold value for the model to keep the results from classifier
95
+ # classifiermodel: you can pass the classifier model directly,which takes priority
96
+ # however if not then looks for model in streamlit session.
97
+ # In case of streamlit avoid passing the model directly.
98
+ # Returns
99
+ # ----------
100
+ # df: Dataframe with two columns['SDG:int', 'text']
101
+ # x: Series object with the unique SDG covered in the document uploaded and
102
+ # the number of times it is covered/discussed/count_of_paragraphs.
103
+ # """
104
 
105
+ # logging.info("Working on target/action identification")
106
 
107
+ # haystack_doc['Vulnerability Label'] = 'NA'
108
 
109
+ # if not classifier_model:
110
 
111
+ # classifier_model = st.session_state['target_classifier']
112
 
113
+ # # Get predictions
114
+ # predictions = classifier_model(list(haystack_doc.text))
115
 
116
+ # # Get labels for predictions
117
+ # pred_labels = getlabels(predictions)
118
 
119
+ # # Save labels
120
+ # haystack_doc['Target Label'] = pred_labels
121
 
122
 
123
+ # # logging.info("Working on action/target extraction")
124
+ # # if not classifier_model:
125
+ # # classifier_model = st.session_state['target_classifier']
126
 
127
+ # # results = classifier_model(list(haystack_doc.text))
128
+ # # labels_= [(l[0]['label'],
129
+ # # l[0]['score']) for l in results]
130
 
131
 
132
+ # # df1 = DataFrame(labels_, columns=["Target Label","Target Score"])
133
+ # # df = pd.concat([haystack_doc,df1],axis=1)
134
 
135
+ # # df = df.sort_values(by="Target Score", ascending=False).reset_index(drop=True)
136
+ # # df['Target Score'] = df['Target Score'].round(2)
137
+ # # df.index += 1
138
+ # # # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
139
 
140
+ # return haystack_doc