taka-yamakoshi
commited on
Commit
·
ca1b654
1
Parent(s):
9839e32
add heads
Browse files
app.py
CHANGED
@@ -130,14 +130,14 @@ def show_instruction(sent,fontsize=20):
|
|
130 |
suffix = '</span></p>'
|
131 |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
|
132 |
|
133 |
-
def create_interventions(token_id,interv_types,num_heads,multihead=False):
|
134 |
interventions = {}
|
135 |
for rep in ['lay','qry','key','val']:
|
136 |
if rep in interv_types:
|
137 |
if multihead:
|
138 |
interventions[rep] = [(head_id,token_id,[0,1]) for head_id in range(num_heads)]
|
139 |
else:
|
140 |
-
interventions[rep] = [(head_id,token_id,[
|
141 |
else:
|
142 |
interventions[rep] = []
|
143 |
return interventions
|
@@ -176,6 +176,27 @@ def run_intervention(interventions,batch_size,skeleton_model,model,masked_ids_op
|
|
176 |
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
|
177 |
return probs
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
if __name__=='__main__':
|
180 |
wide_setup()
|
181 |
load_css('style.css')
|
@@ -217,7 +238,7 @@ if __name__=='__main__':
|
|
217 |
show_instruction('2. Select sites to mask out and click "Confirm"',fontsize=16)
|
218 |
#show_instruction('------------------------------',fontsize=32)
|
219 |
annotate_mask(1,sent_1)
|
220 |
-
show_instruction('------------------------------',fontsize=
|
221 |
annotate_mask(2,sent_2)
|
222 |
if st.button('Confirm',key='confirm_mask'):
|
223 |
st.session_state['page_status'] = 'annotate_options'
|
@@ -230,21 +251,34 @@ if __name__=='__main__':
|
|
230 |
show_instruction('3. Select options and click "Confirm"',fontsize=16)
|
231 |
#show_instruction('------------------------------',fontsize=32)
|
232 |
annotate_options(1,sent_1)
|
233 |
-
show_instruction('------------------------------',fontsize=
|
234 |
annotate_options(2,sent_2)
|
235 |
if st.button('Confirm',key='confirm_option'):
|
236 |
st.session_state['page_status'] = 'analysis'
|
237 |
st.experimental_rerun()
|
238 |
|
239 |
if st.session_state['page_status']=='analysis':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
sent_1 = st.session_state['sent_1']
|
241 |
sent_2 = st.session_state['sent_2']
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
#show_annotated_sentence(st.session_state['decoded_sent_2'],
|
246 |
-
# option_locs=st.session_state['option_locs_2'],
|
247 |
-
# mask_locs=st.session_state['mask_locs_2'])
|
248 |
|
249 |
option_1_locs, option_2_locs = {}, {}
|
250 |
pron_locs = {}
|
@@ -263,12 +297,6 @@ if __name__=='__main__':
|
|
263 |
pron_locs[f'sent_{sent_id}'],
|
264 |
option_2_locs[f'sent_{sent_id}'],mask_id)
|
265 |
|
266 |
-
#st.write(option_1_locs)
|
267 |
-
#st.write(option_2_locs)
|
268 |
-
#st.write(pron_locs)
|
269 |
-
#for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
|
270 |
-
# st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
|
271 |
-
|
272 |
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
|
273 |
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
|
274 |
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
|
@@ -293,45 +321,31 @@ if __name__=='__main__':
|
|
293 |
assert np.all(compare_1.astype(int)==compare_2.astype(int))
|
294 |
context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation
|
295 |
|
296 |
-
multihead = True
|
297 |
assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
|
298 |
assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
|
299 |
assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
|
300 |
token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs
|
301 |
-
#st.write(token_id_list)
|
302 |
|
303 |
effect_array = []
|
304 |
for token_id in token_id_list:
|
305 |
token_id += 1
|
306 |
effect_list = []
|
307 |
for layer_id in range(num_layers):
|
308 |
-
interventions = [create_interventions(token_id,
|
|
|
309 |
if multihead:
|
310 |
probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
311 |
else:
|
312 |
-
probs = run_intervention(interventions,
|
313 |
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
|
314 |
effect_list.append(effect)
|
315 |
effect_array.append(effect_list)
|
316 |
effect_array = np.transpose(np.array(effect_array),(1,0,2))
|
317 |
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
fig,ax = plt.subplots()
|
326 |
-
ax.set_box_aspect(num_layers)
|
327 |
-
ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
|
328 |
-
vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
|
329 |
-
ax.set_xticks([])
|
330 |
-
ax.set_xticklabels([])
|
331 |
-
ax.set_yticks([])
|
332 |
-
ax.set_yticklabels([])
|
333 |
-
ax.spines['top'].set_visible(False)
|
334 |
-
ax.spines['bottom'].set_visible(False)
|
335 |
-
ax.spines['right'].set_visible(False)
|
336 |
-
ax.spines['left'].set_visible(False)
|
337 |
-
st.pyplot(fig)
|
|
|
130 |
suffix = '</span></p>'
|
131 |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
|
132 |
|
133 |
+
def create_interventions(token_id,interv_types,num_heads,multihead=False,heads=[]):
|
134 |
interventions = {}
|
135 |
for rep in ['lay','qry','key','val']:
|
136 |
if rep in interv_types:
|
137 |
if multihead:
|
138 |
interventions[rep] = [(head_id,token_id,[0,1]) for head_id in range(num_heads)]
|
139 |
else:
|
140 |
+
interventions[rep] = [(head_id,token_id,[i,i+len(heads)]) for i,head_id in enumerate(heads)]
|
141 |
else:
|
142 |
interventions[rep] = []
|
143 |
return interventions
|
|
|
176 |
assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
|
177 |
return probs
|
178 |
|
179 |
+
def show_results(effect_array,masked_sent,token_id_list,num_layers):
|
180 |
+
cols = st.columns(len(masked_sent)-2)
|
181 |
+
for col_id,col in enumerate(cols):
|
182 |
+
with col:
|
183 |
+
st.write(tokenizer.decode([masked_sent[col_id+1]]))
|
184 |
+
if col_id in token_id_list:
|
185 |
+
interv_id = token_id_list.index(col_id)
|
186 |
+
fig,ax = plt.subplots()
|
187 |
+
ax.set_box_aspect(num_layers)
|
188 |
+
ax.imshow(effect_array[:,interv_id:interv_id+1],cmap=sns.color_palette("light:r", as_cmap=True),
|
189 |
+
vmin=effect_array.min(),vmax=effect_array.max())
|
190 |
+
ax.set_xticks([])
|
191 |
+
ax.set_xticklabels([])
|
192 |
+
ax.set_yticks([])
|
193 |
+
ax.set_yticklabels([])
|
194 |
+
ax.spines['top'].set_visible(False)
|
195 |
+
ax.spines['bottom'].set_visible(False)
|
196 |
+
ax.spines['right'].set_visible(False)
|
197 |
+
ax.spines['left'].set_visible(False)
|
198 |
+
st.pyplot(fig)
|
199 |
+
|
200 |
if __name__=='__main__':
|
201 |
wide_setup()
|
202 |
load_css('style.css')
|
|
|
238 |
show_instruction('2. Select sites to mask out and click "Confirm"',fontsize=16)
|
239 |
#show_instruction('------------------------------',fontsize=32)
|
240 |
annotate_mask(1,sent_1)
|
241 |
+
show_instruction('------------------------------',fontsize=24)
|
242 |
annotate_mask(2,sent_2)
|
243 |
if st.button('Confirm',key='confirm_mask'):
|
244 |
st.session_state['page_status'] = 'annotate_options'
|
|
|
251 |
show_instruction('3. Select options and click "Confirm"',fontsize=16)
|
252 |
#show_instruction('------------------------------',fontsize=32)
|
253 |
annotate_options(1,sent_1)
|
254 |
+
show_instruction('------------------------------',fontsize=24)
|
255 |
annotate_options(2,sent_2)
|
256 |
if st.button('Confirm',key='confirm_option'):
|
257 |
st.session_state['page_status'] = 'analysis'
|
258 |
st.experimental_rerun()
|
259 |
|
260 |
if st.session_state['page_status']=='analysis':
|
261 |
+
interv_reps = st.multiselect('Select the types of representations to intervene.',['layer','query','key','value'])
|
262 |
+
rep_dict = {'layer':'lay','query':'qry','key':'key','value':'val'}
|
263 |
+
multihead = not st.checkbox('Perform individual head analysis (takes time)')
|
264 |
+
if not multihead:
|
265 |
+
heads = st.multiselect('Select heads to intervene.',list(np.arange(1,num_heads+1)))
|
266 |
+
else:
|
267 |
+
heads = []
|
268 |
+
|
269 |
+
if st.button('Run',key='run'):
|
270 |
+
st.session_state['reps'] = [rep_dict[rep] for rep in interv_reps]
|
271 |
+
st.session_state['multihead'] = multihead
|
272 |
+
st.session_state['heads'] = heads
|
273 |
+
st.session_state['page_status'] = 'results'
|
274 |
+
st.experimental_rerun()
|
275 |
+
|
276 |
+
if st.session_state['page_status']=='results':
|
277 |
sent_1 = st.session_state['sent_1']
|
278 |
sent_2 = st.session_state['sent_2']
|
279 |
+
multihead = st.session_state['multihead']
|
280 |
+
heads = st.session_state['heads']
|
281 |
+
reps = st.session_state['reps']
|
|
|
|
|
|
|
282 |
|
283 |
option_1_locs, option_2_locs = {}, {}
|
284 |
pron_locs = {}
|
|
|
297 |
pron_locs[f'sent_{sent_id}'],
|
298 |
option_2_locs[f'sent_{sent_id}'],mask_id)
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
|
301 |
option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
|
302 |
option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
|
|
|
321 |
assert np.all(compare_1.astype(int)==compare_2.astype(int))
|
322 |
context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation
|
323 |
|
|
|
324 |
assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
|
325 |
assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
|
326 |
assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
|
327 |
token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs
|
|
|
328 |
|
329 |
effect_array = []
|
330 |
for token_id in token_id_list:
|
331 |
token_id += 1
|
332 |
effect_list = []
|
333 |
for layer_id in range(num_layers):
|
334 |
+
interventions = [create_interventions(token_id,reps,num_heads,multihead,[head_id-1 for head_id in heads])
|
335 |
+
if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
336 |
if multihead:
|
337 |
probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
338 |
else:
|
339 |
+
probs = run_intervention(interventions,len(heads),skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
340 |
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
|
341 |
effect_list.append(effect)
|
342 |
effect_array.append(effect_list)
|
343 |
effect_array = np.transpose(np.array(effect_array),(1,0,2))
|
344 |
|
345 |
+
if multihead:
|
346 |
+
show_results(effect_array[:,:,0],masked_ids_option_1['sent_1'],token_id_list,num_layers)
|
347 |
+
else:
|
348 |
+
tabs = st.tabs(heads)
|
349 |
+
for i,tab in enumerate(tabs):
|
350 |
+
with tab:
|
351 |
+
show_results(effect_array[:,:,i],masked_ids_option_1['sent_1'],token_id_list,num_layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|