taka-yamakoshi commited on
Commit
c1ed878
1 Parent(s): 9149baa

impl interv

Browse files
Files changed (1) hide show
  1. custom_modeling_albert_flax.py +22 -0
custom_modeling_albert_flax.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Callable, Optional, Tuple
 
2
 
3
  import numpy as np
4
 
@@ -88,6 +89,27 @@ class CustomFlaxAlbertSelfAttention(nn.Module):
88
  hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
89
  )
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # Convert the boolean attention mask to an attention bias.
92
  if attention_mask is not None:
93
  # attention mask in the form of attention bias
 
1
  from typing import Callable, Optional, Tuple
2
+ from copy import deepcopy
3
 
4
  import numpy as np
5
 
 
89
  hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
90
  )
91
 
92
+ reps = {
93
+ 'lay': hidden_states,
94
+ 'qry': query_states,
95
+ 'key': key_states,
96
+ 'val': value_states,
97
+ }
98
+ if layer_id in interv_dict:
99
+ interv = interv_dict[layer_id]
100
+ for rep_name in ['lay','qry','key','val']:
101
+ if rep_name in interv:
102
+ new_state = deepcopy(reps[rep_name])
103
+ for head_id, pos, swap_ids in interv[rep_name]:
104
+ new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id]
105
+ new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id]
106
+ reps[rep_name] = deepcopy(new_state)
107
+
108
+ hidden_states = deepcopy(reps['lay'])
109
+ query_states = deepcopy(reps['qry'])
110
+ key_states = deepcopy(reps['key'])
111
+ value_states = deepcopy(reps['val'])
112
+
113
  # Convert the boolean attention mask to an attention bias.
114
  if attention_mask is not None:
115
  # attention mask in the form of attention bias