taka-yamakoshi commited on
Commit
28525ba
1 Parent(s): 6b8bbf9

change name of func

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -142,7 +142,7 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
142
  # note annotations are shifted by 1 because special tokens were omitted
143
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
144
 
145
- def run(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
146
  probs = []
147
  for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]]):
148
  input_ids = torch.tensor([
@@ -247,9 +247,9 @@ if __name__=='__main__':
247
  option_2_tokens = option_2_tokens_1
248
 
249
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
250
- probs_original = run(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
251
  st.write(probs_original)
252
 
253
  for layer_id in range(num_layers):
254
  interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
255
- probs = run(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
 
142
  # note annotations are shifted by 1 because special tokens were omitted
143
  return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
144
 
145
+ def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
146
  probs = []
147
  for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]]):
148
  input_ids = torch.tensor([
 
247
  option_2_tokens = option_2_tokens_1
248
 
249
  interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
250
+ probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
251
  st.write(probs_original)
252
 
253
  for layer_id in range(num_layers):
254
  interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
255
+ probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)