ppsingh commited on
Commit
1481282
1 Parent(s): 8024e2c

Create conditional_classifier.py

Browse files
Files changed (1) hide show
  1. utils/conditional_classifier.py +92 -0
utils/conditional_classifier.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from typing_extensions import Literal
3
+ import logging
4
+ import pandas as pd
5
+ from pandas import DataFrame, Series
6
+ from utils.config import getconfig
7
+ from utils.preprocessing import processingpipeline
8
+ import streamlit as st
9
+ from transformers import pipeline
10
+
11
+
12
+ @st.cache_resource
13
+ def load_conditionalClassifier(config_file:str = None, classifier_name:str = None):
14
+ """
15
+ loads the document classifier using haystack, where the name/path of model
16
+ in HF-hub as string is used to fetch the model object.Either configfile or
17
+ model should be passed.
18
+ 1. https://docs.haystack.deepset.ai/reference/document-classifier-api
19
+ 2. https://docs.haystack.deepset.ai/docs/document_classifier
20
+ Params
21
+ --------
22
+ config_file: config file path from which to read the model name
23
+ classifier_name: if modelname is passed, it takes a priority if not \
24
+ found then will look for configfile, else raise error.
25
+ Return: document classifier model
26
+ """
27
+ if not classifier_name:
28
+ if not config_file:
29
+ logging.warning("Pass either model name or config file")
30
+ return
31
+ else:
32
+ config = getconfig(config_file)
33
+ classifier_name = config.get('conditional','MODEL')
34
+
35
+ logging.info("Loading conditional classifier")
36
+ doc_classifier = pipeline("text-classification",
37
+ model=classifier_name,
38
+ top_k =1)
39
+
40
+ return doc_classifier
41
+
42
+
43
+ @st.cache_data
44
+ def conditional_classification(haystack_doc:pd.DataFrame,
45
+ threshold:float = 0.8,
46
+ classifier_model:pipeline= None
47
+ )->Tuple[DataFrame,Series]:
48
+ """
49
+ Text-Classification on the list of texts provided. Classifier provides the
50
+ most appropriate label for each text. It informs if paragraph contains any
51
+ netzero information or not.
52
+ Params
53
+ ---------
54
+ haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline
55
+ contains the list of paragraphs in different format,here the list of
56
+ Haystack Documents is used.
57
+ threshold: threshold value for the model to keep the results from classifier
58
+ classifiermodel: you can pass the classifier model directly,which takes priority
59
+ however if not then looks for model in streamlit session.
60
+ In case of streamlit avoid passing the model directly.
61
+ Returns
62
+ ----------
63
+ df: Dataframe
64
+ """
65
+ logging.info("Working on Conditionality Identification")
66
+ haystack_doc['Conditional Label'] = 'NA'
67
+ haystack_doc['Conditional Score'] = 0.0
68
+ haystack_doc['cond_check'] = False
69
+ haystack_doc['cond_check'] = haystack_doc.apply(lambda x: True if (
70
+ (x['Target Label'] == 'TARGET') | (x['Action Label'] == 'Action') |
71
+ (x['Policies_Plans Label'] == 'Policies and Plans')) else
72
+ False, axis=1)
73
+ # we apply Netzero to only paragraphs which are classified as 'Target' related
74
+ temp = haystack_doc[haystack_doc['cond_check'] == True]
75
+ temp = temp.reset_index(drop=True)
76
+ df = haystack_doc[haystack_doc['cond_check'] == False]
77
+ df = df.reset_index(drop=True)
78
+
79
+ if not classifier_model:
80
+ classifier_model = st.session_state['conditional_classifier']
81
+
82
+ results = classifier_model(list(temp.text))
83
+ labels_= [(l[0]['label'],l[0]['score']) for l in results]
84
+ temp['Conditional Label'],temp['Conditional Score'] = zip(*labels_)
85
+ # temp[' Label'] = temp['Netzero Label'].apply(lambda x: _lab_dict[x])
86
+ # merging Target with Non Target dataframe
87
+ df = pd.concat([df,temp])
88
+ df = df.drop(columns = ['cond_check'])
89
+ df = df.reset_index(drop =True)
90
+ df.index += 1
91
+
92
+ return df