taka-yamakoshi commited on
Commit
28c20d6
1 Parent(s): a0471c4
Files changed (1) hide show
  1. app.py +98 -99
app.py CHANGED
@@ -222,102 +222,101 @@ if __name__=='__main__':
222
  st.experimental_rerun()
223
 
224
  if st.session_state['page_status']=='analysis':
225
- with main_area.container():
226
- sent_1 = st.session_state['sent_1']
227
- sent_2 = st.session_state['sent_2']
228
- #show_annotated_sentence(st.session_state['decoded_sent_1'],
229
- # option_locs=st.session_state['option_locs_1'],
230
- # mask_locs=st.session_state['mask_locs_1'])
231
- #show_annotated_sentence(st.session_state['decoded_sent_2'],
232
- # option_locs=st.session_state['option_locs_2'],
233
- # mask_locs=st.session_state['mask_locs_2'])
234
-
235
- option_1_locs, option_2_locs = {}, {}
236
- pron_locs = {}
237
- input_ids_dict = {}
238
- masked_ids_option_1 = {}
239
- masked_ids_option_2 = {}
240
- for sent_id in [1,2]:
241
- option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
242
- pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}']
243
- input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids
244
-
245
- masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
246
- pron_locs[f'sent_{sent_id}'],
247
- option_1_locs[f'sent_{sent_id}'],mask_id)
248
- masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
249
- pron_locs[f'sent_{sent_id}'],
250
- option_2_locs[f'sent_{sent_id}'],mask_id)
251
-
252
- #st.write(option_1_locs)
253
- #st.write(option_2_locs)
254
- #st.write(pron_locs)
255
- #for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
256
- # st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
257
-
258
- option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
259
- option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
260
- option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
261
- option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1]
262
- assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
263
- option_1_tokens = option_1_tokens_1
264
- option_2_tokens = option_2_tokens_1
265
-
266
- interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
267
- probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
268
- df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
269
- [probs_original[0,1][0],probs_original[1,1][0]]],
270
- columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
271
- index=['Sentence 1','Sentence 2'])
272
- cols = st.columns(3)
273
- with cols[1]:
274
- show_instruction('Probability of predicting each option in each sentence',fontsize=12)
275
- st.dataframe(df.style.highlight_max(axis=1),use_container_width=True)
276
-
277
- compare_1 = np.array(masked_ids_option_1['sent_1'])!=np.array(masked_ids_option_1['sent_2'])
278
- compare_2 = np.array(masked_ids_option_2['sent_1'])!=np.array(masked_ids_option_2['sent_2'])
279
- assert np.all(compare_1.astype(int)==compare_2.astype(int))
280
- context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation
281
-
282
- multihead = True
283
- assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
284
- assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
285
- assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
286
- token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs
287
- #st.write(token_id_list)
288
-
289
- effect_array = []
290
- for token_id in token_id_list:
291
- token_id += 1
292
- effect_list = []
293
- for layer_id in range(num_layers):
294
- interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
295
- if multihead:
296
- probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
297
- else:
298
- probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
299
- effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
300
- effect_list.append(effect)
301
- effect_array.append(effect_list)
302
- effect_array = np.transpose(np.array(effect_array),(1,0,2))
303
-
304
- cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
305
- token_id = 0
306
- for col_id,col in enumerate(cols):
307
- with col:
308
- st.write(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
309
- if col_id in token_id_list:
310
- interv_id = token_id_list.index(col_id)
311
- fig,ax = plt.subplots()
312
- ax.set_box_aspect(num_layers)
313
- ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
314
- vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
315
- ax.set_xticks([])
316
- ax.set_xticklabels([])
317
- ax.set_yticks([])
318
- ax.set_yticklabels([])
319
- ax.spines['top'].set_visible(False)
320
- ax.spines['bottom'].set_visible(False)
321
- ax.spines['right'].set_visible(False)
322
- ax.spines['left'].set_visible(False)
323
- st.pyplot(fig)
 
222
  st.experimental_rerun()
223
 
224
  if st.session_state['page_status']=='analysis':
225
+ sent_1 = st.session_state['sent_1']
226
+ sent_2 = st.session_state['sent_2']
227
+ #show_annotated_sentence(st.session_state['decoded_sent_1'],
228
+ # option_locs=st.session_state['option_locs_1'],
229
+ # mask_locs=st.session_state['mask_locs_1'])
230
+ #show_annotated_sentence(st.session_state['decoded_sent_2'],
231
+ # option_locs=st.session_state['option_locs_2'],
232
+ # mask_locs=st.session_state['mask_locs_2'])
233
+
234
+ option_1_locs, option_2_locs = {}, {}
235
+ pron_locs = {}
236
+ input_ids_dict = {}
237
+ masked_ids_option_1 = {}
238
+ masked_ids_option_2 = {}
239
+ for sent_id in [1,2]:
240
+ option_1_locs[f'sent_{sent_id}'], option_2_locs[f'sent_{sent_id}'] = separate_options(st.session_state[f'option_locs_{sent_id}'])
241
+ pron_locs[f'sent_{sent_id}'] = st.session_state[f'mask_locs_{sent_id}']
242
+ input_ids_dict[f'sent_{sent_id}'] = tokenizer(st.session_state[f'sent_{sent_id}']).input_ids
243
+
244
+ masked_ids_option_1[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
245
+ pron_locs[f'sent_{sent_id}'],
246
+ option_1_locs[f'sent_{sent_id}'],mask_id)
247
+ masked_ids_option_2[f'sent_{sent_id}'] = mask_out(input_ids_dict[f'sent_{sent_id}'],
248
+ pron_locs[f'sent_{sent_id}'],
249
+ option_2_locs[f'sent_{sent_id}'],mask_id)
250
+
251
+ #st.write(option_1_locs)
252
+ #st.write(option_2_locs)
253
+ #st.write(pron_locs)
254
+ #for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
255
+ # st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
256
+
257
+ option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
258
+ option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
259
+ option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
260
+ option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])+1]
261
+ assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
262
+ option_1_tokens = option_1_tokens_1
263
+ option_2_tokens = option_2_tokens_1
264
+
265
+ interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
266
+ probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
267
+ df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
268
+ [probs_original[0,1][0],probs_original[1,1][0]]],
269
+ columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
270
+ index=['Sentence 1','Sentence 2'])
271
+ cols = st.columns(3)
272
+ with cols[1]:
273
+ show_instruction('Probability of predicting each option in each sentence',fontsize=12)
274
+ st.dataframe(df.style.highlight_max(axis=1),use_container_width=True)
275
+
276
+ compare_1 = np.array(masked_ids_option_1['sent_1'])!=np.array(masked_ids_option_1['sent_2'])
277
+ compare_2 = np.array(masked_ids_option_2['sent_1'])!=np.array(masked_ids_option_2['sent_2'])
278
+ assert np.all(compare_1.astype(int)==compare_2.astype(int))
279
+ context_locs = list(np.arange(len(masked_ids_option_1['sent_1']))[compare_1]-1) # match the indexing for annotation
280
+
281
+ multihead = True
282
+ assert np.all(np.array(pron_locs['sent_1'])==np.array(pron_locs['sent_2']))
283
+ assert np.all(np.array(option_1_locs['sent_1'])==np.array(option_1_locs['sent_2']))
284
+ assert np.all(np.array(option_2_locs['sent_1'])==np.array(option_2_locs['sent_2']))
285
+ token_id_list = pron_locs['sent_1'] + option_1_locs['sent_1'] + option_2_locs['sent_1'] + context_locs
286
+ #st.write(token_id_list)
287
+
288
+ effect_array = []
289
+ for token_id in token_id_list:
290
+ token_id += 1
291
+ effect_list = []
292
+ for layer_id in range(num_layers):
293
+ interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
294
+ if multihead:
295
+ probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
296
+ else:
297
+ probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
298
+ effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
299
+ effect_list.append(effect)
300
+ effect_array.append(effect_list)
301
+ effect_array = np.transpose(np.array(effect_array),(1,0,2))
302
+
303
+ cols = st.columns(len(masked_ids_option_1['sent_1'])-2)
304
+ token_id = 0
305
+ for col_id,col in enumerate(cols):
306
+ with col:
307
+ st.write(tokenizer.decode([masked_ids_option_1['sent_1'][col_id+1]]))
308
+ if col_id in token_id_list:
309
+ interv_id = token_id_list.index(col_id)
310
+ fig,ax = plt.subplots()
311
+ ax.set_box_aspect(num_layers)
312
+ ax.imshow(effect_array[:,interv_id:interv_id+1,0],cmap=sns.color_palette("light:r", as_cmap=True),
313
+ vmin=effect_array[:,:,0].min(),vmax=effect_array[:,:,0].max())
314
+ ax.set_xticks([])
315
+ ax.set_xticklabels([])
316
+ ax.set_yticks([])
317
+ ax.set_yticklabels([])
318
+ ax.spines['top'].set_visible(False)
319
+ ax.spines['bottom'].set_visible(False)
320
+ ax.spines['right'].set_visible(False)
321
+ ax.spines['left'].set_visible(False)
322
+ st.pyplot(fig)