taka-yamakoshi commited on
Commit
b218eb4
1 Parent(s): e87e116
Files changed (1) hide show
  1. skeleton_modeling_albert.py +7 -2
skeleton_modeling_albert.py CHANGED
@@ -10,11 +10,16 @@ def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
10
  attention_layer = layer.attention
11
  num_heads = attention_layer.num_attention_heads
12
  head_dim = attention_layer.attention_head_size
 
13
 
14
  qry = attention_layer.query(hidden)
15
  key = attention_layer.key(hidden)
16
  val = attention_layer.value(hidden)
17
 
 
 
 
 
18
  # swap representations
19
  interv_layer = interventions.pop(layer_id,None)
20
  if interv_layer is not None:
@@ -29,8 +34,8 @@ def SkeletonAlbertLayer(layer_id,layer,hidden,interventions):
29
  if interv_rep is not None:
30
  new_state = reps[rep_type].clone()
31
  for head_id, pos, swap_ids in interv_rep:
32
- new_state[swap_ids[0],pos,head_id] = reps[rep_name][swap_ids[1],pos,head_id]
33
- new_state[swap_ids[1],pos,head_id] = reps[rep_name][swap_ids[0],pos,head_id]
34
  reps[rep_type] = new_state.clone()
35
 
36
  hidden = reps['lay'].clone()
 
10
  attention_layer = layer.attention
11
  num_heads = attention_layer.num_attention_heads
12
  head_dim = attention_layer.attention_head_size
13
+ assert num_heads*head_dim == hidden.shape[2]
14
 
15
  qry = attention_layer.query(hidden)
16
  key = attention_layer.key(hidden)
17
  val = attention_layer.value(hidden)
18
 
19
+ assert qry.shape == hidden.shape
20
+ assert key.shape == hidden.shape
21
+ assert val.shape == hidden.shape
22
+
23
  # swap representations
24
  interv_layer = interventions.pop(layer_id,None)
25
  if interv_layer is not None:
 
34
  if interv_rep is not None:
35
  new_state = reps[rep_type].clone()
36
  for head_id, pos, swap_ids in interv_rep:
37
+ new_state[swap_ids[0],pos,head_dim*head_id:head_dim*(head_id+1)] = reps[rep_type][swap_ids[1],pos,head_dim*head_id:head_dim*(head_id+1)]
38
+ new_state[swap_ids[1],pos,head_dim*head_id:head_dim*(head_id+1)] = reps[rep_type][swap_ids[0],pos,head_dim*head_id:head_dim*(head_id+1)]
39
  reps[rep_type] = new_state.clone()
40
 
41
  hidden = reps['lay'].clone()