taka-yamakoshi commited on
Commit
9096322
1 Parent(s): 7397208
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -211,10 +211,19 @@ if __name__=='__main__':
211
 
212
  st.write(option_1_locs)
213
  st.write(option_2_locs)
 
214
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
215
  st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
216
 
217
  if st.session_state['page_status'] == 'finish_debug':
 
 
 
 
 
 
 
 
218
  for layer_id in range(num_layers):
219
  interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
220
  for masked_ids in [masked_ids_option_1, masked_ids_option_2]:
 
211
 
212
  st.write(option_1_locs)
213
  st.write(option_2_locs)
214
+ st.write(pron_locs)
215
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
216
  st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
217
 
218
  if st.session_state['page_status'] == 'finish_debug':
219
+ option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])]
220
+ option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])]
221
+ option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])]
222
+ option_2_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_2_locs['sent_2'])]
223
+ assert np.all(option_1_tokens_1==option_1_tokens_2) and np.all(option_2_tokens_1==option_2_tokens_2)
224
+ option_1_tokens = option_1_tokens_1
225
+ option_2_tokens = option_2_tokens_1
226
+
227
  for layer_id in range(num_layers):
228
  interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
229
  for masked_ids in [masked_ids_option_1, masked_ids_option_2]: