taka-yamakoshi commited on
Commit
0b05f1f
1 Parent(s): 271977c
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -144,7 +144,7 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
144
 
145
  def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
146
  probs = []
147
- for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]]):
148
  input_ids = torch.tensor([
149
  *[masked_ids['sent_1'] for _ in range(batch_size)],
150
  *[masked_ids['sent_2'] for _ in range(batch_size)]
@@ -253,4 +253,3 @@ if __name__=='__main__':
253
  for layer_id in range(num_layers):
254
  interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
255
  probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
256
-
 
144
 
145
  def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
146
  probs = []
147
+ for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
148
  input_ids = torch.tensor([
149
  *[masked_ids['sent_1'] for _ in range(batch_size)],
150
  *[masked_ids['sent_2'] for _ in range(batch_size)]
 
253
  for layer_id in range(num_layers):
254
  interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
255
  probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)