File size: 2,979 Bytes
ebfe870 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import numpy as np
import torch
import torch.nn.functional as F
import math
@torch.no_grad()
def SkeletonBertLayer(layer_id,layer,hidden,interventions):
attention_layer = layer.attention.self
num_heads = attention_layer.num_attention_heads
head_dim = attention_layer.attention_head_size
assert num_heads*head_dim == hidden.shape[2]
qry = attention_layer.query(hidden)
key = attention_layer.key(hidden)
val = attention_layer.value(hidden)
assert qry.shape == hidden.shape
assert key.shape == hidden.shape
assert val.shape == hidden.shape
# swap representations
reps = {
'lay': hidden,
'qry': qry,
'key': key,
'val': val,
}
for rep_type in ['lay','qry','key','val']:
interv_rep = interventions[layer_id][rep_type]
new_state = reps[rep_type].clone()
for head_id, pos, swap_ids in interv_rep:
new_state[swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
new_state[swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
reps[rep_type] = new_state.clone()
hidden = reps['lay'].clone()
qry = reps['qry'].clone()
key = reps['key'].clone()
val = reps['val'].clone()
#split into multiple heads
split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
#calculate the attention matrix
attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
z_rep_indiv = attn_mat@split_val
z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
hidden_post_attn_res = layer.attention.output.dense(z_rep)+hidden # residual connection
hidden_post_attn = layer.attention.output.LayerNorm(hidden_post_attn_res) # layer_norm
hidden_post_interm = layer.intermediate(hidden_post_attn) # massive feed forward
hidden_post_interm_res = layer.output.dense(hidden_post_interm)+hidden_post_attn # residual connection
new_hidden = layer.output.LayerNorm(hidden_post_interm_res) # layer_norm
return new_hidden
def SkeletonBertForMaskedLM(model,input_ids,interventions):
core_model = model.bert
lm_head = model.cls
output_hidden = []
with torch.no_grad():
hidden = core_model.embeddings(input_ids)
output_hidden.append(hidden)
for layer_id in range(model.config.num_hidden_layers):
layer = core_model.encoder.layer[layer_id]
hidden = SkeletonBertLayer(layer_id,layer,hidden,interventions)
output_hidden.append(hidden)
logits = lm_head(hidden)
return {'logits':logits,'hidden_states':output_hidden}
|