taka-yamakoshi
commited on
Commit
•
0b05f1f
1
Parent(s):
271977c
debug
Browse files
app.py
CHANGED
@@ -144,7 +144,7 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
|
|
144 |
|
145 |
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
146 |
probs = []
|
147 |
-
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]
|
148 |
input_ids = torch.tensor([
|
149 |
*[masked_ids['sent_1'] for _ in range(batch_size)],
|
150 |
*[masked_ids['sent_2'] for _ in range(batch_size)]
|
@@ -253,4 +253,3 @@ if __name__=='__main__':
|
|
253 |
for layer_id in range(num_layers):
|
254 |
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
255 |
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
256 |
-
|
|
|
144 |
|
145 |
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
146 |
probs = []
|
147 |
+
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
|
148 |
input_ids = torch.tensor([
|
149 |
*[masked_ids['sent_1'] for _ in range(batch_size)],
|
150 |
*[masked_ids['sent_2'] for _ in range(batch_size)]
|
|
|
253 |
for layer_id in range(num_layers):
|
254 |
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
255 |
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
|