nazneen commited on
Commit
3e601d6
1 Parent(s): 32daca9

streamlit app

Browse files
Files changed (1) hide show
  1. app.py +182 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### LIBRARIES ###
2
+ # # Data
3
+ import numpy as np
4
+ import pandas as pd
5
+ import json
6
+ from math import floor
7
+
8
+ # Robustness Gym and Analysis
9
+ import robustnessgym as rg
10
+ from gensim.models.doc2vec import Doc2Vec
11
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
12
+ import nltk
13
+ nltk.download('punkt') #make sure that punkt is downloaded
14
+
15
+ # App & Visualization
16
+ import streamlit as st
17
+ import altair as alt
18
+
19
+ # utils
20
+ from interactive_model_cards import utils as ut
21
+ from interactive_model_cards import app_layout as al
22
+ from random import sample
23
+ from PIL import Image
24
+
25
+
26
+
27
+ ### LOADING DATA ###
28
+ # model card data
29
+ @st.experimental_memo
30
+ def load_model_card():
31
+ with open("./assets/data/text_explainer/model_card.json") as f:
32
+ mc_text = json.load(f)
33
+ return mc_text
34
+
35
+
36
+ # pre-computed robusntess gym dev bench
37
+ # @st.experimental_singleton
38
+ @st.cache(allow_output_mutation=True)
39
+ def load_data():
40
+ # load dev bench
41
+ devBench = rg.DevBench.load("./assets/data/rg/sst_db.devbench")
42
+ return devBench
43
+
44
+
45
+ # load model
46
+ @st.experimental_singleton
47
+ def load_model():
48
+ model = rg.HuggingfaceModel(
49
+ "distilbert-base-uncased-finetuned-sst-2-english", is_classifier=True
50
+ )
51
+ return model
52
+
53
+ #load pre-computed embedding
54
+ def load_embedding():
55
+ embedding = pd.read_pickle("./assets/models/sst_vectors.pkl")
56
+ return embedding
57
+
58
+ #load doc2vec model
59
+ @st.experimental_singleton
60
+ def load_doc2vec():
61
+ doc2vec = Doc2Vec.load("./assets/models/sst_train.doc2vec")
62
+ return(doc2vec)
63
+
64
+
65
+ # @st.experimental_memo
66
+ def load_examples():
67
+ with open("./assets/data/user_data/example_sentence.json") as f:
68
+ examples = json.load(f)
69
+ return examples
70
+
71
+
72
+ # loading the dataset
73
+ def load_basic():
74
+ # load data
75
+ devBench = load_data()
76
+ # load model
77
+ model = load_model()
78
+ #protected_classes
79
+ protected_classes = json.load(open("./assets/data/protected_terms.json"))
80
+
81
+ return devBench, model, protected_classes
82
+
83
+ @st.experimental_singleton
84
+ def load_title():
85
+ img = Image.open("./assets/img/title.png")
86
+ return(img)
87
+
88
+
89
+ if __name__ == "__main__":
90
+
91
+ ### STREAMLIT APP CONGFIG ###
92
+ st.set_page_config(layout="wide", page_title="Interactive Model Card")
93
+
94
+ # import custom styling
95
+ ut.init_style()
96
+
97
+ ### LOAD DATA AND SESSION VARIABLES ###
98
+
99
+ # ******* loading the mode and the data
100
+ with st.spinner():
101
+ sst_db, model,protected_classes = load_basic()
102
+ embedding = load_embedding()
103
+ doc2vec = load_doc2vec()
104
+
105
+ # load example sentences
106
+ sentence_examples = load_examples()
107
+
108
+ # ******* session state variables
109
+ if "user_data" not in st.session_state:
110
+ st.session_state["user_data"] = pd.DataFrame()
111
+ if "example_sent" not in st.session_state:
112
+ st.session_state["example_sent"] = "I like you. I love you"
113
+ if "quant_ex" not in st.session_state:
114
+ st.session_state["quant_ex"] = {"Overall Performance": sst_db.metrics["model"]}
115
+ if "selected_slice" not in st.session_state:
116
+ st.session_state["selected_slice"] = None
117
+ if "slice_terms" not in st.session_state:
118
+ st.session_state["slice_terms"] = {}
119
+ if "embedding" not in st.session_state:
120
+ st.session_state["embedding"] = embedding
121
+ if "protected_class" not in st.session_state:
122
+ st.session_state["protected_class"] = protected_classes
123
+
124
+
125
+ ### STREAMLIT APP LAYOUT###
126
+
127
+ # ******* MODEL CARD PANEL *******
128
+ #st.sidebar.title("Interactive Model Card")
129
+ img = load_title()
130
+ st.sidebar.image(img,width=400)
131
+ st.sidebar.warning("Data is not permanently collected or stored from your interactions, but is temporarily cached during usage.")
132
+
133
+ # load model card data
134
+ errors = st.sidebar.checkbox("Show Warnings", value=True)
135
+ model_card = load_model_card()
136
+ al.model_card_panel(model_card,errors)
137
+
138
+ lcol, rcol = st.columns([4, 8])
139
+
140
+ # ******* USER EXAMPLE DATA PANEL *******
141
+ st.markdown("---")
142
+ with lcol:
143
+
144
+ # Choose waht to show for the qunatiative analysis.
145
+ st.write("""<h1 style="font-size:20px;padding-top:0px;"> Quantitative Analysis</h1>""",
146
+ unsafe_allow_html=True)
147
+
148
+ st.markdown("View the model's performance or visually explore the model's training and testing dataset")
149
+
150
+ data_view = st.selectbox("Show:",
151
+ ["Model Performance Metrics","Data Subpopulation Comparison Visualization"])
152
+
153
+ st.markdown("Any groups you define via the *analysis actions* will be automatically added to the view")
154
+ st.markdown("---")
155
+
156
+
157
+ # Additional Analysis Actions
158
+ st.write(
159
+ """<h1 style="font-size:18px;padding-top:5px;"> Analysis Actions</h1>""",
160
+ unsafe_allow_html=True,
161
+ )
162
+ al.example_panel(sentence_examples, model, sst_db,doc2vec)
163
+
164
+ # ****** GUIDANCE PANEL *****
165
+ with st.expander("Guidance"):
166
+ st.markdown(
167
+ "Need help understanding what you're seeing in this model card?"
168
+ )
169
+
170
+ st.markdown(
171
+ " * **[Understanding Metrics](https://stanford.edu/~shervine/teaching/cs-229/cheatsheet-machine-learning-tips-and-tricks)**: A cheatsheet of model metrics"
172
+ )
173
+ st.markdown(
174
+ " * **[Understanding Sentiment Models](https://www.semanticscholar.org/topic/Sentiment-analysis/6011)**: An overview of sentiment analysis"
175
+ )
176
+ st.markdown(
177
+ "* **[Next Steps](https://docs.google.com/document/d/1r9J1NQ7eTibpXkCpcucDEPhASGbOQAMhRTBvosGu4Pk/edit?usp=sharin)**: Suggestions for follow-on actions"
178
+ )
179
+ st.markdown("Feel free to submit feedback via our [online form](https://sfdc.co/imc_feedback)")
180
+
181
+ # ******* QUANTITATIVE DATA PANEL *******
182
+ al.quant_panel(sst_db, st.session_state["embedding"], rcol,data_view)