taka-yamakoshi commited on
Commit
6800334
1 Parent(s): 7c56f41

bring back run

Browse files
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -142,7 +142,7 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
142
  # note annotations are shifted by 1 because special tokens were omitted
143
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
144
 
145
- '''
146
  def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
147
  probs = []
148
  for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
@@ -159,7 +159,6 @@ def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_i
159
  probs = np.array(probs)
160
  assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
161
  return probs
162
- '''
163
 
164
  if __name__=='__main__':
165
  wide_setup()
 
142
  # note annotations are shifted by 1 because special tokens were omitted
143
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
144
 
145
+
146
  def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
147
  probs = []
148
  for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
 
159
  probs = np.array(probs)
160
  assert probs.shape[0]==2 and probs.shape[1]==2 and probs.shape[2]==batch_size
161
  return probs
 
162
 
163
  if __name__=='__main__':
164
  wide_setup()