山越貴耀 commited on
Commit
4c1fd66
1 Parent(s): 22a211b
Files changed (2) hide show
  1. app.py +253 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from sklearn.decomposition import PCA
9
+ from sklearn.manifold import TSNE
10
+ from sentence_transformers import SentenceTransformer
11
+ from transformers import BertTokenizer,BertForMaskedLM
12
+ import cv2
13
+
14
+ def load_sentence_model():
15
+ sentence_model = SentenceTransformer('paraphrase-distilroberta-base-v1')
16
+ return sentence_model
17
+
18
+ @st.cache(show_spinner=False)
19
+ def load_model(model_name):
20
+ if model_name.startswith('bert'):
21
+ tokenizer = BertTokenizer.from_pretrained(model_name)
22
+ model = BertForMaskedLM.from_pretrained(model_name)
23
+ model.eval()
24
+ return tokenizer,model
25
+
26
+ @st.cache
27
+ def load_data(sentence_num):
28
+ df = pd.read_csv('tsne_out.csv')
29
+ df = df.loc[lambda d: (d['sentence_num']==sentence_num)&(d['iter_num']<1000)]
30
+ return df
31
+
32
+ @st.cache
33
+ def mask_prob(model,mask_id,sentences,position,temp=1):
34
+ masked_sentences = sentences.clone()
35
+ masked_sentences[:, position] = mask_id
36
+ with torch.no_grad():
37
+ logits = model(masked_sentences)[0]
38
+ return F.log_softmax(logits[:, position] / temp, dim = -1)
39
+
40
+ @st.cache
41
+ def sample_words(probs,pos,sentences):
42
+ candidates = [[tokenizer.decode([candidate]),torch.exp(probs)[0,candidate].item()]
43
+ for candidate in torch.argsort(probs[0],descending=True)[:10]]
44
+ df = pd.DataFrame(data=candidates,columns=['word','prob'])
45
+ chosen_words = torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1)
46
+ new_sentences = sentences.clone()
47
+ new_sentences[:, pos] = chosen_words
48
+ return new_sentences, df
49
+
50
+ def run_chains(tokenizer,model,mask_id,input_text,num_steps):
51
+ init_sent = tokenizer(input_text,return_tensors='pt')['input_ids']
52
+ seq_len = init_sent.shape[1]
53
+ sentence = init_sent.clone()
54
+ data_list = []
55
+ st.sidebar.write('Generating samples...')
56
+ st.sidebar.write('This takes ~30 seconds for 1000 steps with ~10 token sentences')
57
+ chain_progress = st.sidebar.progress(0)
58
+ for step_id in range(num_steps):
59
+ chain_progress.progress((step_id+1)/num_steps)
60
+ pos = torch.randint(seq_len-2,size=(1,)).item()+1
61
+ data_list.append([step_id,' '.join([tokenizer.decode([token]) for token in sentence[0]]),pos])
62
+ probs = mask_prob(model,mask_id,sentence,pos)
63
+ sentence,_ = sample_words(probs,pos,sentence)
64
+ return pd.DataFrame(data=data_list,columns=['step','sentence','next_sample_loc'])
65
+
66
+ @st.cache(suppress_st_warning=True,show_spinner=False)
67
+ def run_tsne(chain):
68
+ st.sidebar.write('Running t-SNE...')
69
+ chain = chain.assign(cleaned_sentence=chain.sentence.str.replace(r'\[CLS\] ', '',regex=True).str.replace(r' \[SEP\]', '',regex=True))
70
+ sentence_model = load_sentence_model()
71
+ sentence_embeddings = sentence_model.encode(chain.cleaned_sentence.to_list(), show_progress_bar=False)
72
+
73
+ tsne = TSNE(n_components = 2, n_iter=2000)
74
+ big_pca = PCA(n_components = 50)
75
+ tsne_vals = tsne.fit_transform(big_pca.fit_transform(sentence_embeddings))
76
+ tsne = pd.concat([chain, pd.DataFrame(tsne_vals, columns = ['x_tsne', 'y_tsne'],index=chain.index)], axis = 1)
77
+ return tsne
78
+
79
+ def clear_df():
80
+ del st.session_state['df']
81
+
82
+ @st.cache(show_spinner=False)
83
+ def plot_fig(df,sent_id,xlims,ylims,color_list):
84
+ x_tsne, y_tsne = df.x_tsne, df.y_tsne
85
+ fig = plt.figure(figsize=(5,5),dpi=200)
86
+ ax = fig.add_subplot(1,1,1)
87
+ ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
88
+ ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
89
+ ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
90
+ ax.set_xlim(*xlims)
91
+ ax.set_ylim(*ylims)
92
+ ax.axis('off')
93
+ ax.set_title(df.cleaned_sentence.to_list()[sent_id])
94
+ fig.savefig(f'figures/{sent_id}.png')
95
+ plt.clf()
96
+ plt.close()
97
+
98
+ def pre_render_images(df,input_sent_id):
99
+ sent_id_options = [min(len(df)-1,max(0,input_sent_id+increment)) for increment in [-500,-100,-10,-1,0,1,10,100,500]]
100
+ x_tsne, y_tsne = df.x_tsne, df.y_tsne
101
+ xmax,xmin = (max(x_tsne)//30+1)*30,(min(x_tsne)//30-1)*30
102
+ ymax,ymin = (max(y_tsne)//30+1)*30,(min(y_tsne)//30-1)*30
103
+ color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
104
+ sent_list = []
105
+ fig_production = st.progress(0)
106
+ for fig_id,sent_id in enumerate(sent_id_options):
107
+ fig_production.progress(fig_id+1)
108
+ plot_fig(fig_id,x_tsne,y_tsne,sent_id,[xmin,xmax],[ymin,ymax],color_list)
109
+ sent_list.append(df.cleaned_sentence.to_list()[sent_id])
110
+ return sent_list
111
+
112
+
113
+ if __name__=='__main__':
114
+ # Config
115
+ max_width = 1500
116
+ padding_top = 2
117
+ padding_right = 5
118
+ padding_bottom = 0
119
+ padding_left = 5
120
+
121
+ define_margins = f"""
122
+ <style>
123
+ .appview-container .main .block-container{{
124
+ max-width: {max_width}px;
125
+ padding-top: {padding_top}rem;
126
+ padding-right: {padding_right}rem;
127
+ padding-left: {padding_left}rem;
128
+ padding-bottom: {padding_bottom}rem;
129
+ }}
130
+ </style>
131
+ """
132
+ hide_table_row_index = """
133
+ <style>
134
+ tbody th {display:none}
135
+ .blank {display:none}
136
+ </style>
137
+ """
138
+ st.markdown(define_margins, unsafe_allow_html=True)
139
+ st.markdown(hide_table_row_index, unsafe_allow_html=True)
140
+
141
+ # Title
142
+ st.header("Demo: Probing BERT's priors with serial reproduction chains")
143
+
144
+ # Load BERT
145
+ tokenizer,model = load_model('bert-base-uncased')
146
+ mask_id = tokenizer.encode("[MASK]")[1:-1][0]
147
+
148
+ # First step: load the dataframe containing sentences
149
+ input_type = st.sidebar.radio(label='1. Choose the input type',options=('Use one of our example sentences','Use your own initial sentence'))
150
+
151
+ if input_type=='Use one of our example sentences':
152
+ sentence = st.sidebar.selectbox("Select the inital sentence",
153
+ ('About 170 campers attend the camps each week.',
154
+ 'She grew up with three brothers and ten sisters.'))
155
+ if sentence=='About 170 campers attend the camps each week.':
156
+ sentence_num = 6
157
+ else:
158
+ sentence_num = 8
159
+
160
+ st.session_state.df = load_data(sentence_num)
161
+
162
+ else:
163
+ sentence = st.sidebar.text_input('Type down your own sentence here',on_change=clear_df)
164
+ num_steps = st.sidebar.number_input(label='How many steps do you want to run?',value=1000)
165
+ if st.sidebar.button('Run chains'):
166
+ chain = run_chains(tokenizer,model,mask_id,sentence,num_steps=num_steps)
167
+ st.session_state.df = run_tsne(chain)
168
+ st.session_state.finished_sampling = True
169
+
170
+ if 'df' in st.session_state:
171
+ df = st.session_state.df
172
+ sent_id = st.sidebar.slider(label='2. Select the position in a chain to start exploring',
173
+ min_value=0,max_value=len(df)-1,value=0)
174
+
175
+ explore_type = st.sidebar.radio('3. Choose the way to explore',options=['In fixed increments','Click through each step','Autoplay'])
176
+ if explore_type=='Autoplay':
177
+ if st.button('Create the video (this may take a few minutes)'):
178
+ st.write('Creating the video...')
179
+ x_tsne, y_tsne = df.x_tsne, df.y_tsne
180
+ xmax,xmin = (max(x_tsne)//30+1)*30,(min(x_tsne)//30-1)*30
181
+ ymax,ymin = (max(y_tsne)//30+1)*30,(min(y_tsne)//30-1)*30
182
+ color_list = sns.color_palette('flare',n_colors=1200)
183
+ fig_production = st.progress(0)
184
+
185
+ plot_fig(df,0,[xmin,xmax],[ymin,ymax],color_list)
186
+ img = cv2.imread('figures/0.png')
187
+ height, width, layers = img.shape
188
+ size = (width,height)
189
+ out = cv2.VideoWriter('sampling_video.mp4',cv2.VideoWriter_fourcc(*'H264'), 3, size)
190
+ for sent_id in range(1000):
191
+ fig_production.progress((sent_id+1)/1000)
192
+ plot_fig(df,sent_id,[xmin,xmax],[ymin,ymax],color_list)
193
+ img = cv2.imread(f'figures/{sent_id}.png')
194
+ out.write(img)
195
+ out.release()
196
+
197
+ cols = st.columns([1,2,1])
198
+ with cols[1]:
199
+ with open('sampling_video.mp4', 'rb') as f:
200
+ st.video(f)
201
+ else:
202
+ if explore_type=='In fixed increments':
203
+ button_labels = ['-500','-100','-10','-1','0','+1','+10','+100','+500']
204
+ increment = st.sidebar.radio(label='select increment',options=button_labels,index=4)
205
+ sent_id += int(increment.replace('+',''))
206
+ sent_id = min(len(df)-1,max(0,sent_id))
207
+ elif explore_type=='Click through each step':
208
+ sent_id = st.sidebar.number_input(label='step number',value=sent_id)
209
+
210
+ x_tsne, y_tsne = df.x_tsne, df.y_tsne
211
+ xlims = [(min(x_tsne)//30-1)*30,(max(x_tsne)//30+1)*30]
212
+ ylims = [(min(y_tsne)//30-1)*30,(max(y_tsne)//30+1)*30]
213
+ color_list = sns.color_palette('flare',n_colors=int(len(df)*1.2))
214
+
215
+ fig = plt.figure(figsize=(5,5),dpi=200)
216
+ ax = fig.add_subplot(1,1,1)
217
+ ax.plot(x_tsne[:sent_id+1],y_tsne[:sent_id+1],linewidth=0.2,color='gray',zorder=1)
218
+ ax.scatter(x_tsne[:sent_id+1],y_tsne[:sent_id+1],s=5,color=color_list[:sent_id+1],zorder=2)
219
+ ax.scatter(x_tsne[sent_id:sent_id+1],y_tsne[sent_id:sent_id+1],s=50,marker='*',color='blue',zorder=3)
220
+ ax.set_xlim(*xlims)
221
+ ax.set_ylim(*ylims)
222
+ ax.axis('off')
223
+
224
+ sentence = df.cleaned_sentence.to_list()[sent_id]
225
+ input_sent = tokenizer(sentence,return_tensors='pt')['input_ids']
226
+ decoded_sent = [tokenizer.decode([token]) for token in input_sent[0]]
227
+ show_candidates = st.checkbox('Show candidates')
228
+ if show_candidates:
229
+ st.write('Click any word to see each candidate with its probability')
230
+ cols = st.columns(len(decoded_sent))
231
+ with cols[0]:
232
+ st.write(decoded_sent[0])
233
+ with cols[-1]:
234
+ st.write(decoded_sent[-1])
235
+ for word_id,(col,word) in enumerate(zip(cols[1:-1],decoded_sent[1:-1])):
236
+ with col:
237
+ if st.button(word):
238
+ probs = mask_prob(model,mask_id,input_sent,word_id+1)
239
+ _,candidates_df = sample_words(probs, word_id+1, input_sent)
240
+ st.table(candidates_df)
241
+ else:
242
+ disp_style = '"font-family:san serif; color:Black; font-size: 25px; font-weight:bold"'
243
+ if explore_type=='Click through each step' and input_type=='Use your own initial sentence' and sent_id>0 and 'finished_sampling' in st.session_state:
244
+ sampled_loc = df.next_sample_loc.to_list()[sent_id-1]
245
+ disp_sent_before = f'<p style={disp_style}>'+' '.join(decoded_sent[1:sampled_loc])
246
+ new_word = f'<span style="color:Red">{decoded_sent[sampled_loc]}</span>'
247
+ disp_sent_after = ' '.join(decoded_sent[sampled_loc+1:-1])+'</p>'
248
+ st.markdown(disp_sent_before+' '+new_word+' '+disp_sent_after,unsafe_allow_html=True)
249
+ else:
250
+ st.markdown(f'<p style={disp_style}>{sentence}</p>',unsafe_allow_html=True)
251
+ cols = st.columns([1,2,1])
252
+ with cols[1]:
253
+ st.pyplot(fig)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentence_transformers
4
+ cv2
5
+ seaborn
6
+ sklearn