Start-GPT commited on
Commit
142c9a5
1 Parent(s): 9248a10

Create mask_att.py

Browse files
Files changed (1) hide show
  1. server/utils/mask_att.py +80 -0
server/utils/mask_att.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ SEP = '[SEP]'
4
+ CLS = '[CLS]'
5
+ MASK = '[MASK]'
6
+
7
+ def drop_bad_inds(arr, left_drop, right_drop):
8
+ """Given the 4d array returned by attentions of shape (n_layer, n_head, n_left_text, n_right_text),
9
+ return that array modified to drop ind1 from n_left_text and ind2 from n_right_text
10
+ """
11
+ # print("Length of left drop: ", len(left_drop))
12
+ # print("Length of right drop: ", len(left_drop))
13
+ print("Shape of arr: ", arr.shape)
14
+ arr = arr[:, :, ~left_drop, :]
15
+
16
+ # Keys and queries don't match in the final dimension
17
+ if arr.shape[-1] == len(right_drop):
18
+ arr = arr[:, :, :, ~right_drop]
19
+
20
+ return arr
21
+
22
+ def strip_attention(attention):
23
+ """Given an attention output of the BERT model,
24
+ return the same object without CLS and SEP token weightings
25
+ NOTE: Not currently fixing key and query
26
+ """
27
+ attention_out = {}
28
+
29
+ # Iterate through sentence combinations
30
+ # Need queries, keys, att, left_text, right_text
31
+ for i, (k, v) in enumerate(attention.items()):
32
+ stripped_resp = {}
33
+
34
+ left_tokens = np.array(v['left_text'])
35
+ right_tokens = np.array(v['right_text'])
36
+ att = np.array(v['att'])
37
+ # key = np.array(v['keys'])
38
+ # quer = np.array(v['queries'])
39
+
40
+ left_drop = (left_tokens == CLS) | (left_tokens == SEP)
41
+ right_drop = (right_tokens == CLS) | (right_tokens == SEP)
42
+
43
+ att_out = drop_bad_inds(att, left_drop, right_drop)
44
+ # key_out = drop_bad_inds(key, left_drop, right_drop)
45
+ # quer_out = drop_bad_inds(quer, left_drop, right_drop)
46
+ left_out = left_tokens[~left_drop]
47
+ right_out = right_tokens[~right_drop]
48
+
49
+ # assert att_out.shape[:3] == key_out.shape[:3] == quer_out.shape[:3]
50
+ assert att_out.shape[2] == len(left_out)
51
+ assert att_out.shape[3] == len(right_out)
52
+
53
+ stripped_resp['att'] = att_out.tolist()
54
+ stripped_resp['keys'] = v['keys']
55
+ stripped_resp['queries'] = v['queries']
56
+ stripped_resp['left_text'] = left_out.tolist()
57
+ stripped_resp['right_text'] = right_out.tolist()
58
+
59
+ attention_out[k] = stripped_resp
60
+
61
+ return attention_out
62
+
63
+ def mask_attention(deets, maskA, maskB):
64
+ """Deets have form:
65
+ tokens_a, tokens_b, query_tensor.data.numpy(), key_tensor.data.numpy(), attn_tensor.data.numpy()
66
+ Take the first two in tuple and mask according to maskA and maskB which are lists of indices to mask
67
+ """
68
+
69
+ tokens_a = np.array(deets[0])
70
+ tokens_a[maskA] = MASK
71
+ tokens_a.tolist()
72
+
73
+ tokens_b = np.array(deets[1])
74
+ tokens_b[maskb] = MASK
75
+ tokens_b.tolist()
76
+
77
+ deets[0] = tokens_a.tolist()
78
+ deets[1] = tokens_b.tolist()
79
+
80
+ return deets