File size: 5,828 Bytes
079c7c0
 
 
 
 
 
 
 
 
 
ef9edc5
079c7c0
 
 
 
 
 
 
40953ed
 
079c7c0
 
 
cbf120d
079c7c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d30cd5d
 
41aa9dd
d30cd5d
 
154ee8f
ef9edc5
079c7c0
 
ea878f4
 
ef9edc5
079c7c0
 
38ca272
bf7f98d
38ca272
4ecb551
 
 
 
d30cd5d
af776d5
079c7c0
 
d6606ef
 
 
 
f5a6f87
e692c88
6f6085f
 
 
d6606ef
 
 
 
 
 
 
 
 
 
 
 
 
793c140
d6606ef
793c140
 
d6606ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897a5b5
4a76f03
4caba7c
4a76f03
897a5b5
4a76f03
897a5b5
 
6f6085f
897a5b5
b5de5f6
 
 
d6606ef
 
71d49f5
ea878f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# set path
import glob, os, sys; 
sys.path.append('../utils')

#import needed libraries
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from utils.vulnerability_classifier import load_vulnerabilityClassifier, vulnerability_classification
import logging
logger = logging.getLogger(__name__)
from utils.config import get_classifier_params
from utils.preprocessing import paraLengthCheck
from io import BytesIO
import xlsxwriter
import plotly.express as px
from utils.vulnerability_classifier import label_dict



# Declare all the necessary variables
classifier_identifier = 'vulnerability'
params  = get_classifier_params(classifier_identifier)

@st.cache_data
def to_excel(df,sectorlist):
    len_df = len(df)
    output = BytesIO()
    writer = pd.ExcelWriter(output, engine='xlsxwriter')
    df.to_excel(writer, index=False, sheet_name='Sheet1')
    workbook = writer.book
    worksheet = writer.sheets['Sheet1']
    worksheet.data_validation('S2:S{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': ['No', 'Yes', 'Discard']})
    worksheet.data_validation('X2:X{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})
    worksheet.data_validation('T2:T{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})
    worksheet.data_validation('U2:U{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})                               
    worksheet.data_validation('V2:V{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})
    worksheet.data_validation('W2:U{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})                            
    writer.save()
    processed_data = output.getvalue()
    return processed_data

def app():

    ### Main app code ###
    with st.container():

            # If a document has been processed                   
            if 'key0' in st.session_state:

                # Run vulnerability classifier
                df = st.session_state.key0
                classifier = load_vulnerabilityClassifier(classifier_name=params['model_name'])
                st.session_state['{}_classifier'.format(classifier_identifier)] = classifier

    
                # Get the predictions    
                df = vulnerability_classification(haystack_doc=df,
                                            threshold= params['threshold'])

                # Filter the dataframe to only show the paragraphs with references
                df_filtered = df[df['Vulnerability Label'].apply(lambda x: len(x) > 0 and 'Other' not in x)]

                # Rename column 
                df_filtered.rename(columns={'Vulnerability Label': 'Group identified'})


                # Store df in session state with key1
                st.session_state.key1 = df_filtered


def vulnerability_display(): 
    
    # Assign dataframe a name
    df_vul = st.session_state['key0']

    #st.write(df_vul)

    # Header
    st.subheader("Explore references to vulnerable groups:")
    
    col1, col2 = st.columns([1,1])
    
    with col1:
        

        # Text 
        num_paragraphs = len(df_vul['Vulnerability Label'])
        num_references = df_vul['Vulnerability Label'].apply(lambda x: 'Other' not in x).sum()
       
        st.markdown(f"""<div style="text-align: justify;"> The document contains a
                total of <span style="color: red;">{num_paragraphs}</span> paragraphs.
                We identified <span style="color: red;">{num_references}</span>
                references to groups in vulnerable situations.</div>
                <br>
                In the chart on the right you can see how often each group has been references.
                For a more detailed view in the text, see the paragraphs and 
                their respective labels in the table below.</div>""", unsafe_allow_html=True)

    with col2:
        
        ### Bar chart
                    
        # # Create a df that stores all the labels
        df_labels = pd.DataFrame(list(label_dict.items()), columns=['Label ID', 'Label'])

        # Count how often each label appears in the "Vulnerability Labels" column
        group_counts = {}

        # Iterate through each sublist
        for index, row in df_vul.iterrows():
            
            # Iterate through each group in the sublist
            for sublist in row['Vulnerability Label']:
                
                # Update the count in the dictionary
                group_counts[sublist] = group_counts.get(sublist, 0) + 1

        # Create a new dataframe from group_counts
        df_label_count = pd.DataFrame(list(group_counts.items()), columns=['Label', 'Count'])

        # Merge the label counts with the df_label DataFrame
        df_label_count = df_labels.merge(df_label_count, on='Label', how='left')

        # Exclude the "Other" group
        df_bar_chart = df_label_count[df_label_count['Label'] != 'Other']
        
        # Bar chart
        fig = px.bar(df_bar_chart, 
                     x='Label', 
                     y='Count', 
                     title='How many references have been found for each group?',
                     labels={'Count': 'Frequency'})

        #Show plot
        st.plotly_chart(fig, use_container_width=True)

    # ### Table 
    #st.write(df_vul[df_vul['Vulnerability Label'].apply(lambda x: 'Other' not in x)])