Spaces:
Sleeping
Sleeping
from typing import List, Tuple | |
from typing_extensions import Literal | |
import logging | |
import pandas as pd | |
from pandas import DataFrame, Series | |
from utils.config import getconfig | |
from utils.preprocessing import processingpipeline | |
import streamlit as st | |
from transformers import pipeline | |
from setfit import SetFitModel | |
label_dict= {0: 'Agricultural communities', | |
1: 'Children', | |
2: 'Coastal communities', | |
3: 'Ethnic, racial or other minorities', | |
4: 'Fishery communities', | |
5: 'Informal sector workers', | |
6: 'Members of indigenous and local communities', | |
7: 'Migrants and displaced persons', | |
8: 'Older persons', | |
9: 'Other', | |
10: 'Persons living in poverty', | |
11: 'Persons with disabilities', | |
12: 'Persons with pre-existing health conditions', | |
13: 'Residents of drought-prone regions', | |
14: 'Rural populations', | |
15: 'Sexual minorities (LGBTQI+)', | |
16: 'Urban populations', | |
17: 'Women and other genders'} | |
def get_vulnerability_labels(preds): | |
""" | |
Function that takes the numerical predictions as an input and returns a list of the labels. | |
""" | |
# Get label names | |
preds_list = preds.tolist() | |
# Get the name of the group where the prediction is equal to "1" | |
result = [] | |
for sublist in preds_list: | |
names = [label_dict[key] for key, value in enumerate(sublist) if value == 1] | |
result.append(names) | |
return result | |
def load_vulnerabilityClassifier(config_file:str = None, classifier_name:str = None): | |
""" | |
loads the document classifier using haystack, where the name/path of model | |
in HF-hub as string is used to fetch the model object.Either configfile or | |
model should be passed. | |
1. https://docs.haystack.deepset.ai/reference/document-classifier-api | |
2. https://docs.haystack.deepset.ai/docs/document_classifier | |
Params | |
-------- | |
config_file: config file path from which to read the model name | |
classifier_name: if modelname is passed, it takes a priority if not \ | |
found then will look for configfile, else raise error. | |
Return: document classifier model | |
""" | |
# If no classifier given | |
if not classifier_name: | |
if not config_file: | |
logging.warning("Pass either model name or config file") | |
return | |
else: | |
config = getconfig(config_file) | |
classifier_name = config.get('vulnerability','MODEL') | |
logging.info("Loading vulnerability classifier") | |
# we are using the pipeline as the model is multilabel and DocumentClassifier | |
# from Haystack doesnt support multilabel | |
# in pipeline we use 'sigmoid' to explicitly tell pipeline to make it multilabel | |
# if not then it will automatically use softmax, which is not a desired thing. | |
# doc_classifier = TransformersDocumentClassifier( | |
# model_name_or_path=classifier_name, | |
# task="text-classification", | |
# top_k = None) | |
# Download model from HF Hub | |
doc_classifier = SetFitModel.from_pretrained("leavoigt/vulnerability_multilabel") | |
# doc_classifier = pipeline("text-classification", | |
# model=classifier_name, | |
# return_all_scores=True, | |
# function_to_apply= "sigmoid") | |
return doc_classifier | |
def vulnerability_classification(haystack_doc:pd.DataFrame, | |
threshold:float = 0.5, | |
classifier_model:pipeline= None | |
)->Tuple[DataFrame,Series]: | |
""" | |
Text-Classification on the list of texts provided. Classifier provides the | |
most appropriate label for each text. these labels are in terms of if text | |
reference a group in a vulnerable situation. | |
--------- | |
haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline | |
contains the list of paragraphs in different format,here the list of | |
Haystack Documents is used. | |
threshold: threshold value for the model to keep the results from classifier | |
classifiermodel: you can pass the classifier model directly,which takes priority | |
however if not then looks for model in streamlit session. | |
In case of streamlit avoid passing the model directly. | |
Returns | |
---------- | |
df: Dataframe with two columns['SDG:int', 'text'] | |
x: Series object with the unique SDG covered in the document uploaded and | |
the number of times it is covered/discussed/count_of_paragraphs. | |
""" | |
logging.info("Working on vulnerability Identification") | |
haystack_doc['Vulnerability Label'] = 'NA' | |
# haystack_doc['PA_check'] = haystack_doc['Policy-Action Label'].apply(lambda x: True if len(x) != 0 else False) | |
# df1 = haystack_doc[haystack_doc['PA_check'] == True] | |
# df = haystack_doc[haystack_doc['PA_check'] == False] | |
if not classifier_model: | |
classifier_model = st.session_state['vulnerability_classifier'] | |
predictions = classifier_model(list(haystack_doc.text)) | |
pred_labels = get_vulnerability_labels(predictions) | |
haystack_doc['Vulnerability Label'] = pred_labels | |
# placeholder = {} | |
# for j in range(len(temp)): | |
# placeholder[temp[j]['label']] = temp[j]['score'] | |
# list_.append(placeholder) | |
# labels_ = [{**list_[l]} for l in range(len(predictions))] | |
# truth_df = DataFrame.from_dict(labels_) | |
# truth_df = truth_df.round(2) | |
# truth_df = truth_df.astype(float) >= threshold | |
# truth_df = truth_df.astype(str) | |
# categories = list(truth_df.columns) | |
# truth_df['Vulnerability Label'] = truth_df.apply(lambda x: {i if x[i]=='True' else | |
# None for i in categories}, axis=1) | |
# truth_df['Vulnerability Label'] = truth_df.apply(lambda x: list(x['Vulnerability Label'] | |
# -{None}),axis=1) | |
# haystack_doc['Vulnerability Label'] = list(truth_df['Vulnerability Label']) | |
return haystack_doc |