Jekaterina commited on
Commit
7f7ae34
·
verified ·
1 Parent(s): b701fe8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from sklearn.metrics import (
4
+ accuracy_score,
5
+ precision_score,
6
+ recall_score,
7
+ f1_score)
8
+ from imblearn.metrics import specificity_score
9
+ import difflib as dl
10
+ import os
11
+
12
+
13
+ # Title and description
14
+ st.title("Robustness and Sensitivity of BERT Models Predicting Alzheimer's Disease from Text")
15
+ st.markdown("Supplemantary material accompanying the following paper: Jekaterina Novikova (2021).[Robustness and Sensitivity of BERT Models Predicting Alzheimer's Disease from Text](https://arxiv.org/abs/2109.11888). \
16
+ *In: The 7th Workshop on Noisy User-generated Text at EMNLP*, 2021.", unsafe_allow_html=True)
17
+ st.image('img/poster2.png')
18
+ st.write("[Link](https://arxiv.org/abs/2109.11888) to the high-res version of the poster.")
19
+
20
+ # Loading data
21
+ my_data = "data/df_test_all.csv"
22
+ @st.cache(persist = True)
23
+ def load_data(dataset):
24
+ df = pd.read_csv(os.path.join(dataset))
25
+ return df
26
+
27
+ df = load_data(my_data)
28
+
29
+ # Sidebar to select type and level of perturbation selection menu
30
+ st.sidebar.title("Selection Menu")
31
+ st.sidebar.markdown("Please select the type and the level of text perturbation below. <hr>", unsafe_allow_html=True)
32
+
33
+ type = st.sidebar.selectbox('Type of perturbations', ["Original / No perturbations", "Delete filled pauses", "Delete info units", "Back-translation", "Substitute with WordNet synonyms"])
34
+ level = None
35
+ iu_type = None
36
+
37
+ if type in ["Substitute with word2vec", "Substitute with WordNet synonyms"]:
38
+ level = st.sidebar.slider('Level of perturbations:', min_value = 0.1, max_value = 0.90, step = 0.10)
39
+ elif type == "Delete info units":
40
+ iu_type = st.sidebar.radio('Type of info units:', ["Action only", "Location only", "Object only", "Subject only"])
41
+
42
+
43
+ # select column names based on subtype of perturbations:
44
+ def select_pred_column(type, level = None, iu_type = None):
45
+ if type == "Original / No perturbations":
46
+ prediction = "pred_original"
47
+ elif type == "Delete filled pauses":
48
+ prediction = "pred_no_filled_pause"
49
+ elif type == "Delete info units":
50
+ if iu_type == "Action only":
51
+ prediction = "pred_no_iu_action"
52
+ elif iu_type == "Location only":
53
+ prediction = "pred_no_iu_loc"
54
+ elif iu_type == "Object only":
55
+ prediction = "pred_no_iu_obj"
56
+ elif iu_type == "Subject only":
57
+ prediction = "pred_no_iu_subj"
58
+ elif type == "Back-translation":
59
+ prediction = "pred_back_transl"
60
+ elif type == "Substitute with word2vec":
61
+ lvl_str = str(level * 100)[:2]
62
+ prediction = "pred_w2v_"+lvl_str
63
+ elif type == "Substitute with WordNet synonyms":
64
+ lvl_str = str(level * 100)[:2]
65
+ prediction = "pred_wnet_"+lvl_str
66
+ return prediction
67
+
68
+ def select_aug_column(type, level = None, iu_type = None):
69
+ if type == "Original / No perturbations":
70
+ augmentation = "utterances"
71
+ elif type == "Delete filled pauses":
72
+ augmentation = "aug_no_filled_pause"
73
+ elif type == "Delete info units":
74
+ if iu_type == "Action only":
75
+ augmentation = "aug_no_iu_action"
76
+ elif iu_type == "Location only":
77
+ augmentation = "aug_no_iu_loc"
78
+ elif iu_type == "Object only":
79
+ augmentation = "aug_no_iu_obj"
80
+ elif iu_type == "Subject only":
81
+ augmentation = "aug_no_iu_subj"
82
+ elif type == "Back-translation":
83
+ augmentation = "aug_back_transl"
84
+ elif type == "Substitute with word2vec":
85
+ lvl_str = str(level * 100)[:2]
86
+ augmentation = "aug_w2v_"+lvl_str
87
+ elif type == "Substitute with WordNet synonyms":
88
+ lvl_str = str(level * 100)[:2]
89
+ augmentation = "aug_wnet_"+lvl_str
90
+ return augmentation
91
+
92
+ #part I
93
+ st.header("1. Classification Performance")
94
+
95
+ st.write("The performance of the fine-tuned BERT model tested on the samples of text with applied perturbations, as selected in the Selection Menu.")
96
+
97
+ if st.button("Calculate performance"):
98
+ acc = accuracy_score(df.label.values, df[select_pred_column(type, level, iu_type)].values)
99
+ f1 = f1_score(df.label.values, df[select_pred_column(type, level, iu_type)].values)
100
+ prec = precision_score(df.label.values, df[select_pred_column(type, level, iu_type)].values)
101
+ rec = recall_score(df.label.values, df[select_pred_column(type, level, iu_type)].values)
102
+ spec = specificity_score(df.label.values, df[select_pred_column(type, level, iu_type)].values)
103
+
104
+ df_perf = pd.DataFrame([acc, f1, prec, rec, spec])
105
+ df_perf.index = ["Accuracy", "F1-score", "Precision", "Recall/Sensitivity", "Specificity"]
106
+ df_perf.columns = ["Performance"]
107
+ st.table( df_perf.T)
108
+
109
+ #part II
110
+ st.header("2. Examples of Text Perturbations")
111
+
112
+
113
+ def text_to_code(text):
114
+ if text == "Healthy Control (label 0)":
115
+ code = [0]
116
+ elif text == "Alzheimer's Disease (label 1)":
117
+ code = [1]
118
+ else:
119
+ code = [0,1]
120
+ return code
121
+
122
+ dx = st.radio('Real disease:', ["Alzheimer's Disease (label 1)", "Healthy Control (label 0)", "both"])
123
+ pred1 = st.radio('Original prediction (before text perturbation):', ["Alzheimer's Disease (label 1)", "Healthy Control (label 0)", "Don't care"])
124
+ pred2 = st.radio('Prediction after text perturbation:', ["Alzheimer's Disease (label 1)", "Healthy Control (label 0)", "Don't care"])
125
+
126
+ subject_ids = df[(df["label"].isin(text_to_code(dx))) & \
127
+ (df["pred_original"].isin(text_to_code(pred1))) &\
128
+ (df[select_pred_column(type, level, iu_type)].isin(text_to_code(pred2)))]["subject_id"]
129
+
130
+ st.write('There are', subject_ids.shape[0], 'text sample(s) that correspond to such a selection.')
131
+
132
+ if subject_ids.shape[0] > 0:
133
+ subj_choice = st.selectbox("Select a text sample:", subject_ids)
134
+
135
+ df_select = df[df.subject_id == subj_choice][["subject_id", "sex", "age", "label", "pred_original", select_pred_column(type, level, iu_type)]]
136
+ df_select.age = df_select.age.astype(int)
137
+ df_select.columns = ["SubjectID", "Sex", "Age", "Real disease label", "Original prediction", "Prediction after perturbation"]
138
+ st.table(df_select)
139
+
140
+ text_orig = df[df.subject_id == subj_choice]["utterances"].values[0]
141
+ text_aug = df[df.subject_id == subj_choice][select_aug_column(type, level, iu_type)].values[0]
142
+
143
+ words_aug = set(text_aug.replace("'"," ' ").split())
144
+ words_orig = set(text_orig.replace("'"," ' ").split())
145
+ s1 = text_orig.replace("'"," ' ").split()
146
+ s2 = text_aug.replace("'"," ' ").split()
147
+
148
+
149
+ seqmatcher = dl.SequenceMatcher(None, s1, s2, autojunk=False)
150
+ res_orig, res_aug = [], []
151
+ for tag, a0, a1, b0, b1 in seqmatcher.get_opcodes():
152
+ if tag == "equal":
153
+ res_orig += s1[a0:a1]
154
+ res_aug += s2[b0:b1]
155
+ else:
156
+ res_orig += ["<span style='color:blue'> <em><b>"+" ".join(s1[a0:a1])+"</b></em></span>"]
157
+ res_aug += ["<span style='color:red'> <em><b>"+" ".join(s2[b0:b1])+"</b></em></span> "]
158
+
159
+ st.write("**<span style='font-size:larger'>The original text</span>**<br>(words are coloured in blue if they were selected for perturbation):", unsafe_allow_html=True)
160
+ st.write('<p style="padding: 1em">'+' '.join(res_orig)+'</p>', unsafe_allow_html=True)
161
+
162
+
163
+ st.write("**<span style='font-size:larger'>The perturbed text</span>**<br>(words are coloured in red if they appeared after perturbation):", unsafe_allow_html=True)
164
+ st.write('<p style="padding: 1em">'+' '.join(res_aug)+'</p>', unsafe_allow_html=True)