leavoigt commited on
Commit
b90fe6b
1 Parent(s): 2312d16

Create group_classifier.py

Browse files
Files changed (1) hide show
  1. utils/group_classifier.py +93 -0
utils/group_classifier.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from typing_extensions import Literal
3
+ import logging
4
+ import pandas as pd
5
+ from pandas import DataFrame, Series
6
+ from utils.config import getconfig
7
+ from utils.preprocessing import processingpipeline
8
+ import streamlit as st
9
+ from transformers import pipeline
10
+
11
+ ## Labels dictionary ###
12
+ _lab_dict = {
13
+ 0: 'Children and Youth',
14
+ 1: 'Informal sector workers',
15
+ 2: 'Other',
16
+ 3: 'Rural populations',
17
+ 4: 'Sexual minorities (LGBTQI+)',
18
+ 5: 'Urban populations',
19
+ 6: 'Women'}
20
+
21
+ @st.cache_resource
22
+ def load_targetClassifier(config_file:str = None, classifier_name:str = None):
23
+ """
24
+ loads the document classifier using haystack, where the name/path of model
25
+ in HF-hub as string is used to fetch the model object.Either configfile or
26
+ model should be passed.
27
+ 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
28
+ 2. https://docs.haystack.deepset.ai/docs/document_classifier
29
+ Params
30
+ --------
31
+ config_file: config file path from which to read the model name
32
+ classifier_name: if modelname is passed, it takes a priority if not \
33
+ found then will look for configfile, else raise error.
34
+ Return: document classifier model
35
+ """
36
+ if not classifier_name:
37
+ if not config_file:
38
+ logging.warning("Pass either model name or config file")
39
+ return
40
+ else:
41
+ config = getconfig(config_file)
42
+ classifier_name = config.get('target','MODEL')
43
+
44
+ logging.info("Loading classifier")
45
+
46
+ doc_classifier = pipeline("text-classification",
47
+ model=classifier_name,
48
+ top_k =1)
49
+
50
+ return doc_classifier
51
+
52
+
53
+ @st.cache_data
54
+ def target_classification(haystack_doc:pd.DataFrame,
55
+ threshold:float = 0.5,
56
+ classifier_model:pipeline= None
57
+ )->Tuple[DataFrame,Series]:
58
+ """
59
+ Text-Classification on the list of texts provided. Classifier provides the
60
+ most appropriate label for each text. these labels are in terms of if text
61
+ belongs to which particular Sustainable Devleopment Goal (SDG).
62
+ Params
63
+ ---------
64
+ haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
65
+ contains the list of paragraphs in different format,here the list of
66
+ Haystack Documents is used.
67
+ threshold: threshold value for the model to keep the results from classifier
68
+ classifiermodel: you can pass the classifier model directly,which takes priority
69
+ however if not then looks for model in streamlit session.
70
+ In case of streamlit avoid passing the model directly.
71
+ Returns
72
+ ----------
73
+ df: Dataframe with two columns['SDG:int', 'text']
74
+ x: Series object with the unique SDG covered in the document uploaded and
75
+ the number of times it is covered/discussed/count_of_paragraphs.
76
+ """
77
+ logging.info("Working on Target Extraction")
78
+ if not classifier_model:
79
+ classifier_model = st.session_state['target_classifier']
80
+
81
+ results = classifier_model(list(haystack_doc.text))
82
+ labels_= [(l[0]['label'],
83
+ l[0]['score']) for l in results]
84
+
85
+
86
+ df1 = DataFrame(labels_, columns=["Target Label","Relevancy"])
87
+ df = pd.concat([haystack_doc,df1],axis=1)
88
+
89
+ df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
90
+ df.index += 1
91
+ df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
92
+
93
+ return df