taka-yamakoshi commited on
Commit
77d2a77
1 Parent(s): 3052d18
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -117,15 +117,13 @@ def show_instruction(sent,fontsize=20):
117
  suffix = '</span></p>'
118
  return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
119
 
120
- def create_interventions(token_id,interv_type,num_layers,num_heads):
121
  interventions = {}
122
- for layer_id in range(num_layers):
123
- interventions[layer_id] = {}
124
- if interv_type == 'all':
125
- for rep in ['lay','qry','key','val']:
126
- interventions[layer_id][rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
127
  else:
128
- interventions[layer_id][interv_type] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
129
  return interventions
130
 
131
  def separate_options(option_locs):
@@ -195,7 +193,7 @@ if __name__=='__main__':
195
  mask_locs=st.session_state['mask_locs_2'])
196
 
197
  option_1_locs, option_2_locs = {}, {}
198
- pron_id = {}
199
  input_ids_dict = {}
200
  masked_ids_option_1 = {}
201
  masked_ids_option_2 = {}
@@ -215,14 +213,15 @@ if __name__=='__main__':
215
  st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
216
 
217
  if st.session_state['page_status'] == 'finish_debug':
218
- interventions = create_interventions(16,'all',num_layers=num_layers,num_heads=num_heads)
219
- for masked_ids in [masked_ids_option_1, masked_ids_option_2]:
220
- input_ids = torch.tensor([
221
- *[masked_ids['sent_1'] for _ in range(num_heads)],
222
- *[masked_ids['sent_2'] for _ in range(num_heads)]
223
- ])
224
- outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
225
- logprobs = F.log_softmax(outputs['logits'], dim = -1)
 
226
 
227
 
228
  preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
 
117
  suffix = '</span></p>'
118
  return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
119
 
120
+ def create_interventions(token_id,interv_types,num_heads):
121
  interventions = {}
122
+ for rep in ['lay','qry','key','val']:
123
+ if rep in interv_types:
124
+ interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
 
 
125
  else:
126
+ interventions[rep] = []
127
  return interventions
128
 
129
  def separate_options(option_locs):
 
193
  mask_locs=st.session_state['mask_locs_2'])
194
 
195
  option_1_locs, option_2_locs = {}, {}
196
+ pron_locs = {}
197
  input_ids_dict = {}
198
  masked_ids_option_1 = {}
199
  masked_ids_option_2 = {}
 
213
  st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
214
 
215
  if st.session_state['page_status'] == 'finish_debug':
216
+ for layer_id in range(num_layers):
217
+ interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
218
+ for masked_ids in [masked_ids_option_1, masked_ids_option_2]:
219
+ input_ids = torch.tensor([
220
+ *[masked_ids['sent_1'] for _ in range(num_heads)],
221
+ *[masked_ids['sent_2'] for _ in range(num_heads)]
222
+ ])
223
+ outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
224
+ logprobs = F.log_softmax(outputs['logits'], dim = -1)
225
 
226
 
227
  preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]