asigalov61 commited on
Commit
7583533
·
verified ·
1 Parent(s): 6fade3e

Upload x_transformer_1_23_2.py

Browse files
Files changed (1) hide show
  1. x_transformer_1_23_2.py +2465 -0
x_transformer_1_23_2.py ADDED
@@ -0,0 +1,2465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #===================================================================================================================
2
+ #
3
+ # X Trasformer Module
4
+ #
5
+ # Partial x-transformers code With useful modifications
6
+ #
7
+ # Version 1.0
8
+ #
9
+ # Original source code courtesy of lucidrains
10
+ # https://github.com/lucidrains/x-transformers
11
+ #
12
+ # Original source code retrieved on 10/10/2023
13
+ #
14
+ # Project Los Angeles
15
+ # Tegridy Code 2023
16
+
17
+ #===================================================================================================================
18
+
19
+ # Critical dependencies
20
+ #
21
+ # !pip install torch
22
+ # !pip install einops
23
+
24
+ #===================================================================================================================
25
+
26
+ from functools import partial
27
+ from typing import Optional, Tuple
28
+
29
+ import torch
30
+ from torch import nn, einsum, Tensor
31
+ import torch.nn.functional as F
32
+ from torch.nn.attention import SDPBackend, sdpa_kernel
33
+
34
+ from collections import namedtuple
35
+ from functools import wraps
36
+ from packaging import version
37
+ from dataclasses import dataclass
38
+
39
+ from einops import rearrange, repeat
40
+
41
+ # constants
42
+
43
+ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
44
+
45
+ @dataclass
46
+ class Intermediates:
47
+ qk_similarities: Optional[Tensor] = None
48
+ pre_softmax_attn: Optional[Tensor] = None
49
+ post_softmax_attn: Optional[Tensor] = None
50
+ cached_kv: Optional[Tuple[Tensor, Tensor]] = None
51
+
52
+ def to_tuple(self):
53
+ return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
54
+
55
+ # helpers
56
+
57
+ def exists(val):
58
+ return val is not None
59
+
60
+ def default(val, d):
61
+ return val if exists(val) else d
62
+
63
+ def compact(arr):
64
+ return [*filter(exists, arr)]
65
+
66
+ def once(fn):
67
+ called = False
68
+ @wraps(fn)
69
+ def inner(x):
70
+ nonlocal called
71
+ if called:
72
+ return
73
+ called = True
74
+ return fn(x)
75
+ return inner
76
+
77
+ print_once = once(print)
78
+
79
+ # functions for creating causal mask
80
+ # need a special one for onnx cpu (no support for .triu)
81
+
82
+ def create_causal_mask(i, j, device):
83
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
84
+
85
+ def onnx_create_causal_mask(i, j, device):
86
+ r = torch.arange(i, device = device)
87
+ causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
88
+ causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
89
+ return causal_mask
90
+
91
+ # main class
92
+
93
+ class Attend(nn.Module):
94
+ def __init__(
95
+ self,
96
+ *,
97
+ dropout = 0.,
98
+ causal = False,
99
+ heads = None,
100
+ talking_heads = False,
101
+ sparse_topk = None,
102
+ scale = None,
103
+ qk_norm = False,
104
+ flash = False,
105
+ add_zero_kv = False,
106
+ onnxable = False
107
+ ):
108
+ super().__init__()
109
+ self.scale = scale
110
+ self.qk_norm = qk_norm
111
+
112
+ self.causal = causal
113
+ self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
114
+
115
+ self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
116
+
117
+ self.dropout = dropout
118
+ self.attn_dropout = nn.Dropout(dropout)
119
+
120
+ # talking heads
121
+
122
+ assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
123
+
124
+ self.talking_heads = talking_heads
125
+ if talking_heads:
126
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
127
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
128
+
129
+ # sparse topk
130
+
131
+ assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
132
+ self.sparse_topk = sparse_topk
133
+
134
+ # add a key / value token composed of zeros
135
+ # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
136
+
137
+ self.add_zero_kv = add_zero_kv
138
+
139
+ # flash attention
140
+
141
+ self.flash = flash
142
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
143
+
144
+ # determine efficient attention configs for cuda and cpu
145
+
146
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
147
+ self.cuda_config = None
148
+
149
+ if not torch.cuda.is_available() or not flash:
150
+ return
151
+
152
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
153
+
154
+ major, minor = device_properties.major, device_properties.minor
155
+
156
+ if (major, minor) == (8, 0):
157
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
158
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
159
+ elif (major, minor) == (9, 0):
160
+ print_once('H100 GPU detected, using flash attention')
161
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
162
+ else:
163
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
164
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
165
+
166
+ def flash_attn(
167
+ self,
168
+ q, k, v,
169
+ mask = None,
170
+ attn_bias = None
171
+ ):
172
+ batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
173
+
174
+ # Recommended for multi-query single-key-value attention by Tri Dao
175
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
176
+
177
+ if k.ndim == 3:
178
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
179
+
180
+ if v.ndim == 3:
181
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
182
+
183
+ # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
184
+
185
+ if self.qk_norm:
186
+ default_scale = q.shape[-1] ** -0.5
187
+ q = q * (self.scale / default_scale)
188
+
189
+ # Check if mask exists and expand to compatible shape
190
+ # The mask is B L, so it would have to be expanded to B H N L
191
+
192
+ causal = self.causal
193
+
194
+ # in the case of kv caching with one token (q_len == 1), just turn off causal masking
195
+ # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
196
+
197
+ if q_len == 1 and causal:
198
+ causal = False
199
+
200
+ # expand key padding mask
201
+
202
+ if exists(mask):
203
+ assert mask.ndim == 4
204
+ mask = mask.expand(batch, heads, q_len, k_len)
205
+
206
+ # handle kv cache - this should be bypassable in updated flash attention 2
207
+
208
+ if k_len > q_len and causal:
209
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
210
+ if not exists(mask):
211
+ mask = ~causal_mask
212
+ else:
213
+ mask = mask & ~causal_mask
214
+ causal = False
215
+
216
+ # manually handle causal mask, if another mask was given
217
+
218
+ row_is_entirely_masked = None
219
+
220
+ if exists(mask) and causal:
221
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
222
+ mask = mask & ~causal_mask
223
+
224
+ # protect against an entire row being masked out
225
+
226
+ row_is_entirely_masked = ~mask.any(dim = -1)
227
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
228
+
229
+ causal = False
230
+
231
+ # handle alibi positional bias
232
+ # convert from bool to float
233
+
234
+ if exists(attn_bias):
235
+ attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
236
+
237
+ # if mask given, the mask would already contain the causal mask from above logic
238
+ # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
239
+
240
+ mask_value = -torch.finfo(q.dtype).max
241
+
242
+ if exists(mask):
243
+ attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
244
+ elif causal:
245
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
246
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
247
+ causal = False
248
+
249
+ # scaled_dot_product_attention handles attn_mask either as bool or additive bias
250
+ # make it an additive bias here
251
+
252
+ mask = attn_bias
253
+
254
+ # Check if there is a compatible device for flash attention
255
+
256
+ config = self.cuda_config if is_cuda else self.cpu_config
257
+
258
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
259
+
260
+ # Legacy code...
261
+ # with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
262
+
263
+ # New SDP kernel code...
264
+ # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
265
+ with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
266
+
267
+ out = F.scaled_dot_product_attention(
268
+ q, k, v,
269
+ attn_mask = mask,
270
+ dropout_p = self.dropout if self.training else 0.,
271
+ is_causal = causal
272
+ )
273
+
274
+ # for a row that is entirely masked out, should zero out the output of that row token
275
+
276
+ if exists(row_is_entirely_masked):
277
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
278
+
279
+ return out, Intermediates()
280
+
281
+ def forward(
282
+ self,
283
+ q, k, v,
284
+ mask = None,
285
+ attn_bias = None,
286
+ prev_attn = None
287
+ ):
288
+ """
289
+ einstein notation
290
+ b - batch
291
+ h - heads
292
+ n, i, j - sequence length (base sequence length, source, target)
293
+ d - feature dimension
294
+ """
295
+
296
+ n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
297
+
298
+ scale = default(self.scale, q.shape[-1] ** -0.5)
299
+
300
+ causal = self.causal
301
+
302
+ # handle kv cached decoding
303
+
304
+ if n == 1 and causal:
305
+ causal = False
306
+
307
+ # handle grouped multi-query attention
308
+
309
+ if kv_heads == 1:
310
+ k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
311
+ elif kv_heads < heads:
312
+ k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))
313
+
314
+ # handle zero kv, as means for allowing network to attend to nothing
315
+
316
+ if self.add_zero_kv:
317
+ k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))
318
+
319
+ if exists(mask):
320
+ mask = F.pad(mask, (1, 0), value = True)
321
+
322
+ if exists(attn_bias):
323
+ attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
324
+
325
+ if self.flash:
326
+ assert not exists(prev_attn), 'residual attention not compatible with flash attention'
327
+ return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
328
+
329
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
330
+
331
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
332
+
333
+ if exists(prev_attn):
334
+ dots = dots + prev_attn
335
+
336
+ qk_similarities = dots.clone()
337
+
338
+ if self.talking_heads:
339
+ dots = self.pre_softmax_talking_heads(dots)
340
+
341
+ if exists(attn_bias):
342
+ dots = dots + attn_bias
343
+
344
+ i, j, dtype = *dots.shape[-2:], dots.dtype
345
+
346
+ mask_value = -torch.finfo(dots.dtype).max
347
+
348
+ if exists(self.sparse_topk) and self.sparse_topk < j:
349
+ top_values, _ = dots.topk(self.sparse_topk, dim = -1)
350
+ sparse_topk_mask = dots < top_values[..., -1:]
351
+ mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
352
+
353
+ if exists(mask):
354
+ dots = dots.masked_fill(~mask, mask_value)
355
+
356
+ if causal:
357
+ causal_mask = self.create_causal_mask(i, j, device = device)
358
+ dots = dots.masked_fill(causal_mask, mask_value)
359
+
360
+ pre_softmax_attn = dots.clone()
361
+
362
+ attn = self.attn_fn(dots, dim = -1)
363
+ attn = attn.type(dtype)
364
+
365
+ post_softmax_attn = attn.clone()
366
+
367
+ attn = self.attn_dropout(attn)
368
+
369
+ if self.talking_heads:
370
+ attn = self.post_softmax_talking_heads(attn)
371
+
372
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
373
+
374
+ intermediates = Intermediates(
375
+ qk_similarities = qk_similarities,
376
+ pre_softmax_attn = pre_softmax_attn,
377
+ post_softmax_attn = post_softmax_attn
378
+ )
379
+
380
+ return out, intermediates
381
+
382
+ #===================================================================================================================
383
+
384
+ from math import ceil, log
385
+ from typing import Optional, Union, Tuple, Callable
386
+
387
+ import torch
388
+ from torch import nn, Tensor
389
+ from torch.nn import Module
390
+ import torch.nn.functional as F
391
+
392
+ from einops import rearrange, pack, unpack
393
+
394
+ def exists(val):
395
+ return val is not None
396
+
397
+ def default(val, d):
398
+ return val if exists(val) else d
399
+
400
+ def identity(t, *args, **kwargs):
401
+ return t
402
+
403
+ def cast_tuple(t, length = 1):
404
+ return t if isinstance(t, tuple) else (t,) * length
405
+
406
+ def eval_decorator(fn):
407
+ def inner(self, *args, **kwargs):
408
+ was_training = self.training
409
+ self.eval()
410
+ out = fn(self, *args, **kwargs)
411
+ self.train(was_training)
412
+ return out
413
+ return inner
414
+
415
+ # for variable lengthed prefixes
416
+
417
+ def align_right(t, lens, pad_id = 0):
418
+ batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
419
+
420
+ assert lens.ndim == 1 and lens.shape[0] == batch
421
+ assert lens.amax() <= seq_len
422
+
423
+ pad_lens = seq_len - lens
424
+ max_pad_len = pad_lens.amax()
425
+
426
+ batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
427
+ prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
428
+
429
+ t = F.pad(t, (max_pad_len, 0), value = 0)
430
+ offset = max_pad_len - pad_lens
431
+
432
+ aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
433
+ return aligned
434
+
435
+ # nucleus
436
+
437
+ def top_p(logits, thres = 0.9):
438
+ sorted_logits, sorted_indices = torch.sort(logits, descending = True)
439
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
440
+
441
+ sorted_indices_to_remove = cum_probs > thres
442
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
443
+
444
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
445
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
446
+
447
+ # topk
448
+
449
+ def top_k(logits, frac_num_tokens = 0.1, k = None):
450
+ num_tokens = logits.shape[-1]
451
+
452
+ k = default(k, ceil(frac_num_tokens * num_tokens))
453
+ k = min(k, num_tokens)
454
+
455
+ val, ind = torch.topk(logits, k)
456
+ probs = torch.full_like(logits, float('-inf'))
457
+ probs.scatter_(1, ind, val)
458
+ return probs
459
+
460
+ # top_a
461
+
462
+ def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
463
+ probs = F.softmax(logits, dim = -1)
464
+ max_probs = torch.amax(probs, dim = -1, keepdim = True)
465
+ limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
466
+ return torch.where(probs < limit, float('-inf'), logits)
467
+
468
+ # contrastive decoding function
469
+
470
+ def contrastive_decode_fn(
471
+ expert_logits,
472
+ amateur_logits,
473
+ alpha = 0.1,
474
+ beta = 0.5
475
+ ):
476
+ """
477
+ Appendix A Algorithm 2
478
+ https://arxiv.org/abs/2309.09117
479
+ """
480
+
481
+ cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
482
+ diffs = (1 + beta) * expert_logits - beta * amateur_logits
483
+ contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
484
+ return contrastive_decode_logits
485
+
486
+ # autoregressive wrapper class
487
+
488
+ class AutoregressiveWrapper(Module):
489
+ def __init__(
490
+ self,
491
+ net,
492
+ ignore_index = -100,
493
+ pad_value = 0,
494
+ mask_prob = 0.,
495
+ add_attn_z_loss = False
496
+ ):
497
+ super().__init__()
498
+ self.pad_value = pad_value
499
+ self.ignore_index = ignore_index
500
+
501
+ self.net = net
502
+ self.max_seq_len = net.max_seq_len
503
+
504
+ # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
505
+ assert mask_prob < 1.
506
+ self.mask_prob = mask_prob
507
+
508
+ # whether to add router z-loss
509
+ self.add_attn_z_loss = add_attn_z_loss
510
+
511
+ @torch.no_grad()
512
+ @eval_decorator
513
+ def generate(
514
+ self,
515
+ prompts,
516
+ seq_len,
517
+ eos_token = None,
518
+ temperature = 1.,
519
+ prompt_lens: Optional[Tensor] = None,
520
+ filter_logits_fn: Callable = top_k,
521
+ restrict_to_max_seq_len = True,
522
+ amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
523
+ filter_kwargs: dict = dict(),
524
+ contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
525
+ beta = 0.5,
526
+ alpha = 0.1
527
+ ),
528
+ cache_kv = True,
529
+ verbose=True,
530
+ return_prime=False,
531
+ **kwargs
532
+ ):
533
+ max_seq_len, device = self.max_seq_len, prompts.device
534
+
535
+ prompts, ps = pack([prompts], '* n')
536
+
537
+ b, t = prompts.shape
538
+
539
+ # handle variable lengthed prompts (prefixes)
540
+
541
+ seq_start_pos = None
542
+ if exists(prompt_lens):
543
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
544
+ seq_start_pos = t - prompt_lens
545
+
546
+ # output from which sampled tokens appended to
547
+
548
+ out = prompts
549
+
550
+ if verbose:
551
+ print("Generating sequence of max length:", seq_len)
552
+
553
+ # kv caches
554
+
555
+ cache = None
556
+
557
+ # if doing contrastive decoding, turn off filter automatically
558
+
559
+ if exists(amateur_model):
560
+ amateur_model = cast_tuple(amateur_model)
561
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
562
+
563
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
564
+
565
+ amateur_caches = [None] * len(amateur_model)
566
+ filter_logits_fn = identity
567
+
568
+ for i, module in enumerate(amateur_model):
569
+ if isinstance(module, AutoregressiveWrapper):
570
+ amateur_model[i] = module.net
571
+
572
+ module.eval()
573
+
574
+ # sampling up to seq_len
575
+
576
+ for sl in range(seq_len):
577
+
578
+ if restrict_to_max_seq_len:
579
+ x = out[:, -max_seq_len:]
580
+
581
+ if exists(cache):
582
+ for inter in cache.attn_intermediates:
583
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
584
+
585
+ logits, new_cache = self.net(
586
+ x,
587
+ return_intermediates = True,
588
+ cache = cache,
589
+ seq_start_pos = seq_start_pos,
590
+ **kwargs
591
+ )
592
+
593
+ if cache_kv and self.net.can_cache_kv:
594
+ cache = new_cache
595
+
596
+ logits = logits[:, -1]
597
+
598
+ # handle contrastive decoding, Li et al.
599
+ # https://arxiv.org/abs/2210.15097
600
+
601
+ if exists(amateur_model):
602
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
603
+ amateur_logits, next_amateur_cache = amateur(
604
+ x,
605
+ return_intermediates = True,
606
+ cache = amateur_cache,
607
+ seq_start_pos = seq_start_pos,
608
+ **kwargs
609
+ )
610
+
611
+ amateur_logits = amateur_logits[:, -1]
612
+
613
+ assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
614
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
615
+
616
+ if cache_kv and amateur.can_cache_kv:
617
+ amateur_caches[i] = next_amateur_cache
618
+
619
+ # filter by top_k, top_p (nucleus), top_a, or custom
620
+
621
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
622
+
623
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
624
+
625
+ sample = torch.multinomial(probs, 1)
626
+
627
+ out = torch.cat((out, sample), dim=-1)
628
+
629
+ if verbose:
630
+ if sl % 32 == 0:
631
+ print(sl, '/', seq_len)
632
+
633
+ if exists(eos_token):
634
+ is_eos_tokens = (out == eos_token)
635
+
636
+ if is_eos_tokens.any(dim = -1).all():
637
+ # mask out everything after the eos tokens
638
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
639
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
640
+ out = out.masked_fill(mask, self.pad_value)
641
+
642
+ if verbose:
643
+ print('Model called the end of sequence at:', sl, '/', seq_len)
644
+
645
+ break
646
+
647
+ if return_prime:
648
+ return out[:, :]
649
+
650
+ else:
651
+ return out[:, t:]
652
+
653
+ # out, = unpack(out, ps, '* n')
654
+
655
+ # return out
656
+
657
+ def compute_accuracy(self, logits, labels):
658
+ out = torch.argmax(logits, dim=-1)
659
+ out = out.flatten()
660
+ labels = labels.flatten()
661
+
662
+ mask = (labels != self.ignore_index) # can also be self.pad_value (your choice)
663
+ out = out[mask]
664
+ labels = labels[mask]
665
+
666
+ num_right = (out == labels)
667
+ num_right = torch.sum(num_right).type(torch.float32)
668
+
669
+ acc = num_right / len(labels)
670
+ return acc
671
+
672
+ def forward(self, x, **kwargs):
673
+ seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
674
+
675
+ inp, target = x[:, :-1], x[:, 1:]
676
+ inp = torch.where(inp == ignore_index, self.pad_value, inp)
677
+
678
+ if self.mask_prob > 0.:
679
+ rand = torch.randn(inp.shape, device = x.device)
680
+ rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
681
+ num_mask = min(int(seq * self.mask_prob), seq - 1)
682
+ indices = rand.topk(num_mask, dim = -1).indices
683
+ mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
684
+ kwargs.update(self_attn_kv_mask = mask)
685
+
686
+ logits, cache = self.net(
687
+ inp,
688
+ return_intermediates = True,
689
+ return_attn_z_loss = add_attn_z_loss,
690
+ **kwargs
691
+ )
692
+
693
+ acc = self.compute_accuracy(logits, target)
694
+
695
+ loss = F.cross_entropy(
696
+ rearrange(logits, 'b n c -> b c n'),
697
+ target,
698
+ ignore_index = ignore_index
699
+ )
700
+
701
+ if add_attn_z_loss:
702
+ loss = loss + cache.attn_z_loss
703
+
704
+ return loss, acc
705
+
706
+ #===============================================================================
707
+
708
+ import math
709
+ from random import random
710
+
711
+ import torch
712
+ from torch import nn, einsum, Tensor
713
+ import torch.nn.functional as F
714
+
715
+ from functools import partial, wraps
716
+ from inspect import isfunction
717
+ from collections import namedtuple
718
+ from dataclasses import dataclass
719
+ from typing import List, Callable, Optional
720
+
721
+ from einops import rearrange, repeat, reduce, pack, unpack
722
+ from einops.layers.torch import Rearrange
723
+
724
+ # constants
725
+
726
+ DEFAULT_DIM_HEAD = 64
727
+
728
+ @dataclass
729
+ class LayerIntermediates:
730
+ hiddens: Optional[List[Tensor]] = None
731
+ attn_intermediates: Optional[List[Intermediates]] = None
732
+ layer_hiddens: Optional[List[Tensor]] = None
733
+ attn_z_loss: Optional[Tensor] = None
734
+ mems: Optional[Tensor] = None
735
+
736
+ # helpers
737
+
738
+ def exists(val):
739
+ return val is not None
740
+
741
+ def default(val, d):
742
+ if exists(val):
743
+ return val
744
+ return d() if isfunction(d) else d
745
+
746
+ def cast_tuple(val, depth):
747
+ return val if isinstance(val, tuple) else (val,) * depth
748
+
749
+ def divisible_by(num, den):
750
+ return (num % den) == 0
751
+
752
+ def maybe(fn):
753
+ @wraps(fn)
754
+ def inner(x, *args, **kwargs):
755
+ if not exists(x):
756
+ return x
757
+ return fn(x, *args, **kwargs)
758
+ return inner
759
+
760
+ class always():
761
+ def __init__(self, val):
762
+ self.val = val
763
+ def __call__(self, *args, **kwargs):
764
+ return self.val
765
+
766
+ class not_equals():
767
+ def __init__(self, val):
768
+ self.val = val
769
+ def __call__(self, x, *args, **kwargs):
770
+ return x != self.val
771
+
772
+ class equals():
773
+ def __init__(self, val):
774
+ self.val = val
775
+ def __call__(self, x, *args, **kwargs):
776
+ return x == self.val
777
+
778
+ def Sequential(*modules):
779
+ return nn.Sequential(*filter(exists, modules))
780
+
781
+ # tensor helpers
782
+
783
+ def max_neg_value(tensor):
784
+ return -torch.finfo(tensor.dtype).max
785
+
786
+ def l2norm(t, groups = 1):
787
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
788
+ t = F.normalize(t, p = 2, dim = -1)
789
+ return rearrange(t, '... g d -> ... (g d)')
790
+
791
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
792
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
793
+ zeros = ((0, 0) * dims_from_right)
794
+ return F.pad(t, (*zeros, *pad), value = value)
795
+
796
+ def or_reduce(masks):
797
+ head, *body = masks
798
+ for rest in body:
799
+ head = head | rest
800
+ return head
801
+
802
+ # auxiliary loss helpers
803
+
804
+ def calc_z_loss(
805
+ pre_softmax_attns: List[Tensor],
806
+ mask = None,
807
+ weight = 1.
808
+ ):
809
+ # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
810
+ # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
811
+ # also used in PaLM as one of the measures
812
+
813
+ lse = 0.
814
+
815
+ for attn in pre_softmax_attns:
816
+ lse = lse + attn.logsumexp(dim = -1)
817
+
818
+ loss = torch.square(lse)
819
+ loss = reduce(loss, 'b h n -> b n', 'sum')
820
+
821
+ if not exists(mask):
822
+ return loss.mean() * weight
823
+
824
+ loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
825
+ return loss * weight
826
+
827
+ # init helpers
828
+
829
+ def init_zero_(layer):
830
+ nn.init.constant_(layer.weight, 0.)
831
+ if exists(layer.bias):
832
+ nn.init.constant_(layer.bias, 0.)
833
+
834
+ # keyword argument helpers
835
+
836
+ def pick_and_pop(keys, d):
837
+ values = list(map(lambda key: d.pop(key), keys))
838
+ return dict(zip(keys, values))
839
+
840
+ def group_dict_by_key(cond, d):
841
+ return_val = [dict(),dict()]
842
+ for key in d.keys():
843
+ match = bool(cond(key))
844
+ ind = int(not match)
845
+ return_val[ind][key] = d[key]
846
+ return (*return_val,)
847
+
848
+ def string_begins_with(prefix, str):
849
+ return str.startswith(prefix)
850
+
851
+ def group_by_key_prefix(prefix, d):
852
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
853
+
854
+ def groupby_prefix_and_trim(prefix, d):
855
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
856
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
857
+ return kwargs_without_prefix, kwargs
858
+
859
+ # structured dropout, more effective than traditional attention dropouts
860
+
861
+ def dropout_seq(seq, mask, dropout):
862
+ b, n, *_, device = *seq.shape, seq.device
863
+ logits = torch.randn(b, n, device = device)
864
+
865
+ if exists(mask):
866
+ mask_value = max_neg_value(logits)
867
+ logits = logits.masked_fill(~mask, mask_value)
868
+
869
+ keep_prob = 1. - dropout
870
+ num_keep = max(1, int(keep_prob * n))
871
+ keep_indices = logits.topk(num_keep, dim = 1).indices
872
+
873
+ batch_indices = torch.arange(b, device = device)
874
+ batch_indices = rearrange(batch_indices, 'b -> b 1')
875
+
876
+ seq = seq[batch_indices, keep_indices]
877
+
878
+ if exists(mask):
879
+ seq_counts = mask.sum(dim = -1)
880
+ seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
881
+ keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
882
+
883
+ mask = mask[batch_indices, keep_indices] & keep_mask
884
+
885
+ return seq, mask
886
+
887
+ # activations
888
+
889
+ class ReluSquared(nn.Module):
890
+ def forward(self, x):
891
+ return F.relu(x) ** 2
892
+
893
+ # embedding
894
+
895
+ class TokenEmbedding(nn.Module):
896
+ def __init__(self, dim, num_tokens, l2norm_embed = False):
897
+ super().__init__()
898
+ self.l2norm_embed = l2norm_embed
899
+ self.emb = nn.Embedding(num_tokens, dim)
900
+
901
+ def forward(self, x):
902
+ token_emb = self.emb(x)
903
+ return l2norm(token_emb) if self.l2norm_embed else token_emb
904
+
905
+ # positional embeddings
906
+
907
+ class AbsolutePositionalEmbedding(nn.Module):
908
+ def __init__(self, dim, max_seq_len, l2norm_embed = False):
909
+ super().__init__()
910
+ self.scale = dim ** -0.5 if not l2norm_embed else 1.
911
+ self.max_seq_len = max_seq_len
912
+ self.l2norm_embed = l2norm_embed
913
+ self.emb = nn.Embedding(max_seq_len, dim)
914
+
915
+ def forward(self, x, pos = None, seq_start_pos = None):
916
+ seq_len, device = x.shape[1], x.device
917
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
918
+
919
+ if not exists(pos):
920
+ pos = torch.arange(seq_len, device = device)
921
+
922
+ if exists(seq_start_pos):
923
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
924
+
925
+ pos_emb = self.emb(pos)
926
+ pos_emb = pos_emb * self.scale
927
+ return l2norm(pos_emb) if self.l2norm_embed else pos_emb
928
+
929
+ class ScaledSinusoidalEmbedding(nn.Module):
930
+ def __init__(self, dim, theta = 10000):
931
+ super().__init__()
932
+ assert divisible_by(dim, 2)
933
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
934
+
935
+ half_dim = dim // 2
936
+ freq_seq = torch.arange(half_dim).float() / half_dim
937
+ inv_freq = theta ** -freq_seq
938
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
939
+
940
+ def forward(self, x, pos = None, seq_start_pos = None):
941
+ seq_len, device = x.shape[1], x.device
942
+
943
+ if not exists(pos):
944
+ pos = torch.arange(seq_len, device = device)
945
+
946
+ if exists(seq_start_pos):
947
+ pos = pos - seq_start_pos[..., None]
948
+
949
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
950
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
951
+ return emb * self.scale
952
+
953
+ class RelativePositionBias(nn.Module):
954
+ def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
955
+ super().__init__()
956
+ self.scale = scale
957
+ self.causal = causal
958
+ self.num_buckets = num_buckets
959
+ self.max_distance = max_distance
960
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
961
+
962
+ @staticmethod
963
+ def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
964
+ ret = 0
965
+ n = -relative_position
966
+ if not causal:
967
+ num_buckets //= 2
968
+ ret += (n < 0).long() * num_buckets
969
+ n = torch.abs(n)
970
+ else:
971
+ n = torch.max(n, torch.zeros_like(n))
972
+
973
+ max_exact = num_buckets // 2
974
+ is_small = n < max_exact
975
+
976
+ val_if_large = max_exact + (
977
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
978
+ ).long()
979
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
980
+
981
+ ret += torch.where(is_small, n, val_if_large)
982
+ return ret
983
+
984
+ @property
985
+ def device(self):
986
+ return next(self.parameters()).device
987
+
988
+ def forward(self, i, j):
989
+ device = self.device
990
+ q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
991
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
992
+ rel_pos = k_pos[None, :] - q_pos[:, None]
993
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
994
+ values = self.relative_attention_bias(rp_bucket)
995
+ bias = rearrange(values, 'i j h -> h i j')
996
+ return bias * self.scale
997
+
998
+ class DynamicPositionBias(nn.Module):
999
+ def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
1000
+ super().__init__()
1001
+ assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
1002
+ self.log_distance = log_distance
1003
+
1004
+ self.mlp = nn.ModuleList([])
1005
+
1006
+ self.mlp.append(Sequential(
1007
+ nn.Linear(1, dim),
1008
+ nn.LayerNorm(dim) if norm else None,
1009
+ nn.SiLU()
1010
+ ))
1011
+
1012
+ for _ in range(depth - 1):
1013
+ self.mlp.append(Sequential(
1014
+ nn.Linear(dim, dim),
1015
+ nn.LayerNorm(dim) if norm else None,
1016
+ nn.SiLU()
1017
+ ))
1018
+
1019
+ self.mlp.append(nn.Linear(dim, heads))
1020
+
1021
+ @property
1022
+ def device(self):
1023
+ return next(self.parameters()).device
1024
+
1025
+ def forward(self, i, j):
1026
+ assert i == j
1027
+ n, device = j, self.device
1028
+
1029
+ # get the (n x n) matrix of distances
1030
+ seq_arange = torch.arange(n, device = device)
1031
+ context_arange = torch.arange(n, device = device)
1032
+ indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
1033
+ indices += (n - 1)
1034
+
1035
+ # input to continuous positions MLP
1036
+ pos = torch.arange(-n + 1, n, device = device).float()
1037
+ pos = rearrange(pos, '... -> ... 1')
1038
+
1039
+ if self.log_distance:
1040
+ pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
1041
+
1042
+ for layer in self.mlp:
1043
+ pos = layer(pos)
1044
+
1045
+ # get position biases
1046
+ bias = pos[indices]
1047
+ bias = rearrange(bias, 'i j h -> h i j')
1048
+ return bias
1049
+
1050
+ class AlibiPositionalBias(nn.Module):
1051
+ def __init__(self, heads, total_heads, **kwargs):
1052
+ super().__init__()
1053
+ self.heads = heads
1054
+ self.total_heads = total_heads
1055
+
1056
+ slopes = Tensor(self._get_slopes(heads))
1057
+ slopes = rearrange(slopes, 'h -> h 1 1')
1058
+ self.register_buffer('slopes', slopes, persistent = False)
1059
+ self.register_buffer('bias', None, persistent = False)
1060
+
1061
+ def get_bias(self, i, j, device):
1062
+ i_arange = torch.arange(j - i, j, device = device)
1063
+ j_arange = torch.arange(j, device = device)
1064
+ bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
1065
+ return bias
1066
+
1067
+ @staticmethod
1068
+ def _get_slopes(heads):
1069
+ def get_slopes_power_of_2(n):
1070
+ start = (2**(-2**-(math.log2(n)-3)))
1071
+ ratio = start
1072
+ return [start*ratio**i for i in range(n)]
1073
+
1074
+ if math.log2(heads).is_integer():
1075
+ return get_slopes_power_of_2(heads)
1076
+
1077
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
1078
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
1079
+
1080
+ @property
1081
+ def device(self):
1082
+ return next(self.buffers()).device
1083
+
1084
+ def forward(self, i, j):
1085
+ h, device = self.total_heads, self.device
1086
+
1087
+ if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
1088
+ return self.bias[..., -i:, -j:]
1089
+
1090
+ bias = self.get_bias(i, j, device)
1091
+ bias = bias * self.slopes
1092
+
1093
+ num_heads_unalibied = h - bias.shape[0]
1094
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
1095
+ self.register_buffer('bias', bias, persistent = False)
1096
+
1097
+ return self.bias
1098
+
1099
+ class RotaryEmbedding(nn.Module):
1100
+ def __init__(
1101
+ self,
1102
+ dim,
1103
+ use_xpos = False,
1104
+ scale_base = 512,
1105
+ interpolation_factor = 1.,
1106
+ base = 10000,
1107
+ base_rescale_factor = 1.
1108
+ ):
1109
+ super().__init__()
1110
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
1111
+ # has some connection to NTK literature
1112
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
1113
+ base *= base_rescale_factor ** (dim / (dim - 2))
1114
+
1115
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
1116
+ self.register_buffer('inv_freq', inv_freq)
1117
+
1118
+ assert interpolation_factor >= 1.
1119
+ self.interpolation_factor = interpolation_factor
1120
+
1121
+ if not use_xpos:
1122
+ self.register_buffer('scale', None)
1123
+ return
1124
+
1125
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
1126
+
1127
+ self.scale_base = scale_base
1128
+ self.register_buffer('scale', scale)
1129
+
1130
+ def forward(self, seq_len):
1131
+ device = self.inv_freq.device
1132
+ t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
1133
+
1134
+ t = t / self.interpolation_factor
1135
+
1136
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
1137
+ freqs = torch.cat((freqs, freqs), dim = -1)
1138
+
1139
+ if not exists(self.scale):
1140
+ return freqs, 1.
1141
+
1142
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
1143
+ scale = self.scale ** rearrange(power, 'n -> n 1')
1144
+ scale = torch.cat((scale, scale), dim = -1)
1145
+
1146
+ return freqs, scale
1147
+
1148
+
1149
+ def rotate_half(x):
1150
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
1151
+ x1, x2 = x.unbind(dim = -2)
1152
+ return torch.cat((-x2, x1), dim = -1)
1153
+
1154
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
1155
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
1156
+ freqs = freqs[-seq_len:, :]
1157
+
1158
+ if t.ndim == 4 and freqs.ndim == 3:
1159
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
1160
+
1161
+ # partial rotary embeddings, Wang et al. GPT-J
1162
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
1163
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
1164
+ return torch.cat((t, t_unrotated), dim = -1)
1165
+
1166
+ # norms
1167
+
1168
+ class Scale(nn.Module):
1169
+ def __init__(self, value, fn):
1170
+ super().__init__()
1171
+ self.value = value
1172
+ self.fn = fn
1173
+
1174
+ def forward(self, x, **kwargs):
1175
+ out = self.fn(x, **kwargs)
1176
+ scale_fn = lambda t: t * self.value
1177
+
1178
+ if not isinstance(out, tuple):
1179
+ return scale_fn(out)
1180
+
1181
+ return (scale_fn(out[0]), *out[1:])
1182
+
1183
+ class ScaleNorm(nn.Module):
1184
+ def __init__(self, dim, eps = 1e-5):
1185
+ super().__init__()
1186
+ self.eps = eps
1187
+ self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
1188
+
1189
+ def forward(self, x):
1190
+ norm = torch.norm(x, dim = -1, keepdim = True)
1191
+ return x / norm.clamp(min = self.eps) * self.g
1192
+
1193
+ class RMSNorm(nn.Module):
1194
+ def __init__(self, dim):
1195
+ super().__init__()
1196
+ self.scale = dim ** 0.5
1197
+ self.g = nn.Parameter(torch.ones(dim))
1198
+
1199
+ def forward(self, x):
1200
+ return F.normalize(x, dim = -1) * self.scale * self.g
1201
+
1202
+ class SimpleRMSNorm(nn.Module):
1203
+ def __init__(self, dim):
1204
+ super().__init__()
1205
+ self.scale = dim ** 0.5
1206
+
1207
+ def forward(self, x):
1208
+ return F.normalize(x, dim = -1) * self.scale
1209
+
1210
+ # residual and residual gates
1211
+
1212
+ class Residual(nn.Module):
1213
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
1214
+ super().__init__()
1215
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1216
+ self.scale_residual_constant = scale_residual_constant
1217
+
1218
+ def forward(self, x, residual):
1219
+ if exists(self.residual_scale):
1220
+ residual = residual * self.residual_scale
1221
+
1222
+ if self.scale_residual_constant != 1:
1223
+ residual = residual * self.scale_residual_constant
1224
+
1225
+ return x + residual
1226
+
1227
+ class GRUGating(nn.Module):
1228
+ def __init__(self, dim, scale_residual = False, **kwargs):
1229
+ super().__init__()
1230
+ self.gru = nn.GRUCell(dim, dim)
1231
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1232
+
1233
+ def forward(self, x, residual):
1234
+ if exists(self.residual_scale):
1235
+ residual = residual * self.residual_scale
1236
+
1237
+ gated_output = self.gru(
1238
+ rearrange(x, 'b n d -> (b n) d'),
1239
+ rearrange(residual, 'b n d -> (b n) d')
1240
+ )
1241
+
1242
+ return gated_output.reshape_as(x)
1243
+
1244
+ # token shifting
1245
+
1246
+ def shift(t, amount, mask = None):
1247
+ if amount == 0:
1248
+ return t
1249
+ else:
1250
+ amount = min(amount, t.shape[1])
1251
+
1252
+ if exists(mask):
1253
+ t = t.masked_fill(~mask[..., None], 0.)
1254
+
1255
+ return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
1256
+
1257
+ class ShiftTokens(nn.Module):
1258
+ def __init__(self, shifts, fn):
1259
+ super().__init__()
1260
+ self.fn = fn
1261
+ self.shifts = tuple(shifts)
1262
+
1263
+ def forward(self, x, **kwargs):
1264
+ mask = kwargs.get('mask', None)
1265
+ shifts = self.shifts
1266
+ segments = len(shifts)
1267
+ feats_per_shift = x.shape[-1] // segments
1268
+ splitted = x.split(feats_per_shift, dim = -1)
1269
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
1270
+ segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
1271
+ x = torch.cat((*segments_to_shift, *rest), dim = -1)
1272
+ return self.fn(x, **kwargs)
1273
+
1274
+ # feedforward
1275
+
1276
+ class GLU(nn.Module):
1277
+ def __init__(
1278
+ self,
1279
+ dim_in,
1280
+ dim_out,
1281
+ activation: Callable,
1282
+ mult_bias = False
1283
+ ):
1284
+ super().__init__()
1285
+ self.act = activation
1286
+ self.proj = nn.Linear(dim_in, dim_out * 2)
1287
+ self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
1288
+
1289
+ def forward(self, x):
1290
+ x, gate = self.proj(x).chunk(2, dim = -1)
1291
+ return x * self.act(gate) * self.mult_bias
1292
+
1293
+ class FeedForward(nn.Module):
1294
+ def __init__(
1295
+ self,
1296
+ dim,
1297
+ dim_out = None,
1298
+ mult = 4,
1299
+ glu = False,
1300
+ glu_mult_bias = False,
1301
+ swish = False,
1302
+ relu_squared = False,
1303
+ post_act_ln = False,
1304
+ dropout = 0.,
1305
+ no_bias = False,
1306
+ zero_init_output = False
1307
+ ):
1308
+ super().__init__()
1309
+ inner_dim = int(dim * mult)
1310
+ dim_out = default(dim_out, dim)
1311
+
1312
+ if relu_squared:
1313
+ activation = ReluSquared()
1314
+ elif swish:
1315
+ activation = nn.SiLU()
1316
+ else:
1317
+ activation = nn.GELU()
1318
+
1319
+ if glu:
1320
+ project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1321
+ else:
1322
+ project_in = nn.Sequential(
1323
+ nn.Linear(dim, inner_dim, bias = not no_bias),
1324
+ activation
1325
+ )
1326
+
1327
+ self.ff = Sequential(
1328
+ project_in,
1329
+ nn.LayerNorm(inner_dim) if post_act_ln else None,
1330
+ nn.Dropout(dropout),
1331
+ nn.Linear(inner_dim, dim_out, bias = not no_bias)
1332
+ )
1333
+
1334
+ # init last linear layer to 0
1335
+ if zero_init_output:
1336
+ init_zero_(self.ff[-1])
1337
+
1338
+ def forward(self, x):
1339
+ return self.ff(x)
1340
+
1341
+ # attention. it is all we need
1342
+
1343
+ class Attention(nn.Module):
1344
+ def __init__(
1345
+ self,
1346
+ dim,
1347
+ dim_head = DEFAULT_DIM_HEAD,
1348
+ heads = 8,
1349
+ causal = False,
1350
+ flash = False,
1351
+ talking_heads = False,
1352
+ head_scale = False,
1353
+ sparse_topk = None,
1354
+ num_mem_kv = 0,
1355
+ dropout = 0.,
1356
+ on_attn = False,
1357
+ gate_value_heads = False,
1358
+ gate_values = False,
1359
+ zero_init_output = False,
1360
+ max_attend_past = None,
1361
+ qk_norm = False,
1362
+ qk_norm_groups = 1,
1363
+ qk_norm_scale = 10,
1364
+ qk_norm_dim_scale = False,
1365
+ one_kv_head = False,
1366
+ kv_heads = None,
1367
+ shared_kv = False,
1368
+ value_dim_head = None,
1369
+ tensor_product = False, # https://arxiv.org/abs/2208.06061
1370
+ add_zero_kv = False, # same as add_zero_attn in pytorch
1371
+ rotary_embed_values = False,
1372
+ onnxable = False
1373
+ ):
1374
+ super().__init__()
1375
+ self.scale = dim_head ** -0.5
1376
+
1377
+ self.heads = heads
1378
+ self.causal = causal
1379
+ self.max_attend_past = max_attend_past
1380
+
1381
+ assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
1382
+
1383
+ value_dim_head = default(value_dim_head, dim_head)
1384
+ kv_heads = default(kv_heads, heads)
1385
+
1386
+ kv_heads = 1 if one_kv_head else kv_heads
1387
+ assert divisible_by(heads, kv_heads)
1388
+
1389
+ self.kv_heads = kv_heads
1390
+
1391
+ q_dim = dim_head * heads
1392
+ k_dim = dim_head * kv_heads
1393
+ v_dim = value_dim_head * kv_heads
1394
+ out_dim = value_dim_head * heads
1395
+
1396
+ self.to_q = nn.Linear(dim, q_dim, bias = False)
1397
+ self.to_k = nn.Linear(dim, k_dim, bias = False)
1398
+
1399
+ # shared key / values, for further memory savings during inference
1400
+ assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1401
+ self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
1402
+
1403
+ # relations projection from tp-attention
1404
+ self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
1405
+
1406
+ # add GLU gating for aggregated values, from alphafold2
1407
+ self.to_v_gate = None
1408
+ if gate_values:
1409
+ self.to_v_gate = nn.Linear(dim, out_dim)
1410
+ nn.init.constant_(self.to_v_gate.weight, 0)
1411
+ nn.init.constant_(self.to_v_gate.bias, 10)
1412
+
1413
+ # add per head gating of the output values, from 'Attend to nothing' paper
1414
+ self.to_v_head_gate = None
1415
+ if gate_value_heads:
1416
+ self.to_v_head_gate = nn.Linear(dim, heads)
1417
+ nn.init.constant_(self.to_v_head_gate.weight, 0)
1418
+ nn.init.constant_(self.to_v_head_gate.bias, 10)
1419
+
1420
+ # cosine sim attention
1421
+ self.qk_norm = qk_norm
1422
+ self.qk_norm_groups = qk_norm_groups
1423
+ self.qk_norm_scale = qk_norm_scale
1424
+
1425
+ # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
1426
+ self.qk_norm_dim_scale = qk_norm_dim_scale
1427
+
1428
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1429
+ if qk_norm and qk_norm_dim_scale:
1430
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1431
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1432
+
1433
+ assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1434
+ assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1435
+
1436
+ # attend class - includes core attention algorithm + talking heads
1437
+
1438
+ self.attend = Attend(
1439
+ heads = heads,
1440
+ causal = causal,
1441
+ talking_heads = talking_heads,
1442
+ dropout = dropout,
1443
+ sparse_topk = sparse_topk,
1444
+ qk_norm = qk_norm,
1445
+ scale = qk_norm_scale if qk_norm else self.scale,
1446
+ add_zero_kv = add_zero_kv,
1447
+ flash = flash,
1448
+ onnxable = onnxable
1449
+ )
1450
+
1451
+ # head scaling
1452
+ self.head_scale = head_scale
1453
+ if head_scale:
1454
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
1455
+
1456
+ # explicit topk sparse attention
1457
+ self.sparse_topk = sparse_topk
1458
+
1459
+ # add memory key / values
1460
+ self.num_mem_kv = num_mem_kv
1461
+ if num_mem_kv > 0:
1462
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1463
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1464
+
1465
+ # attention on attention
1466
+ self.attn_on_attn = on_attn
1467
+ self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
1468
+
1469
+ # whether to rotate positions into values, for absolute positions in addition to relative
1470
+ self.rotary_embed_values = rotary_embed_values
1471
+
1472
+ # init output projection 0
1473
+ if zero_init_output:
1474
+ init_zero_(self.to_out)
1475
+
1476
+ def forward(
1477
+ self,
1478
+ x,
1479
+ context = None,
1480
+ mask = None,
1481
+ context_mask = None,
1482
+ attn_mask = None,
1483
+ rel_pos = None,
1484
+ rotary_pos_emb = None,
1485
+ prev_attn = None,
1486
+ mem = None,
1487
+ return_intermediates = False,
1488
+ cache: Optional[Intermediates] = None,
1489
+ ):
1490
+ b, n, _, h, kv_h, head_scale, device, has_context = *x.shape, self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
1491
+ kv_input = default(context, x)
1492
+
1493
+ q_input = x
1494
+ k_input = kv_input
1495
+ v_input = kv_input
1496
+ r_input = x
1497
+
1498
+ if exists(mem):
1499
+ k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1500
+ v_input, _ = pack([mem, v_input], 'b * d')
1501
+
1502
+ q = self.to_q(q_input)
1503
+ k = self.to_k(k_input)
1504
+ v = self.to_v(v_input) if exists(self.to_v) else k
1505
+ r = self.to_r(r_input) if exists(self.to_r) else None
1506
+
1507
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1508
+
1509
+ k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1510
+
1511
+ if exists(cache) and not has_context:
1512
+ ck, cv = cache.cached_kv
1513
+
1514
+ if exists(mem):
1515
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
1516
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
1517
+
1518
+ k = torch.cat((ck, k), dim = -2)
1519
+ v = torch.cat((cv, v), dim = -2)
1520
+
1521
+ if exists(mem):
1522
+ k = torch.cat((mk, k), dim = -2)
1523
+ v = torch.cat((mv, v), dim = -2)
1524
+
1525
+ if return_intermediates:
1526
+ mem_len = mem.shape[-2] if exists(mem) else 0
1527
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1528
+
1529
+ if self.qk_norm:
1530
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1531
+ q, k = map(qk_l2norm, (q, k))
1532
+ scale = self.qk_norm_scale
1533
+
1534
+ q = q * self.qk_norm_q_scale
1535
+ k = k * self.qk_norm_k_scale
1536
+
1537
+ if exists(rotary_pos_emb) and not has_context:
1538
+ freqs, xpos_scale = rotary_pos_emb
1539
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1540
+
1541
+ q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1542
+ k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1543
+
1544
+ if self.rotary_embed_values:
1545
+ v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1546
+
1547
+ input_mask = context_mask
1548
+
1549
+ if not exists(input_mask) and not has_context:
1550
+ input_mask = mask
1551
+
1552
+ if self.num_mem_kv > 0:
1553
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1554
+
1555
+ if self.qk_norm:
1556
+ mem_k = l2norm(mem_k)
1557
+ mem_k = mem_k * self.qk_norm_k_scale
1558
+
1559
+ k = torch.cat((mem_k, k), dim = -2)
1560
+ v = torch.cat((mem_v, v), dim = -2)
1561
+
1562
+ if exists(input_mask):
1563
+ input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1564
+
1565
+ i, j = map(lambda t: t.shape[-2], (q, k))
1566
+
1567
+ # determine masking
1568
+
1569
+ mask_value = max_neg_value(q)
1570
+ masks = []
1571
+ final_attn_mask = None
1572
+
1573
+ if exists(input_mask):
1574
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1575
+ masks.append(~input_mask)
1576
+
1577
+ if exists(attn_mask):
1578
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1579
+ if attn_mask.ndim == 2:
1580
+ attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1581
+ elif attn_mask.ndim == 3:
1582
+ attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1583
+ masks.append(~attn_mask)
1584
+
1585
+ if exists(self.max_attend_past):
1586
+ range_q = torch.arange(j - i, j, device = device)
1587
+ range_k = torch.arange(j, device = device)
1588
+ dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1589
+ max_attend_past_mask = dist > self.max_attend_past
1590
+ masks.append(max_attend_past_mask)
1591
+
1592
+ if len(masks) > 0:
1593
+ final_attn_mask = ~or_reduce(masks)
1594
+
1595
+ # prepare relative positional bias, if needed
1596
+
1597
+ attn_bias = None
1598
+ if exists(rel_pos):
1599
+ attn_bias = rel_pos(i, j)
1600
+
1601
+ # attention is all we need
1602
+
1603
+ out, intermediates = self.attend(
1604
+ q, k, v,
1605
+ mask = final_attn_mask,
1606
+ attn_bias = attn_bias,
1607
+ prev_attn = prev_attn
1608
+ )
1609
+
1610
+ # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1611
+
1612
+ if exists(r):
1613
+ out = out * r + out
1614
+
1615
+ # normformer scaling of heads
1616
+
1617
+ if head_scale:
1618
+ out = out * self.head_scale_params
1619
+
1620
+ # per head gating, from https://arxiv.org/abs/2306.12929
1621
+
1622
+ if exists(self.to_v_head_gate):
1623
+ head_gate = self.to_v_head_gate(x)
1624
+ out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1625
+
1626
+ # merge heads
1627
+
1628
+ out = rearrange(out, 'b h n d -> b n (h d)')
1629
+
1630
+ # alphafold2 styled gating of the values
1631
+
1632
+ if exists(self.to_v_gate):
1633
+ gates = self.to_v_gate(x)
1634
+ out = out * gates.sigmoid()
1635
+
1636
+ # combine the heads
1637
+
1638
+ out = self.to_out(out)
1639
+
1640
+ if exists(mask):
1641
+ mask = rearrange(mask, 'b n -> b n 1')
1642
+ out = out.masked_fill(~mask, 0.)
1643
+
1644
+ if not return_intermediates:
1645
+ return out
1646
+
1647
+ intermediates.cached_kv = cached_kv
1648
+
1649
+ return out, intermediates
1650
+
1651
+ class AttentionLayers(nn.Module):
1652
+ def __init__(
1653
+ self,
1654
+ dim,
1655
+ depth,
1656
+ heads = 8,
1657
+ causal = False,
1658
+ cross_attend = False,
1659
+ only_cross = False,
1660
+ use_scalenorm = False,
1661
+ use_rmsnorm = False,
1662
+ use_simple_rmsnorm = False,
1663
+ alibi_pos_bias = False,
1664
+ alibi_num_heads = None,
1665
+ rel_pos_bias = False,
1666
+ rel_pos_num_buckets = 32,
1667
+ rel_pos_max_distance = 128,
1668
+ dynamic_pos_bias = False,
1669
+ dynamic_pos_bias_log_distance = False,
1670
+ dynamic_pos_bias_mlp_depth = 2,
1671
+ dynamic_pos_bias_norm = False,
1672
+ rotary_pos_emb = False,
1673
+ rotary_emb_dim = None,
1674
+ rotary_xpos = False,
1675
+ rotary_interpolation_factor = 1.,
1676
+ rotary_xpos_scale_base = 512,
1677
+ rotary_base_rescale_factor = 1.,
1678
+ custom_layers = None,
1679
+ sandwich_coef = None,
1680
+ par_ratio = None,
1681
+ weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1682
+ layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1683
+ residual_attn = False,
1684
+ cross_residual_attn = False,
1685
+ macaron = False,
1686
+ pre_norm = True,
1687
+ pre_norm_has_final_norm = True,
1688
+ gate_residual = False,
1689
+ scale_residual = False,
1690
+ scale_residual_constant = 1.,
1691
+ shift_tokens = 0,
1692
+ sandwich_norm = False,
1693
+ resi_dual = False,
1694
+ resi_dual_scale = 1.,
1695
+ zero_init_branch_output = False,
1696
+ layer_dropout = 0.,
1697
+ cross_attn_tokens_dropout = 0.,
1698
+ **kwargs
1699
+ ):
1700
+ super().__init__()
1701
+ rotary_pos_emb = rotary_pos_emb or rotary_xpos
1702
+
1703
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1704
+ attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1705
+
1706
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1707
+
1708
+ self.dim = dim
1709
+ self.depth = depth
1710
+ self.causal = causal
1711
+ self.layers = nn.ModuleList([])
1712
+
1713
+ self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1714
+
1715
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1716
+
1717
+ assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1718
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1719
+
1720
+ assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1721
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1722
+
1723
+ # relative positional bias
1724
+
1725
+ flash_attn = attn_kwargs.get('flash', False)
1726
+ assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1727
+
1728
+ self.rel_pos = None
1729
+ if rel_pos_bias:
1730
+ assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1731
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1732
+ elif dynamic_pos_bias:
1733
+ assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1734
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1735
+ elif alibi_pos_bias:
1736
+ alibi_num_heads = default(alibi_num_heads, heads)
1737
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1738
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1739
+
1740
+ assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1741
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1742
+
1743
+ if resi_dual:
1744
+ pre_norm = False
1745
+
1746
+ self.pre_norm = pre_norm
1747
+ self.sandwich_norm = sandwich_norm
1748
+
1749
+ self.resi_dual = resi_dual
1750
+ assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1751
+ self.resi_dual_scale = resi_dual_scale
1752
+
1753
+ self.residual_attn = residual_attn
1754
+ self.cross_residual_attn = cross_residual_attn
1755
+ assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
1756
+
1757
+ self.cross_attend = cross_attend
1758
+
1759
+ assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1760
+
1761
+ if use_scalenorm:
1762
+ norm_class = ScaleNorm
1763
+ elif use_rmsnorm:
1764
+ norm_class = RMSNorm
1765
+ elif use_simple_rmsnorm:
1766
+ norm_class = SimpleRMSNorm
1767
+ else:
1768
+ norm_class = nn.LayerNorm
1769
+
1770
+ norm_fn = partial(norm_class, dim)
1771
+
1772
+ if cross_attend and not only_cross:
1773
+ default_block = ('a', 'c', 'f')
1774
+ elif cross_attend and only_cross:
1775
+ default_block = ('c', 'f')
1776
+ else:
1777
+ default_block = ('a', 'f')
1778
+
1779
+ if macaron:
1780
+ default_block = ('f',) + default_block
1781
+
1782
+ # zero init
1783
+
1784
+ if zero_init_branch_output:
1785
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1786
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1787
+
1788
+ # setup weight tying, which is a special case of `layer_execute_order`
1789
+
1790
+ assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1791
+
1792
+ if weight_tie_layers:
1793
+ assert not exists(layers_execute_order)
1794
+ layers_execute_order = tuple(range(len(default_block))) * depth
1795
+ depth = 1
1796
+
1797
+ # calculate layer block order
1798
+
1799
+ if exists(custom_layers):
1800
+ layer_types = custom_layers
1801
+ elif exists(par_ratio):
1802
+ par_depth = depth * len(default_block)
1803
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1804
+ default_block = tuple(filter(not_equals('f'), default_block))
1805
+ par_attn = par_depth // par_ratio
1806
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1807
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
1808
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1809
+ par_block = default_block + ('f',) * (par_width - len(default_block))
1810
+ par_head = par_block * par_attn
1811
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
1812
+ elif exists(sandwich_coef):
1813
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1814
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1815
+ else:
1816
+ layer_types = default_block * depth
1817
+
1818
+ self.layer_types = layer_types
1819
+ self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1820
+
1821
+ assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1822
+
1823
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1824
+
1825
+ # stochastic depth
1826
+
1827
+ self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1828
+
1829
+ # structured dropout for cross attending
1830
+
1831
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1832
+
1833
+ # calculate token shifting
1834
+
1835
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1836
+
1837
+ # whether it has post norm
1838
+
1839
+ self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
1840
+
1841
+ # iterate and construct layers
1842
+
1843
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1844
+ is_last_layer = ind == (len(self.layer_types) - 1)
1845
+
1846
+ if layer_type == 'a':
1847
+ layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1848
+ elif layer_type == 'c':
1849
+ layer = Attention(dim, heads = heads, **attn_kwargs)
1850
+ elif layer_type == 'f':
1851
+ layer = FeedForward(dim, **ff_kwargs)
1852
+ layer = layer if not macaron else Scale(0.5, layer)
1853
+ else:
1854
+ raise Exception(f'invalid layer type {layer_type}')
1855
+
1856
+ if layer_shift_tokens > 0:
1857
+ shift_range_upper = layer_shift_tokens + 1
1858
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1859
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1860
+
1861
+ residual_fn = GRUGating if gate_residual else Residual
1862
+ residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1863
+
1864
+ pre_branch_norm = norm_fn() if pre_norm else None
1865
+ post_branch_norm = norm_fn() if sandwich_norm else None
1866
+ post_main_norm = norm_fn() if not pre_norm else None
1867
+
1868
+ norms = nn.ModuleList([
1869
+ pre_branch_norm,
1870
+ post_branch_norm,
1871
+ post_main_norm
1872
+ ])
1873
+
1874
+ self.layers.append(nn.ModuleList([
1875
+ norms,
1876
+ layer,
1877
+ residual
1878
+ ]))
1879
+
1880
+ def forward(
1881
+ self,
1882
+ x,
1883
+ context = None,
1884
+ mask = None,
1885
+ context_mask = None,
1886
+ attn_mask = None,
1887
+ self_attn_kv_mask = None,
1888
+ mems = None,
1889
+ seq_start_pos: Optional[Tensor] = None,
1890
+ cache: Optional[LayerIntermediates] = None,
1891
+ cache_age = 1,
1892
+ return_hiddens = False
1893
+ ):
1894
+ assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1895
+
1896
+ # initialize accums
1897
+
1898
+ hiddens = []
1899
+ layer_hiddens = []
1900
+ intermediates = []
1901
+
1902
+ prev_attn = None
1903
+ prev_cross_attn = None
1904
+
1905
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1906
+
1907
+ # handle left padded sequences
1908
+
1909
+ if exists(seq_start_pos):
1910
+ seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
1911
+ left_pad_mask = seq_arange >= seq_start_pos[..., None]
1912
+
1913
+ if exists(self_attn_kv_mask):
1914
+ self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
1915
+ else:
1916
+ self_attn_kv_mask = left_pad_mask
1917
+
1918
+ # rotary positions
1919
+
1920
+ rotary_pos_emb = None
1921
+
1922
+ if exists(self.rotary_pos_emb):
1923
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1924
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
1925
+
1926
+ # assume cached key / values
1927
+
1928
+ attn_cache = []
1929
+
1930
+ if exists(cache):
1931
+ assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
1932
+
1933
+ if cache_age > 0:
1934
+ x = x[:, -cache_age:] # for spec decoding, may be greater than 1
1935
+
1936
+ attn_cache = cache.attn_intermediates
1937
+
1938
+ iter_attn_cache = iter(attn_cache)
1939
+
1940
+ # outer residual - for resiDual paper
1941
+
1942
+ outer_residual = x * self.resi_dual_scale
1943
+
1944
+ # get layers to be executed
1945
+
1946
+ layer_variables = (
1947
+ self.layer_types,
1948
+ self.layers,
1949
+ self.layer_dropouts
1950
+ )
1951
+
1952
+ layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1953
+
1954
+ # go through the attention and feedforward layers
1955
+
1956
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1957
+ is_last = ind == (len(self.layers) - 1)
1958
+
1959
+ if self.training and layer_dropout > 0. and random() < layer_dropout:
1960
+ continue
1961
+
1962
+ if layer_type == 'a':
1963
+ if return_hiddens:
1964
+ hiddens.append(x)
1965
+ layer_mem = mems.pop(0) if mems else None
1966
+
1967
+ if layer_type == 'c':
1968
+ if self.training and self.cross_attn_tokens_dropout > 0.:
1969
+ context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1970
+
1971
+ inner_residual = x
1972
+
1973
+ if return_hiddens:
1974
+ layer_hiddens.append(x)
1975
+
1976
+ pre_norm, post_branch_norm, post_main_norm = norm
1977
+
1978
+ if exists(pre_norm):
1979
+ x = pre_norm(x)
1980
+
1981
+ if layer_type == 'a':
1982
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
1983
+ elif layer_type == 'c':
1984
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1985
+ elif layer_type == 'f':
1986
+ out = block(x)
1987
+
1988
+ if self.resi_dual:
1989
+ outer_residual = outer_residual + out * self.resi_dual_scale
1990
+
1991
+ if exists(post_branch_norm):
1992
+ out = post_branch_norm(out)
1993
+
1994
+ x = residual_fn(out, inner_residual)
1995
+
1996
+ if layer_type in ('a', 'c') and return_hiddens:
1997
+ intermediates.append(inter)
1998
+
1999
+ if layer_type == 'a' and self.residual_attn:
2000
+ prev_attn = inter.pre_softmax_attn
2001
+ elif layer_type == 'c' and self.cross_residual_attn:
2002
+ prev_cross_attn = inter.pre_softmax_attn
2003
+
2004
+ if exists(post_main_norm):
2005
+ x = post_main_norm(x)
2006
+
2007
+ if return_hiddens:
2008
+ layer_hiddens.append(x)
2009
+
2010
+ if self.resi_dual:
2011
+ x = x + self.final_norm(outer_residual)
2012
+ else:
2013
+ x = self.final_norm(x)
2014
+
2015
+ if not return_hiddens:
2016
+ return x
2017
+
2018
+ intermediates = LayerIntermediates(
2019
+ hiddens = hiddens,
2020
+ attn_intermediates = intermediates,
2021
+ layer_hiddens = layer_hiddens
2022
+ )
2023
+
2024
+ return x, intermediates
2025
+
2026
+ class Encoder(AttentionLayers):
2027
+ def __init__(self, **kwargs):
2028
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
2029
+ super().__init__(causal = False, **kwargs)
2030
+
2031
+ class Decoder(AttentionLayers):
2032
+ def __init__(self, **kwargs):
2033
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
2034
+ super().__init__(causal = True, **kwargs)
2035
+
2036
+ class CrossAttender(AttentionLayers):
2037
+ def __init__(self, **kwargs):
2038
+ super().__init__(cross_attend = True, only_cross = True, **kwargs)
2039
+
2040
+ class ViTransformerWrapper(nn.Module):
2041
+ def __init__(
2042
+ self,
2043
+ *,
2044
+ image_size,
2045
+ patch_size,
2046
+ attn_layers,
2047
+ channels = 3,
2048
+ num_classes = None,
2049
+ post_emb_norm = False,
2050
+ num_register_tokens = 0,
2051
+ emb_dropout = 0.
2052
+ ):
2053
+ super().__init__()
2054
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
2055
+ assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
2056
+ dim = attn_layers.dim
2057
+ num_patches = (image_size // patch_size) ** 2
2058
+ patch_dim = channels * patch_size ** 2
2059
+
2060
+ self.patch_size = patch_size
2061
+
2062
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
2063
+
2064
+ has_register_tokens = num_register_tokens > 0
2065
+ self.has_register_tokens = has_register_tokens
2066
+
2067
+ if has_register_tokens:
2068
+ self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
2069
+
2070
+ self.patch_to_embedding = nn.Sequential(
2071
+ nn.LayerNorm(patch_dim),
2072
+ nn.Linear(patch_dim, dim),
2073
+ nn.LayerNorm(dim)
2074
+ )
2075
+
2076
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2077
+ self.dropout = nn.Dropout(emb_dropout)
2078
+
2079
+ self.attn_layers = attn_layers
2080
+
2081
+ self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
2082
+
2083
+ def forward(
2084
+ self,
2085
+ img,
2086
+ return_embeddings = False
2087
+ ):
2088
+ b, p = img.shape[0], self.patch_size
2089
+
2090
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
2091
+ x = self.patch_to_embedding(x)
2092
+ n = x.shape[1]
2093
+
2094
+ x = x + self.pos_embedding[:, :n]
2095
+
2096
+ x = self.post_emb_norm(x)
2097
+ x = self.dropout(x)
2098
+
2099
+ if self.has_register_tokens:
2100
+ r = repeat(self.register_tokens, 'n d -> b n d', b = b)
2101
+ x, ps = pack((x, r), 'b * d')
2102
+
2103
+ x = self.attn_layers(x)
2104
+
2105
+ if self.has_register_tokens:
2106
+ x, _ = unpack(x, ps, 'b * d')
2107
+
2108
+ if not exists(self.mlp_head) or return_embeddings:
2109
+ return x
2110
+
2111
+ x = x.mean(dim = -2)
2112
+ return self.mlp_head(x)
2113
+
2114
+ class TransformerWrapper(nn.Module):
2115
+ def __init__(
2116
+ self,
2117
+ *,
2118
+ num_tokens,
2119
+ max_seq_len,
2120
+ attn_layers,
2121
+ emb_dim = None,
2122
+ max_mem_len = 0,
2123
+ shift_mem_down = 0,
2124
+ emb_dropout = 0.,
2125
+ post_emb_norm = False,
2126
+ num_memory_tokens = None,
2127
+ memory_tokens_interspersed_every = None,
2128
+ tie_embedding = False,
2129
+ logits_dim = None,
2130
+ use_abs_pos_emb = True,
2131
+ scaled_sinu_pos_emb = False,
2132
+ l2norm_embed = False,
2133
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
2134
+ attn_z_loss_weight = 1e-4,
2135
+ ):
2136
+ super().__init__()
2137
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2138
+
2139
+ dim = attn_layers.dim
2140
+ emb_dim = default(emb_dim, dim)
2141
+ self.emb_dim = emb_dim
2142
+ self.num_tokens = num_tokens
2143
+
2144
+ self.max_seq_len = max_seq_len
2145
+ self.max_mem_len = max_mem_len
2146
+ self.shift_mem_down = shift_mem_down
2147
+
2148
+ self.l2norm_embed = l2norm_embed
2149
+ self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
2150
+
2151
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2152
+ self.pos_emb = always(0)
2153
+ elif scaled_sinu_pos_emb:
2154
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
2155
+ else:
2156
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
2157
+
2158
+ self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
2159
+
2160
+ self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
2161
+ self.emb_dropout = nn.Dropout(emb_dropout)
2162
+
2163
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
2164
+ self.attn_layers = attn_layers
2165
+
2166
+ self.init_()
2167
+
2168
+ logits_dim = default(logits_dim, num_tokens)
2169
+ self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
2170
+
2171
+ # memory tokens (like [cls]) from Memory Transformers paper
2172
+
2173
+ num_memory_tokens = default(num_memory_tokens, 0)
2174
+ self.num_memory_tokens = num_memory_tokens
2175
+ if num_memory_tokens > 0:
2176
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
2177
+
2178
+ self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2179
+
2180
+ # whether can do cached kv decoding
2181
+
2182
+ self.can_cache_kv = self.num_memory_tokens == 0
2183
+
2184
+ def init_(self):
2185
+ if self.l2norm_embed:
2186
+ nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2187
+ if not isinstance(self.pos_emb, always):
2188
+ nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2189
+ return
2190
+
2191
+ nn.init.kaiming_normal_(self.token_emb.emb.weight)
2192
+
2193
+ def forward(
2194
+ self,
2195
+ x,
2196
+ return_embeddings = False,
2197
+ return_logits_and_embeddings = False,
2198
+ return_intermediates = False,
2199
+ mask = None,
2200
+ return_mems = False,
2201
+ return_attn = False,
2202
+ mems = None,
2203
+ pos = None,
2204
+ prepend_embeds = None,
2205
+ sum_embeds = None,
2206
+ return_attn_z_loss = False,
2207
+ attn_z_loss_weight = 1e-4,
2208
+ seq_start_pos = None,
2209
+ cache: Optional[LayerIntermediates] = None,
2210
+ **kwargs
2211
+ ):
2212
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2213
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2214
+
2215
+ # absolute positional embedding
2216
+
2217
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
2218
+ pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2219
+ x = self.token_emb(x) + pos_emb
2220
+
2221
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
2222
+
2223
+ if exists(sum_embeds):
2224
+ x = x + sum_embeds
2225
+
2226
+ # post embedding norm, purportedly leads to greater stabilization
2227
+
2228
+ x = self.post_emb_norm(x)
2229
+
2230
+ # whether to append embeds, as in PaLI, for image embeddings
2231
+
2232
+ if exists(prepend_embeds):
2233
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2234
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2235
+
2236
+ x = torch.cat((prepend_embeds, x), dim = -2)
2237
+
2238
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2239
+
2240
+ if emb_frac_gradient < 1:
2241
+ assert emb_frac_gradient > 0
2242
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
2243
+
2244
+ # embedding dropout
2245
+
2246
+ x = self.emb_dropout(x)
2247
+
2248
+ x = self.project_emb(x)
2249
+
2250
+ if has_memory_tokens:
2251
+ mem_every = self.memory_tokens_interspersed_every
2252
+
2253
+ if exists(mem_every):
2254
+ assert mem_every > 0
2255
+ assert isinstance(self.attn_layers, Decoder), 'only for decoder'
2256
+ next_seq_len = math.ceil(n / mem_every) * mem_every
2257
+
2258
+ x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
2259
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
2260
+
2261
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
2262
+ x, mem_packed_shape = pack((mem, x), 'b * d')
2263
+
2264
+ # auto-handle masking after appending memory tokens
2265
+ if not exists(mem_every) and exists(mask):
2266
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
2267
+
2268
+ if exists(mem_every):
2269
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2270
+
2271
+ if self.shift_mem_down and exists(mems):
2272
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2273
+ mems = [*mems_r, *mems_l]
2274
+
2275
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2276
+
2277
+ if has_memory_tokens:
2278
+ if exists(mem_every):
2279
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
2280
+
2281
+ mem, x = unpack(x, mem_packed_shape, 'b * d')
2282
+
2283
+ if exists(mem_every):
2284
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2285
+
2286
+ x = x[:, :n]
2287
+
2288
+ if return_logits_and_embeddings:
2289
+ out = (self.to_logits(x), x)
2290
+ elif return_embeddings:
2291
+ out = x
2292
+ else:
2293
+ out = self.to_logits(x)
2294
+
2295
+ if return_attn_z_loss:
2296
+ pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2297
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2298
+ return_intermediates = True
2299
+
2300
+ if return_mems:
2301
+ hiddens = intermediates.hiddens
2302
+ new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2303
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2304
+
2305
+ if not return_intermediates:
2306
+ return out, new_mems
2307
+
2308
+ intermediates.mems = new_mems
2309
+
2310
+ if return_intermediates:
2311
+ return out, intermediates
2312
+
2313
+ if return_attn:
2314
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2315
+ return out, attn_maps
2316
+
2317
+ return out
2318
+
2319
+ class ContinuousTransformerWrapper(nn.Module):
2320
+ def __init__(
2321
+ self,
2322
+ *,
2323
+ max_seq_len,
2324
+ attn_layers,
2325
+ dim_in = None,
2326
+ dim_out = None,
2327
+ emb_dim = None,
2328
+ max_mem_len = 0,
2329
+ post_emb_norm = False,
2330
+ emb_dropout = 0.,
2331
+ use_abs_pos_emb = True,
2332
+ scaled_sinu_pos_emb = False
2333
+ ):
2334
+ super().__init__()
2335
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2336
+
2337
+ dim = attn_layers.dim
2338
+
2339
+ self.max_seq_len = max_seq_len
2340
+
2341
+ self.max_mem_len = max_mem_len
2342
+
2343
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2344
+ self.pos_emb = always(0)
2345
+ elif scaled_sinu_pos_emb:
2346
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
2347
+ else:
2348
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
2349
+
2350
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2351
+ self.emb_dropout = nn.Dropout(emb_dropout)
2352
+
2353
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
2354
+
2355
+ self.attn_layers = attn_layers
2356
+
2357
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
2358
+
2359
+ def forward(
2360
+ self,
2361
+ x,
2362
+ return_embeddings = False,
2363
+ return_intermediates = False,
2364
+ return_mems = False,
2365
+ mask = None,
2366
+ return_attn = False,
2367
+ mems = None,
2368
+ pos = None,
2369
+ prepend_embeds = None,
2370
+ **kwargs
2371
+ ):
2372
+ x = self.project_in(x)
2373
+ x = x + self.pos_emb(x, pos = pos)
2374
+
2375
+ x = self.post_emb_norm(x)
2376
+
2377
+ # whether to append embeds, as in PaLI, for image embeddings
2378
+
2379
+ if exists(prepend_embeds):
2380
+ _, prepend_dim = prepend_embeds.shape[1:]
2381
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
2382
+
2383
+ x = torch.cat((prepend_embeds, x), dim = -2)
2384
+
2385
+ x = self.emb_dropout(x)
2386
+
2387
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
2388
+
2389
+ out = self.project_out(x) if not return_embeddings else x
2390
+
2391
+ if return_intermediates:
2392
+ return out, intermediates
2393
+
2394
+ if return_mems:
2395
+ hiddens = intermediates.hiddens
2396
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
2397
+ return out, new_mems
2398
+
2399
+ if return_attn:
2400
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2401
+ return out, attn_maps
2402
+
2403
+ return out
2404
+
2405
+ class XTransformer(nn.Module):
2406
+ def __init__(
2407
+ self,
2408
+ *,
2409
+ dim,
2410
+ tie_token_emb = False,
2411
+ ignore_index = -100,
2412
+ pad_value = 0,
2413
+ cross_attn_tokens_dropout = 0.,
2414
+ **kwargs
2415
+ ):
2416
+ super().__init__()
2417
+ enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
2418
+ dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
2419
+
2420
+ assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
2421
+ enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
2422
+ enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
2423
+ enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
2424
+ enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
2425
+ enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
2426
+
2427
+ dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
2428
+ dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
2429
+ dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
2430
+ dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
2431
+
2432
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
2433
+
2434
+ self.encoder = TransformerWrapper(
2435
+ **enc_transformer_kwargs,
2436
+ attn_layers = Encoder(dim = dim, **enc_kwargs)
2437
+ )
2438
+
2439
+ self.decoder = TransformerWrapper(
2440
+ **dec_transformer_kwargs,
2441
+ attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
2442
+ )
2443
+
2444
+ if tie_token_emb:
2445
+ self.decoder.token_emb = self.encoder.token_emb
2446
+
2447
+ self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
2448
+
2449
+ @torch.no_grad()
2450
+ def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
2451
+ encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
2452
+ return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
2453
+
2454
+ def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
2455
+
2456
+ if exists(src_prepend_embeds) and exists(mask):
2457
+ mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
2458
+
2459
+ enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
2460
+
2461
+ if self.training and self.cross_attn_tokens_dropout > 0:
2462
+ enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2463
+
2464
+ out = self.decoder(tgt, context = enc, context_mask = mask)
2465
+ return out