taka-yamakoshi commited on
Commit
6cb1c36
1 Parent(s): a4dd7f0
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -237,6 +237,7 @@ if __name__=='__main__':
237
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
238
  st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
239
 
 
240
  option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
241
  option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
242
  option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
@@ -245,7 +246,6 @@ if __name__=='__main__':
245
  option_1_tokens = option_1_tokens_1
246
  option_2_tokens = option_2_tokens_1
247
 
248
- if st.session_state['page_status'] == 'finish_debug':
249
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
250
  probs_original = run(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
251
  st.write(probs_original)
 
237
  for token_ids in [masked_ids_option_1['sent_1'],masked_ids_option_1['sent_2'],masked_ids_option_2['sent_1'],masked_ids_option_2['sent_2']]:
238
  st.write(' '.join([tokenizer.decode([token]) for token in token_ids]))
239
 
240
+ if st.session_state['page_status'] == 'finish_debug':
241
  option_1_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_1_locs['sent_1'])+1]
242
  option_1_tokens_2 = np.array(input_ids_dict['sent_2'])[np.array(option_1_locs['sent_2'])+1]
243
  option_2_tokens_1 = np.array(input_ids_dict['sent_1'])[np.array(option_2_locs['sent_1'])+1]
 
246
  option_1_tokens = option_1_tokens_1
247
  option_2_tokens = option_2_tokens_1
248
 
 
249
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
250
  probs_original = run(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
251
  st.write(probs_original)