asigalov61 commited on
Commit
e3ed184
·
verified ·
1 Parent(s): 0c340e5

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2 filter=lfs diff=lfs merge=lfs -text
SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd41a4639c9e7a96413b4b22540d48e6741e24bcdabcb2eff22cd65929df3cfa
3
+ size 553961496
TMIDIX.py ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/spaces/asigalov61/Melody2Song-Seq2Seq-Music-Transformer
2
+
3
+ import os
4
+ import time as reqtime
5
+ import datetime
6
+ from pytz import timezone
7
+
8
+ import torch
9
+
10
+ import spaces
11
+ import gradio as gr
12
+
13
+ from x_transformer_1_23_2 import *
14
+ import random
15
+ import tqdm
16
+
17
+ from midi_to_colab_audio import midi_to_colab_audio
18
+ import TMIDIX
19
+
20
+ import matplotlib.pyplot as plt
21
+
22
+ in_space = os.getenv("SYSTEM") == "spaces"
23
+
24
+ # =================================================================================================
25
+
26
+ @spaces.GPU
27
+ def GenerateSong(input_melody_seed_number):
28
+ print('=' * 70)
29
+ print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
30
+ start_time = reqtime.time()
31
+
32
+ print('Loading model...')
33
+
34
+ SEQ_LEN = 2560
35
+ PAD_IDX = 514
36
+ DEVICE = 'cuda' # 'cuda'
37
+
38
+ # instantiate the model
39
+
40
+ model = TransformerWrapper(
41
+ num_tokens = PAD_IDX+1,
42
+ max_seq_len = SEQ_LEN,
43
+ attn_layers = Decoder(dim = 1024, depth = 24, heads = 16, attn_flash = True)
44
+ )
45
+
46
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
47
+
48
+ model.to(DEVICE)
49
+ print('=' * 70)
50
+
51
+ print('Loading model checkpoint...')
52
+
53
+ model.load_state_dict(
54
+ torch.load('Melody2Song_Seq2Seq_Music_Transformer_Trained_Model_28482_steps_0.719_loss_0.7865_acc.pth',
55
+ map_location=DEVICE))
56
+ print('=' * 70)
57
+
58
+ model.eval()
59
+
60
+ if DEVICE == 'cpu':
61
+ dtype = torch.bfloat16
62
+ else:
63
+ dtype = torch.bfloat16
64
+
65
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
66
+
67
+ print('Done!')
68
+ print('=' * 70)
69
+ seed_melody = seed_melodies_data[input_melody_seed_number]
70
+ print('Input melody seed number:', input_melody_seed_number)
71
+ print('-' * 70)
72
+
73
+ #==================================================================
74
+
75
+ print('=' * 70)
76
+
77
+ print('Sample output events', seed_melody[:16])
78
+ print('=' * 70)
79
+ print('Generating...')
80
+
81
+ x = (torch.tensor(seed_melody, dtype=torch.long, device='cuda')[None, ...])
82
+
83
+ with ctx:
84
+ out = model.generate(x,
85
+ 1536,
86
+ temperature=0.9,
87
+ return_prime=False,
88
+ verbose=False)
89
+
90
+ output = out[0].tolist()
91
+
92
+ print('=' * 70)
93
+ print('Done!')
94
+ print('=' * 70)
95
+
96
+ #===============================================================================
97
+ print('Rendering results...')
98
+
99
+ print('=' * 70)
100
+ print('Sample INTs', output[:15])
101
+ print('=' * 70)
102
+
103
+ out1 = output
104
+
105
+ if len(out1) != 0:
106
+
107
+ song = out1
108
+ song_f = []
109
+
110
+ time = 0
111
+ dur = 0
112
+ vel = 90
113
+ pitch = 0
114
+ channel = 0
115
+
116
+ patches = [0] * 16
117
+ patches[3] = 40
118
+
119
+ for ss in song:
120
+
121
+ if 0 < ss < 128:
122
+
123
+ time += (ss * 32)
124
+
125
+ if 128 < ss < 256:
126
+
127
+ dur = (ss-128) * 32
128
+
129
+ if 256 < ss < 512:
130
+
131
+ pitch = (ss-256) % 128
132
+
133
+ channel = (ss-256) // 128
134
+
135
+ if channel == 1:
136
+ channel = 3
137
+ vel = 110 + (pitch % 12)
138
+ song_f.append(['note', time, dur, channel, pitch, vel, 40])
139
+
140
+ else:
141
+ vel = 80 + (pitch % 12)
142
+ channel = 0
143
+ song_f.append(['note', time, dur, channel, pitch, vel, 0])
144
+
145
+ fn1 = "Melody2Song-Seq2Seq-Music-Transformer-Composition"
146
+
147
+ detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
148
+ output_signature = 'Melody2Song Seq2Seq Music Transformer',
149
+ output_file_name = fn1,
150
+ track_name='Project Los Angeles',
151
+ list_of_MIDI_patches=patches
152
+ )
153
+
154
+ new_fn = fn1+'.mid'
155
+
156
+
157
+ audio = midi_to_colab_audio(new_fn,
158
+ soundfont_path=soundfont,
159
+ sample_rate=16000,
160
+ volume_scale=10,
161
+ output_for_gradio=True
162
+ )
163
+
164
+ print('Done!')
165
+ print('=' * 70)
166
+
167
+ #========================================================
168
+
169
+ output_midi_title = str(fn1)
170
+ output_midi_summary = str(song_f[:3])
171
+ output_midi = str(new_fn)
172
+ output_audio = (16000, audio)
173
+
174
+ output_plot = TMIDIX.plot_ms_SONG(song_f, plot_title=output_midi, return_plt=True)
175
+
176
+ print('Output MIDI file name:', output_midi)
177
+ print('Output MIDI title:', output_midi_title)
178
+ print('Output MIDI summary:', output_midi_summary)
179
+ print('=' * 70)
180
+
181
+
182
+ #========================================================
183
+
184
+ print('-' * 70)
185
+ print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
186
+ print('-' * 70)
187
+ print('Req execution time:', (reqtime.time() - start_time), 'sec')
188
+
189
+ return output_midi_title, output_midi_summary, output_midi, output_audio, output_plot
190
+
191
+ # =================================================================================================
192
+
193
+ if __name__ == "__main__":
194
+
195
+ PDT = timezone('US/Pacific')
196
+
197
+ print('=' * 70)
198
+ print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
199
+ print('=' * 70)
200
+
201
+ soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2"
202
+
203
+ print('Loading seed meldoies data...')
204
+ seed_melodies_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody2Song_Seq2Seq_Music_Transformer_Seed_Melodies_Data')
205
+ print('=' * 70)
206
+
207
+ app = gr.Blocks()
208
+ with app:
209
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Melody2Song Seq2Seq Music Transformer</h1>")
210
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Generate unique songs from melodies with seq2seq music transformer</h1>")
211
+ gr.Markdown(
212
+ "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Melody2Song-Seq2Seq-Music-Transformer&style=flat)\n\n")
213
+
214
+ input_melody_seed_number = gr.Slider(0, 203664, value=0, step=1, label="Select seed melody number")
215
+
216
+ run_btn = gr.Button("generate", variant="primary")
217
+
218
+ gr.Markdown("## Generation results")
219
+
220
+ output_midi_title = gr.Textbox(label="Output MIDI title")
221
+ output_midi_summary = gr.Textbox(label="Output MIDI summary")
222
+ output_audio = gr.Audio(label="Output MIDI audio", format="wav", elem_id="midi_audio")
223
+ output_plot = gr.Plot(label="Output MIDI score plot")
224
+ output_midi = gr.File(label="Output MIDI file", file_types=[".mid"])
225
+
226
+ run_event = run_btn.click(GenerateSong, [input_melody_seed_number],
227
+ [output_midi_title, output_midi_summary, output_midi, output_audio, output_plot])
228
+
229
+ app.queue().launch()
midi_to_colab_audio.py ADDED
The diff for this file is too large to render. See raw diff
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ fluidsynth
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ gradio
3
+ einops
x_transformer_1_23_2.py ADDED
@@ -0,0 +1,2464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #===================================================================================================================
2
+ #
3
+ # X Trasformer Module
4
+ #
5
+ # Partial x-transformers code With useful modifications
6
+ #
7
+ # Version 1.0
8
+ #
9
+ # Original source code courtesy of lucidrains
10
+ # https://github.com/lucidrains/x-transformers
11
+ #
12
+ # Original source code retrieved on 10/10/2023
13
+ #
14
+ # Project Los Angeles
15
+ # Tegridy Code 2023
16
+
17
+ #===================================================================================================================
18
+
19
+ # Critical dependencies
20
+ #
21
+ # !pip install torch
22
+ # !pip install einops
23
+
24
+ #===================================================================================================================
25
+
26
+ from functools import partial
27
+ from typing import Optional, Tuple
28
+
29
+ import torch
30
+ from torch import nn, einsum, Tensor
31
+ import torch.nn.functional as F
32
+ # from torch.nn.attention import SDPBackend, sdpa_kernel
33
+
34
+ from collections import namedtuple
35
+ from functools import wraps
36
+ from packaging import version
37
+ from dataclasses import dataclass
38
+
39
+ from einops import rearrange, repeat
40
+
41
+ # constants
42
+
43
+ EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
44
+
45
+ @dataclass
46
+ class Intermediates:
47
+ qk_similarities: Optional[Tensor] = None
48
+ pre_softmax_attn: Optional[Tensor] = None
49
+ post_softmax_attn: Optional[Tensor] = None
50
+ cached_kv: Optional[Tuple[Tensor, Tensor]] = None
51
+
52
+ def to_tuple(self):
53
+ return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)
54
+
55
+ # helpers
56
+
57
+ def exists(val):
58
+ return val is not None
59
+
60
+ def default(val, d):
61
+ return val if exists(val) else d
62
+
63
+ def compact(arr):
64
+ return [*filter(exists, arr)]
65
+
66
+ def once(fn):
67
+ called = False
68
+ @wraps(fn)
69
+ def inner(x):
70
+ nonlocal called
71
+ if called:
72
+ return
73
+ called = True
74
+ return fn(x)
75
+ return inner
76
+
77
+ print_once = once(print)
78
+
79
+ # functions for creating causal mask
80
+ # need a special one for onnx cpu (no support for .triu)
81
+
82
+ def create_causal_mask(i, j, device):
83
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
84
+
85
+ def onnx_create_causal_mask(i, j, device):
86
+ r = torch.arange(i, device = device)
87
+ causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
88
+ causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
89
+ return causal_mask
90
+
91
+ # main class
92
+
93
+ class Attend(nn.Module):
94
+ def __init__(
95
+ self,
96
+ *,
97
+ dropout = 0.,
98
+ causal = False,
99
+ heads = None,
100
+ talking_heads = False,
101
+ sparse_topk = None,
102
+ scale = None,
103
+ qk_norm = False,
104
+ flash = False,
105
+ add_zero_kv = False,
106
+ onnxable = False
107
+ ):
108
+ super().__init__()
109
+ self.scale = scale
110
+ self.qk_norm = qk_norm
111
+
112
+ self.causal = causal
113
+ self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask
114
+
115
+ self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
116
+
117
+ self.dropout = dropout
118
+ self.attn_dropout = nn.Dropout(dropout)
119
+
120
+ # talking heads
121
+
122
+ assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
123
+
124
+ self.talking_heads = talking_heads
125
+ if talking_heads:
126
+ self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
127
+ self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
128
+
129
+ # sparse topk
130
+
131
+ assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
132
+ self.sparse_topk = sparse_topk
133
+
134
+ # add a key / value token composed of zeros
135
+ # in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html
136
+
137
+ self.add_zero_kv = add_zero_kv
138
+
139
+ # flash attention
140
+
141
+ self.flash = flash
142
+ assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
143
+
144
+ # determine efficient attention configs for cuda and cpu
145
+
146
+ self.cpu_config = EfficientAttentionConfig(True, True, True)
147
+ self.cuda_config = None
148
+
149
+ if not torch.cuda.is_available() or not flash:
150
+ return
151
+
152
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
153
+
154
+ major, minor = device_properties.major, device_properties.minor
155
+
156
+ if (major, minor) == (8, 0):
157
+ print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
158
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
159
+ elif (major, minor) == (9, 0):
160
+ print_once('H100 GPU detected, using flash attention')
161
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
162
+ else:
163
+ print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
164
+ self.cuda_config = EfficientAttentionConfig(False, True, True)
165
+
166
+ def flash_attn(
167
+ self,
168
+ q, k, v,
169
+ mask = None,
170
+ attn_bias = None
171
+ ):
172
+ batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
173
+
174
+ # Recommended for multi-query single-key-value attention by Tri Dao
175
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
176
+
177
+ if k.ndim == 3:
178
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
179
+
180
+ if v.ndim == 3:
181
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
182
+
183
+ # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
184
+
185
+ if self.qk_norm:
186
+ default_scale = q.shape[-1] ** -0.5
187
+ q = q * (self.scale / default_scale)
188
+
189
+ # Check if mask exists and expand to compatible shape
190
+ # The mask is B L, so it would have to be expanded to B H N L
191
+
192
+ causal = self.causal
193
+
194
+ # in the case of kv caching with one token (q_len == 1), just turn off causal masking
195
+ # in speculative decoding, this may go up to 5-6, so right aligned causal mask will be needed there
196
+
197
+ if q_len == 1 and causal:
198
+ causal = False
199
+
200
+ # expand key padding mask
201
+
202
+ if exists(mask):
203
+ assert mask.ndim == 4
204
+ mask = mask.expand(batch, heads, q_len, k_len)
205
+
206
+ # handle kv cache - this should be bypassable in updated flash attention 2
207
+
208
+ if k_len > q_len and causal:
209
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
210
+ if not exists(mask):
211
+ mask = ~causal_mask
212
+ else:
213
+ mask = mask & ~causal_mask
214
+ causal = False
215
+
216
+ # manually handle causal mask, if another mask was given
217
+
218
+ row_is_entirely_masked = None
219
+
220
+ if exists(mask) and causal:
221
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
222
+ mask = mask & ~causal_mask
223
+
224
+ # protect against an entire row being masked out
225
+
226
+ row_is_entirely_masked = ~mask.any(dim = -1)
227
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
228
+
229
+ causal = False
230
+
231
+ # handle alibi positional bias
232
+ # convert from bool to float
233
+
234
+ if exists(attn_bias):
235
+ attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
236
+
237
+ # if mask given, the mask would already contain the causal mask from above logic
238
+ # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
239
+
240
+ mask_value = -torch.finfo(q.dtype).max
241
+
242
+ if exists(mask):
243
+ attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
244
+ elif causal:
245
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
246
+ attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
247
+ causal = False
248
+
249
+ # scaled_dot_product_attention handles attn_mask either as bool or additive bias
250
+ # make it an additive bias here
251
+
252
+ mask = attn_bias
253
+
254
+ # Check if there is a compatible device for flash attention
255
+
256
+ config = self.cuda_config if is_cuda else self.cpu_config
257
+
258
+ # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
259
+
260
+ # Legacy code...
261
+ with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=True):
262
+
263
+ # New SDP kernel code...
264
+ # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
265
+
266
+ out = F.scaled_dot_product_attention(
267
+ q, k, v,
268
+ attn_mask = mask,
269
+ dropout_p = self.dropout if self.training else 0.,
270
+ is_causal = causal
271
+ )
272
+
273
+ # for a row that is entirely masked out, should zero out the output of that row token
274
+
275
+ if exists(row_is_entirely_masked):
276
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
277
+
278
+ return out, Intermediates()
279
+
280
+ def forward(
281
+ self,
282
+ q, k, v,
283
+ mask = None,
284
+ attn_bias = None,
285
+ prev_attn = None
286
+ ):
287
+ """
288
+ einstein notation
289
+ b - batch
290
+ h - heads
291
+ n, i, j - sequence length (base sequence length, source, target)
292
+ d - feature dimension
293
+ """
294
+
295
+ n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device
296
+
297
+ scale = default(self.scale, q.shape[-1] ** -0.5)
298
+
299
+ causal = self.causal
300
+
301
+ # handle kv cached decoding
302
+
303
+ if n == 1 and causal:
304
+ causal = False
305
+
306
+ # handle grouped multi-query attention
307
+
308
+ if kv_heads == 1:
309
+ k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
310
+ elif kv_heads < heads:
311
+ k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))
312
+
313
+ # handle zero kv, as means for allowing network to attend to nothing
314
+
315
+ if self.add_zero_kv:
316
+ k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))
317
+
318
+ if exists(mask):
319
+ mask = F.pad(mask, (1, 0), value = True)
320
+
321
+ if exists(attn_bias):
322
+ attn_bias = F.pad(attn_bias, (1, 0), value = 0.)
323
+
324
+ if self.flash:
325
+ assert not exists(prev_attn), 'residual attention not compatible with flash attention'
326
+ return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
327
+
328
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
329
+
330
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
331
+
332
+ if exists(prev_attn):
333
+ dots = dots + prev_attn
334
+
335
+ qk_similarities = dots.clone()
336
+
337
+ if self.talking_heads:
338
+ dots = self.pre_softmax_talking_heads(dots)
339
+
340
+ if exists(attn_bias):
341
+ dots = dots + attn_bias
342
+
343
+ i, j, dtype = *dots.shape[-2:], dots.dtype
344
+
345
+ mask_value = -torch.finfo(dots.dtype).max
346
+
347
+ if exists(self.sparse_topk) and self.sparse_topk < j:
348
+ top_values, _ = dots.topk(self.sparse_topk, dim = -1)
349
+ sparse_topk_mask = dots < top_values[..., -1:]
350
+ mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask
351
+
352
+ if exists(mask):
353
+ dots = dots.masked_fill(~mask, mask_value)
354
+
355
+ if causal:
356
+ causal_mask = self.create_causal_mask(i, j, device = device)
357
+ dots = dots.masked_fill(causal_mask, mask_value)
358
+
359
+ pre_softmax_attn = dots.clone()
360
+
361
+ attn = self.attn_fn(dots, dim = -1)
362
+ attn = attn.type(dtype)
363
+
364
+ post_softmax_attn = attn.clone()
365
+
366
+ attn = self.attn_dropout(attn)
367
+
368
+ if self.talking_heads:
369
+ attn = self.post_softmax_talking_heads(attn)
370
+
371
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
372
+
373
+ intermediates = Intermediates(
374
+ qk_similarities = qk_similarities,
375
+ pre_softmax_attn = pre_softmax_attn,
376
+ post_softmax_attn = post_softmax_attn
377
+ )
378
+
379
+ return out, intermediates
380
+
381
+ #===================================================================================================================
382
+
383
+ from math import ceil, log
384
+ from typing import Optional, Union, Tuple, Callable
385
+
386
+ import torch
387
+ from torch import nn, Tensor
388
+ from torch.nn import Module
389
+ import torch.nn.functional as F
390
+
391
+ from einops import rearrange, pack, unpack
392
+
393
+ def exists(val):
394
+ return val is not None
395
+
396
+ def default(val, d):
397
+ return val if exists(val) else d
398
+
399
+ def identity(t, *args, **kwargs):
400
+ return t
401
+
402
+ def cast_tuple(t, length = 1):
403
+ return t if isinstance(t, tuple) else (t,) * length
404
+
405
+ def eval_decorator(fn):
406
+ def inner(self, *args, **kwargs):
407
+ was_training = self.training
408
+ self.eval()
409
+ out = fn(self, *args, **kwargs)
410
+ self.train(was_training)
411
+ return out
412
+ return inner
413
+
414
+ # for variable lengthed prefixes
415
+
416
+ def align_right(t, lens, pad_id = 0):
417
+ batch, seq_len, device, dtype = *t.shape, t.device, t.dtype
418
+
419
+ assert lens.ndim == 1 and lens.shape[0] == batch
420
+ assert lens.amax() <= seq_len
421
+
422
+ pad_lens = seq_len - lens
423
+ max_pad_len = pad_lens.amax()
424
+
425
+ batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
426
+ prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
427
+
428
+ t = F.pad(t, (max_pad_len, 0), value = 0)
429
+ offset = max_pad_len - pad_lens
430
+
431
+ aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
432
+ return aligned
433
+
434
+ # nucleus
435
+
436
+ def top_p(logits, thres = 0.9):
437
+ sorted_logits, sorted_indices = torch.sort(logits, descending = True)
438
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim = -1), dim = -1)
439
+
440
+ sorted_indices_to_remove = cum_probs > thres
441
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, -1), value = False)
442
+
443
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
444
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
445
+
446
+ # topk
447
+
448
+ def top_k(logits, frac_num_tokens = 0.1, k = None):
449
+ num_tokens = logits.shape[-1]
450
+
451
+ k = default(k, ceil(frac_num_tokens * num_tokens))
452
+ k = min(k, num_tokens)
453
+
454
+ val, ind = torch.topk(logits, k)
455
+ probs = torch.full_like(logits, float('-inf'))
456
+ probs.scatter_(1, ind, val)
457
+ return probs
458
+
459
+ # top_a
460
+
461
+ def top_a(logits, min_p_pow = 2.0, min_p_ratio = 0.02):
462
+ probs = F.softmax(logits, dim = -1)
463
+ max_probs = torch.amax(probs, dim = -1, keepdim = True)
464
+ limit = torch.pow(max_probs, min_p_pow) * min_p_ratio
465
+ return torch.where(probs < limit, float('-inf'), logits)
466
+
467
+ # contrastive decoding function
468
+
469
+ def contrastive_decode_fn(
470
+ expert_logits,
471
+ amateur_logits,
472
+ alpha = 0.1,
473
+ beta = 0.5
474
+ ):
475
+ """
476
+ Appendix A Algorithm 2
477
+ https://arxiv.org/abs/2309.09117
478
+ """
479
+
480
+ cutoff = log(alpha) + expert_logits.amax(dim = -1, keepdim = True)
481
+ diffs = (1 + beta) * expert_logits - beta * amateur_logits
482
+ contrastive_decode_logits = diffs.masked_fill(expert_logits < cutoff, -torch.finfo(expert_logits.dtype).max)
483
+ return contrastive_decode_logits
484
+
485
+ # autoregressive wrapper class
486
+
487
+ class AutoregressiveWrapper(Module):
488
+ def __init__(
489
+ self,
490
+ net,
491
+ ignore_index = -100,
492
+ pad_value = 0,
493
+ mask_prob = 0.,
494
+ add_attn_z_loss = False
495
+ ):
496
+ super().__init__()
497
+ self.pad_value = pad_value
498
+ self.ignore_index = ignore_index
499
+
500
+ self.net = net
501
+ self.max_seq_len = net.max_seq_len
502
+
503
+ # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
504
+ assert mask_prob < 1.
505
+ self.mask_prob = mask_prob
506
+
507
+ # whether to add router z-loss
508
+ self.add_attn_z_loss = add_attn_z_loss
509
+
510
+ @torch.no_grad()
511
+ @eval_decorator
512
+ def generate(
513
+ self,
514
+ prompts,
515
+ seq_len,
516
+ eos_token = None,
517
+ temperature = 1.,
518
+ prompt_lens: Optional[Tensor] = None,
519
+ filter_logits_fn: Callable = top_k,
520
+ restrict_to_max_seq_len = True,
521
+ amateur_model: Optional[Union[Module, Tuple[Module]]] = None,
522
+ filter_kwargs: dict = dict(),
523
+ contrastive_decode_kwargs: Union[dict, Tuple[dict]] = dict(
524
+ beta = 0.5,
525
+ alpha = 0.1
526
+ ),
527
+ cache_kv = True,
528
+ verbose=True,
529
+ return_prime=False,
530
+ **kwargs
531
+ ):
532
+ max_seq_len, device = self.max_seq_len, prompts.device
533
+
534
+ prompts, ps = pack([prompts], '* n')
535
+
536
+ b, t = prompts.shape
537
+
538
+ # handle variable lengthed prompts (prefixes)
539
+
540
+ seq_start_pos = None
541
+ if exists(prompt_lens):
542
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
543
+ seq_start_pos = t - prompt_lens
544
+
545
+ # output from which sampled tokens appended to
546
+
547
+ out = prompts
548
+
549
+ if verbose:
550
+ print("Generating sequence of max length:", seq_len)
551
+
552
+ # kv caches
553
+
554
+ cache = None
555
+
556
+ # if doing contrastive decoding, turn off filter automatically
557
+
558
+ if exists(amateur_model):
559
+ amateur_model = cast_tuple(amateur_model)
560
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
561
+
562
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
563
+
564
+ amateur_caches = [None] * len(amateur_model)
565
+ filter_logits_fn = identity
566
+
567
+ for i, module in enumerate(amateur_model):
568
+ if isinstance(module, AutoregressiveWrapper):
569
+ amateur_model[i] = module.net
570
+
571
+ module.eval()
572
+
573
+ # sampling up to seq_len
574
+
575
+ for sl in range(seq_len):
576
+
577
+ if restrict_to_max_seq_len:
578
+ x = out[:, -max_seq_len:]
579
+
580
+ if exists(cache):
581
+ for inter in cache.attn_intermediates:
582
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
583
+
584
+ logits, new_cache = self.net(
585
+ x,
586
+ return_intermediates = True,
587
+ cache = cache,
588
+ seq_start_pos = seq_start_pos,
589
+ **kwargs
590
+ )
591
+
592
+ if cache_kv and self.net.can_cache_kv:
593
+ cache = new_cache
594
+
595
+ logits = logits[:, -1]
596
+
597
+ # handle contrastive decoding, Li et al.
598
+ # https://arxiv.org/abs/2210.15097
599
+
600
+ if exists(amateur_model):
601
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
602
+ amateur_logits, next_amateur_cache = amateur(
603
+ x,
604
+ return_intermediates = True,
605
+ cache = amateur_cache,
606
+ seq_start_pos = seq_start_pos,
607
+ **kwargs
608
+ )
609
+
610
+ amateur_logits = amateur_logits[:, -1]
611
+
612
+ assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
613
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
614
+
615
+ if cache_kv and amateur.can_cache_kv:
616
+ amateur_caches[i] = next_amateur_cache
617
+
618
+ # filter by top_k, top_p (nucleus), top_a, or custom
619
+
620
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
621
+
622
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
623
+
624
+ sample = torch.multinomial(probs, 1)
625
+
626
+ out = torch.cat((out, sample), dim=-1)
627
+
628
+ if verbose:
629
+ if sl % 32 == 0:
630
+ print(sl, '/', seq_len)
631
+
632
+ if exists(eos_token):
633
+ is_eos_tokens = (out == eos_token)
634
+
635
+ if is_eos_tokens.any(dim = -1).all():
636
+ # mask out everything after the eos tokens
637
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
638
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
639
+ out = out.masked_fill(mask, self.pad_value)
640
+
641
+ if verbose:
642
+ print('Model called the end of sequence at:', sl, '/', seq_len)
643
+
644
+ break
645
+
646
+ if return_prime:
647
+ return out[:, :]
648
+
649
+ else:
650
+ return out[:, t:]
651
+
652
+ # out, = unpack(out, ps, '* n')
653
+
654
+ # return out
655
+
656
+ def compute_accuracy(self, logits, labels):
657
+ out = torch.argmax(logits, dim=-1)
658
+ out = out.flatten()
659
+ labels = labels.flatten()
660
+
661
+ mask = (labels != self.ignore_index) # can also be self.pad_value (your choice)
662
+ out = out[mask]
663
+ labels = labels[mask]
664
+
665
+ num_right = (out == labels)
666
+ num_right = torch.sum(num_right).type(torch.float32)
667
+
668
+ acc = num_right / len(labels)
669
+ return acc
670
+
671
+ def forward(self, x, **kwargs):
672
+ seq, ignore_index, add_attn_z_loss = x.shape[1], self.ignore_index, self.add_attn_z_loss
673
+
674
+ inp, target = x[:, :-1], x[:, 1:]
675
+ inp = torch.where(inp == ignore_index, self.pad_value, inp)
676
+
677
+ if self.mask_prob > 0.:
678
+ rand = torch.randn(inp.shape, device = x.device)
679
+ rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
680
+ num_mask = min(int(seq * self.mask_prob), seq - 1)
681
+ indices = rand.topk(num_mask, dim = -1).indices
682
+ mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
683
+ kwargs.update(self_attn_kv_mask = mask)
684
+
685
+ logits, cache = self.net(
686
+ inp,
687
+ return_intermediates = True,
688
+ return_attn_z_loss = add_attn_z_loss,
689
+ **kwargs
690
+ )
691
+
692
+ acc = self.compute_accuracy(logits, target)
693
+
694
+ loss = F.cross_entropy(
695
+ rearrange(logits, 'b n c -> b c n'),
696
+ target,
697
+ ignore_index = ignore_index
698
+ )
699
+
700
+ if add_attn_z_loss:
701
+ loss = loss + cache.attn_z_loss
702
+
703
+ return loss, acc
704
+
705
+ #===============================================================================
706
+
707
+ import math
708
+ from random import random
709
+
710
+ import torch
711
+ from torch import nn, einsum, Tensor
712
+ import torch.nn.functional as F
713
+
714
+ from functools import partial, wraps
715
+ from inspect import isfunction
716
+ from collections import namedtuple
717
+ from dataclasses import dataclass
718
+ from typing import List, Callable, Optional
719
+
720
+ from einops import rearrange, repeat, reduce, pack, unpack
721
+ from einops.layers.torch import Rearrange
722
+
723
+ # constants
724
+
725
+ DEFAULT_DIM_HEAD = 64
726
+
727
+ @dataclass
728
+ class LayerIntermediates:
729
+ hiddens: Optional[List[Tensor]] = None
730
+ attn_intermediates: Optional[List[Intermediates]] = None
731
+ layer_hiddens: Optional[List[Tensor]] = None
732
+ attn_z_loss: Optional[Tensor] = None
733
+ mems: Optional[Tensor] = None
734
+
735
+ # helpers
736
+
737
+ def exists(val):
738
+ return val is not None
739
+
740
+ def default(val, d):
741
+ if exists(val):
742
+ return val
743
+ return d() if isfunction(d) else d
744
+
745
+ def cast_tuple(val, depth):
746
+ return val if isinstance(val, tuple) else (val,) * depth
747
+
748
+ def divisible_by(num, den):
749
+ return (num % den) == 0
750
+
751
+ def maybe(fn):
752
+ @wraps(fn)
753
+ def inner(x, *args, **kwargs):
754
+ if not exists(x):
755
+ return x
756
+ return fn(x, *args, **kwargs)
757
+ return inner
758
+
759
+ class always():
760
+ def __init__(self, val):
761
+ self.val = val
762
+ def __call__(self, *args, **kwargs):
763
+ return self.val
764
+
765
+ class not_equals():
766
+ def __init__(self, val):
767
+ self.val = val
768
+ def __call__(self, x, *args, **kwargs):
769
+ return x != self.val
770
+
771
+ class equals():
772
+ def __init__(self, val):
773
+ self.val = val
774
+ def __call__(self, x, *args, **kwargs):
775
+ return x == self.val
776
+
777
+ def Sequential(*modules):
778
+ return nn.Sequential(*filter(exists, modules))
779
+
780
+ # tensor helpers
781
+
782
+ def max_neg_value(tensor):
783
+ return -torch.finfo(tensor.dtype).max
784
+
785
+ def l2norm(t, groups = 1):
786
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
787
+ t = F.normalize(t, p = 2, dim = -1)
788
+ return rearrange(t, '... g d -> ... (g d)')
789
+
790
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
791
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
792
+ zeros = ((0, 0) * dims_from_right)
793
+ return F.pad(t, (*zeros, *pad), value = value)
794
+
795
+ def or_reduce(masks):
796
+ head, *body = masks
797
+ for rest in body:
798
+ head = head | rest
799
+ return head
800
+
801
+ # auxiliary loss helpers
802
+
803
+ def calc_z_loss(
804
+ pre_softmax_attns: List[Tensor],
805
+ mask = None,
806
+ weight = 1.
807
+ ):
808
+ # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906
809
+ # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects
810
+ # also used in PaLM as one of the measures
811
+
812
+ lse = 0.
813
+
814
+ for attn in pre_softmax_attns:
815
+ lse = lse + attn.logsumexp(dim = -1)
816
+
817
+ loss = torch.square(lse)
818
+ loss = reduce(loss, 'b h n -> b n', 'sum')
819
+
820
+ if not exists(mask):
821
+ return loss.mean() * weight
822
+
823
+ loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5)
824
+ return loss * weight
825
+
826
+ # init helpers
827
+
828
+ def init_zero_(layer):
829
+ nn.init.constant_(layer.weight, 0.)
830
+ if exists(layer.bias):
831
+ nn.init.constant_(layer.bias, 0.)
832
+
833
+ # keyword argument helpers
834
+
835
+ def pick_and_pop(keys, d):
836
+ values = list(map(lambda key: d.pop(key), keys))
837
+ return dict(zip(keys, values))
838
+
839
+ def group_dict_by_key(cond, d):
840
+ return_val = [dict(),dict()]
841
+ for key in d.keys():
842
+ match = bool(cond(key))
843
+ ind = int(not match)
844
+ return_val[ind][key] = d[key]
845
+ return (*return_val,)
846
+
847
+ def string_begins_with(prefix, str):
848
+ return str.startswith(prefix)
849
+
850
+ def group_by_key_prefix(prefix, d):
851
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
852
+
853
+ def groupby_prefix_and_trim(prefix, d):
854
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
855
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
856
+ return kwargs_without_prefix, kwargs
857
+
858
+ # structured dropout, more effective than traditional attention dropouts
859
+
860
+ def dropout_seq(seq, mask, dropout):
861
+ b, n, *_, device = *seq.shape, seq.device
862
+ logits = torch.randn(b, n, device = device)
863
+
864
+ if exists(mask):
865
+ mask_value = max_neg_value(logits)
866
+ logits = logits.masked_fill(~mask, mask_value)
867
+
868
+ keep_prob = 1. - dropout
869
+ num_keep = max(1, int(keep_prob * n))
870
+ keep_indices = logits.topk(num_keep, dim = 1).indices
871
+
872
+ batch_indices = torch.arange(b, device = device)
873
+ batch_indices = rearrange(batch_indices, 'b -> b 1')
874
+
875
+ seq = seq[batch_indices, keep_indices]
876
+
877
+ if exists(mask):
878
+ seq_counts = mask.sum(dim = -1)
879
+ seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
880
+ keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
881
+
882
+ mask = mask[batch_indices, keep_indices] & keep_mask
883
+
884
+ return seq, mask
885
+
886
+ # activations
887
+
888
+ class ReluSquared(nn.Module):
889
+ def forward(self, x):
890
+ return F.relu(x) ** 2
891
+
892
+ # embedding
893
+
894
+ class TokenEmbedding(nn.Module):
895
+ def __init__(self, dim, num_tokens, l2norm_embed = False):
896
+ super().__init__()
897
+ self.l2norm_embed = l2norm_embed
898
+ self.emb = nn.Embedding(num_tokens, dim)
899
+
900
+ def forward(self, x):
901
+ token_emb = self.emb(x)
902
+ return l2norm(token_emb) if self.l2norm_embed else token_emb
903
+
904
+ # positional embeddings
905
+
906
+ class AbsolutePositionalEmbedding(nn.Module):
907
+ def __init__(self, dim, max_seq_len, l2norm_embed = False):
908
+ super().__init__()
909
+ self.scale = dim ** -0.5 if not l2norm_embed else 1.
910
+ self.max_seq_len = max_seq_len
911
+ self.l2norm_embed = l2norm_embed
912
+ self.emb = nn.Embedding(max_seq_len, dim)
913
+
914
+ def forward(self, x, pos = None, seq_start_pos = None):
915
+ seq_len, device = x.shape[1], x.device
916
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
917
+
918
+ if not exists(pos):
919
+ pos = torch.arange(seq_len, device = device)
920
+
921
+ if exists(seq_start_pos):
922
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
923
+
924
+ pos_emb = self.emb(pos)
925
+ pos_emb = pos_emb * self.scale
926
+ return l2norm(pos_emb) if self.l2norm_embed else pos_emb
927
+
928
+ class ScaledSinusoidalEmbedding(nn.Module):
929
+ def __init__(self, dim, theta = 10000):
930
+ super().__init__()
931
+ assert divisible_by(dim, 2)
932
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
933
+
934
+ half_dim = dim // 2
935
+ freq_seq = torch.arange(half_dim).float() / half_dim
936
+ inv_freq = theta ** -freq_seq
937
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
938
+
939
+ def forward(self, x, pos = None, seq_start_pos = None):
940
+ seq_len, device = x.shape[1], x.device
941
+
942
+ if not exists(pos):
943
+ pos = torch.arange(seq_len, device = device)
944
+
945
+ if exists(seq_start_pos):
946
+ pos = pos - seq_start_pos[..., None]
947
+
948
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
949
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
950
+ return emb * self.scale
951
+
952
+ class RelativePositionBias(nn.Module):
953
+ def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
954
+ super().__init__()
955
+ self.scale = scale
956
+ self.causal = causal
957
+ self.num_buckets = num_buckets
958
+ self.max_distance = max_distance
959
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
960
+
961
+ @staticmethod
962
+ def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
963
+ ret = 0
964
+ n = -relative_position
965
+ if not causal:
966
+ num_buckets //= 2
967
+ ret += (n < 0).long() * num_buckets
968
+ n = torch.abs(n)
969
+ else:
970
+ n = torch.max(n, torch.zeros_like(n))
971
+
972
+ max_exact = num_buckets // 2
973
+ is_small = n < max_exact
974
+
975
+ val_if_large = max_exact + (
976
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
977
+ ).long()
978
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
979
+
980
+ ret += torch.where(is_small, n, val_if_large)
981
+ return ret
982
+
983
+ @property
984
+ def device(self):
985
+ return next(self.parameters()).device
986
+
987
+ def forward(self, i, j):
988
+ device = self.device
989
+ q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
990
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
991
+ rel_pos = k_pos[None, :] - q_pos[:, None]
992
+ rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
993
+ values = self.relative_attention_bias(rp_bucket)
994
+ bias = rearrange(values, 'i j h -> h i j')
995
+ return bias * self.scale
996
+
997
+ class DynamicPositionBias(nn.Module):
998
+ def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
999
+ super().__init__()
1000
+ assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
1001
+ self.log_distance = log_distance
1002
+
1003
+ self.mlp = nn.ModuleList([])
1004
+
1005
+ self.mlp.append(Sequential(
1006
+ nn.Linear(1, dim),
1007
+ nn.LayerNorm(dim) if norm else None,
1008
+ nn.SiLU()
1009
+ ))
1010
+
1011
+ for _ in range(depth - 1):
1012
+ self.mlp.append(Sequential(
1013
+ nn.Linear(dim, dim),
1014
+ nn.LayerNorm(dim) if norm else None,
1015
+ nn.SiLU()
1016
+ ))
1017
+
1018
+ self.mlp.append(nn.Linear(dim, heads))
1019
+
1020
+ @property
1021
+ def device(self):
1022
+ return next(self.parameters()).device
1023
+
1024
+ def forward(self, i, j):
1025
+ assert i == j
1026
+ n, device = j, self.device
1027
+
1028
+ # get the (n x n) matrix of distances
1029
+ seq_arange = torch.arange(n, device = device)
1030
+ context_arange = torch.arange(n, device = device)
1031
+ indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
1032
+ indices += (n - 1)
1033
+
1034
+ # input to continuous positions MLP
1035
+ pos = torch.arange(-n + 1, n, device = device).float()
1036
+ pos = rearrange(pos, '... -> ... 1')
1037
+
1038
+ if self.log_distance:
1039
+ pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
1040
+
1041
+ for layer in self.mlp:
1042
+ pos = layer(pos)
1043
+
1044
+ # get position biases
1045
+ bias = pos[indices]
1046
+ bias = rearrange(bias, 'i j h -> h i j')
1047
+ return bias
1048
+
1049
+ class AlibiPositionalBias(nn.Module):
1050
+ def __init__(self, heads, total_heads, **kwargs):
1051
+ super().__init__()
1052
+ self.heads = heads
1053
+ self.total_heads = total_heads
1054
+
1055
+ slopes = Tensor(self._get_slopes(heads))
1056
+ slopes = rearrange(slopes, 'h -> h 1 1')
1057
+ self.register_buffer('slopes', slopes, persistent = False)
1058
+ self.register_buffer('bias', None, persistent = False)
1059
+
1060
+ def get_bias(self, i, j, device):
1061
+ i_arange = torch.arange(j - i, j, device = device)
1062
+ j_arange = torch.arange(j, device = device)
1063
+ bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
1064
+ return bias
1065
+
1066
+ @staticmethod
1067
+ def _get_slopes(heads):
1068
+ def get_slopes_power_of_2(n):
1069
+ start = (2**(-2**-(math.log2(n)-3)))
1070
+ ratio = start
1071
+ return [start*ratio**i for i in range(n)]
1072
+
1073
+ if math.log2(heads).is_integer():
1074
+ return get_slopes_power_of_2(heads)
1075
+
1076
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
1077
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
1078
+
1079
+ @property
1080
+ def device(self):
1081
+ return next(self.buffers()).device
1082
+
1083
+ def forward(self, i, j):
1084
+ h, device = self.total_heads, self.device
1085
+
1086
+ if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
1087
+ return self.bias[..., -i:, -j:]
1088
+
1089
+ bias = self.get_bias(i, j, device)
1090
+ bias = bias * self.slopes
1091
+
1092
+ num_heads_unalibied = h - bias.shape[0]
1093
+ bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
1094
+ self.register_buffer('bias', bias, persistent = False)
1095
+
1096
+ return self.bias
1097
+
1098
+ class RotaryEmbedding(nn.Module):
1099
+ def __init__(
1100
+ self,
1101
+ dim,
1102
+ use_xpos = False,
1103
+ scale_base = 512,
1104
+ interpolation_factor = 1.,
1105
+ base = 10000,
1106
+ base_rescale_factor = 1.
1107
+ ):
1108
+ super().__init__()
1109
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
1110
+ # has some connection to NTK literature
1111
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
1112
+ base *= base_rescale_factor ** (dim / (dim - 2))
1113
+
1114
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
1115
+ self.register_buffer('inv_freq', inv_freq)
1116
+
1117
+ assert interpolation_factor >= 1.
1118
+ self.interpolation_factor = interpolation_factor
1119
+
1120
+ if not use_xpos:
1121
+ self.register_buffer('scale', None)
1122
+ return
1123
+
1124
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
1125
+
1126
+ self.scale_base = scale_base
1127
+ self.register_buffer('scale', scale)
1128
+
1129
+ def forward(self, seq_len):
1130
+ device = self.inv_freq.device
1131
+ t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
1132
+
1133
+ t = t / self.interpolation_factor
1134
+
1135
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
1136
+ freqs = torch.cat((freqs, freqs), dim = -1)
1137
+
1138
+ if not exists(self.scale):
1139
+ return freqs, 1.
1140
+
1141
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
1142
+ scale = self.scale ** rearrange(power, 'n -> n 1')
1143
+ scale = torch.cat((scale, scale), dim = -1)
1144
+
1145
+ return freqs, scale
1146
+
1147
+
1148
+ def rotate_half(x):
1149
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
1150
+ x1, x2 = x.unbind(dim = -2)
1151
+ return torch.cat((-x2, x1), dim = -1)
1152
+
1153
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
1154
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
1155
+ freqs = freqs[-seq_len:, :]
1156
+
1157
+ if t.ndim == 4 and freqs.ndim == 3:
1158
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
1159
+
1160
+ # partial rotary embeddings, Wang et al. GPT-J
1161
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
1162
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
1163
+ return torch.cat((t, t_unrotated), dim = -1)
1164
+
1165
+ # norms
1166
+
1167
+ class Scale(nn.Module):
1168
+ def __init__(self, value, fn):
1169
+ super().__init__()
1170
+ self.value = value
1171
+ self.fn = fn
1172
+
1173
+ def forward(self, x, **kwargs):
1174
+ out = self.fn(x, **kwargs)
1175
+ scale_fn = lambda t: t * self.value
1176
+
1177
+ if not isinstance(out, tuple):
1178
+ return scale_fn(out)
1179
+
1180
+ return (scale_fn(out[0]), *out[1:])
1181
+
1182
+ class ScaleNorm(nn.Module):
1183
+ def __init__(self, dim, eps = 1e-5):
1184
+ super().__init__()
1185
+ self.eps = eps
1186
+ self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
1187
+
1188
+ def forward(self, x):
1189
+ norm = torch.norm(x, dim = -1, keepdim = True)
1190
+ return x / norm.clamp(min = self.eps) * self.g
1191
+
1192
+ class RMSNorm(nn.Module):
1193
+ def __init__(self, dim):
1194
+ super().__init__()
1195
+ self.scale = dim ** 0.5
1196
+ self.g = nn.Parameter(torch.ones(dim))
1197
+
1198
+ def forward(self, x):
1199
+ return F.normalize(x, dim = -1) * self.scale * self.g
1200
+
1201
+ class SimpleRMSNorm(nn.Module):
1202
+ def __init__(self, dim):
1203
+ super().__init__()
1204
+ self.scale = dim ** 0.5
1205
+
1206
+ def forward(self, x):
1207
+ return F.normalize(x, dim = -1) * self.scale
1208
+
1209
+ # residual and residual gates
1210
+
1211
+ class Residual(nn.Module):
1212
+ def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
1213
+ super().__init__()
1214
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1215
+ self.scale_residual_constant = scale_residual_constant
1216
+
1217
+ def forward(self, x, residual):
1218
+ if exists(self.residual_scale):
1219
+ residual = residual * self.residual_scale
1220
+
1221
+ if self.scale_residual_constant != 1:
1222
+ residual = residual * self.scale_residual_constant
1223
+
1224
+ return x + residual
1225
+
1226
+ class GRUGating(nn.Module):
1227
+ def __init__(self, dim, scale_residual = False, **kwargs):
1228
+ super().__init__()
1229
+ self.gru = nn.GRUCell(dim, dim)
1230
+ self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
1231
+
1232
+ def forward(self, x, residual):
1233
+ if exists(self.residual_scale):
1234
+ residual = residual * self.residual_scale
1235
+
1236
+ gated_output = self.gru(
1237
+ rearrange(x, 'b n d -> (b n) d'),
1238
+ rearrange(residual, 'b n d -> (b n) d')
1239
+ )
1240
+
1241
+ return gated_output.reshape_as(x)
1242
+
1243
+ # token shifting
1244
+
1245
+ def shift(t, amount, mask = None):
1246
+ if amount == 0:
1247
+ return t
1248
+ else:
1249
+ amount = min(amount, t.shape[1])
1250
+
1251
+ if exists(mask):
1252
+ t = t.masked_fill(~mask[..., None], 0.)
1253
+
1254
+ return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
1255
+
1256
+ class ShiftTokens(nn.Module):
1257
+ def __init__(self, shifts, fn):
1258
+ super().__init__()
1259
+ self.fn = fn
1260
+ self.shifts = tuple(shifts)
1261
+
1262
+ def forward(self, x, **kwargs):
1263
+ mask = kwargs.get('mask', None)
1264
+ shifts = self.shifts
1265
+ segments = len(shifts)
1266
+ feats_per_shift = x.shape[-1] // segments
1267
+ splitted = x.split(feats_per_shift, dim = -1)
1268
+ segments_to_shift, rest = splitted[:segments], splitted[segments:]
1269
+ segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
1270
+ x = torch.cat((*segments_to_shift, *rest), dim = -1)
1271
+ return self.fn(x, **kwargs)
1272
+
1273
+ # feedforward
1274
+
1275
+ class GLU(nn.Module):
1276
+ def __init__(
1277
+ self,
1278
+ dim_in,
1279
+ dim_out,
1280
+ activation: Callable,
1281
+ mult_bias = False
1282
+ ):
1283
+ super().__init__()
1284
+ self.act = activation
1285
+ self.proj = nn.Linear(dim_in, dim_out * 2)
1286
+ self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.
1287
+
1288
+ def forward(self, x):
1289
+ x, gate = self.proj(x).chunk(2, dim = -1)
1290
+ return x * self.act(gate) * self.mult_bias
1291
+
1292
+ class FeedForward(nn.Module):
1293
+ def __init__(
1294
+ self,
1295
+ dim,
1296
+ dim_out = None,
1297
+ mult = 4,
1298
+ glu = False,
1299
+ glu_mult_bias = False,
1300
+ swish = False,
1301
+ relu_squared = False,
1302
+ post_act_ln = False,
1303
+ dropout = 0.,
1304
+ no_bias = False,
1305
+ zero_init_output = False
1306
+ ):
1307
+ super().__init__()
1308
+ inner_dim = int(dim * mult)
1309
+ dim_out = default(dim_out, dim)
1310
+
1311
+ if relu_squared:
1312
+ activation = ReluSquared()
1313
+ elif swish:
1314
+ activation = nn.SiLU()
1315
+ else:
1316
+ activation = nn.GELU()
1317
+
1318
+ if glu:
1319
+ project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias)
1320
+ else:
1321
+ project_in = nn.Sequential(
1322
+ nn.Linear(dim, inner_dim, bias = not no_bias),
1323
+ activation
1324
+ )
1325
+
1326
+ self.ff = Sequential(
1327
+ project_in,
1328
+ nn.LayerNorm(inner_dim) if post_act_ln else None,
1329
+ nn.Dropout(dropout),
1330
+ nn.Linear(inner_dim, dim_out, bias = not no_bias)
1331
+ )
1332
+
1333
+ # init last linear layer to 0
1334
+ if zero_init_output:
1335
+ init_zero_(self.ff[-1])
1336
+
1337
+ def forward(self, x):
1338
+ return self.ff(x)
1339
+
1340
+ # attention. it is all we need
1341
+
1342
+ class Attention(nn.Module):
1343
+ def __init__(
1344
+ self,
1345
+ dim,
1346
+ dim_head = DEFAULT_DIM_HEAD,
1347
+ heads = 8,
1348
+ causal = False,
1349
+ flash = False,
1350
+ talking_heads = False,
1351
+ head_scale = False,
1352
+ sparse_topk = None,
1353
+ num_mem_kv = 0,
1354
+ dropout = 0.,
1355
+ on_attn = False,
1356
+ gate_value_heads = False,
1357
+ gate_values = False,
1358
+ zero_init_output = False,
1359
+ max_attend_past = None,
1360
+ qk_norm = False,
1361
+ qk_norm_groups = 1,
1362
+ qk_norm_scale = 10,
1363
+ qk_norm_dim_scale = False,
1364
+ one_kv_head = False,
1365
+ kv_heads = None,
1366
+ shared_kv = False,
1367
+ value_dim_head = None,
1368
+ tensor_product = False, # https://arxiv.org/abs/2208.06061
1369
+ add_zero_kv = False, # same as add_zero_attn in pytorch
1370
+ rotary_embed_values = False,
1371
+ onnxable = False
1372
+ ):
1373
+ super().__init__()
1374
+ self.scale = dim_head ** -0.5
1375
+
1376
+ self.heads = heads
1377
+ self.causal = causal
1378
+ self.max_attend_past = max_attend_past
1379
+
1380
+ assert not (exists(kv_heads) and one_kv_head), 'either attn_one_kv_head is set to True (in which case kv_heads is set to 1), or attn_kv_heads is set, but not both'
1381
+
1382
+ value_dim_head = default(value_dim_head, dim_head)
1383
+ kv_heads = default(kv_heads, heads)
1384
+
1385
+ kv_heads = 1 if one_kv_head else kv_heads
1386
+ assert divisible_by(heads, kv_heads)
1387
+
1388
+ self.kv_heads = kv_heads
1389
+
1390
+ q_dim = dim_head * heads
1391
+ k_dim = dim_head * kv_heads
1392
+ v_dim = value_dim_head * kv_heads
1393
+ out_dim = value_dim_head * heads
1394
+
1395
+ self.to_q = nn.Linear(dim, q_dim, bias = False)
1396
+ self.to_k = nn.Linear(dim, k_dim, bias = False)
1397
+
1398
+ # shared key / values, for further memory savings during inference
1399
+ assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1400
+ self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
1401
+
1402
+ # relations projection from tp-attention
1403
+ self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
1404
+
1405
+ # add GLU gating for aggregated values, from alphafold2
1406
+ self.to_v_gate = None
1407
+ if gate_values:
1408
+ self.to_v_gate = nn.Linear(dim, out_dim)
1409
+ nn.init.constant_(self.to_v_gate.weight, 0)
1410
+ nn.init.constant_(self.to_v_gate.bias, 10)
1411
+
1412
+ # add per head gating of the output values, from 'Attend to nothing' paper
1413
+ self.to_v_head_gate = None
1414
+ if gate_value_heads:
1415
+ self.to_v_head_gate = nn.Linear(dim, heads)
1416
+ nn.init.constant_(self.to_v_head_gate.weight, 0)
1417
+ nn.init.constant_(self.to_v_head_gate.bias, 10)
1418
+
1419
+ # cosine sim attention
1420
+ self.qk_norm = qk_norm
1421
+ self.qk_norm_groups = qk_norm_groups
1422
+ self.qk_norm_scale = qk_norm_scale
1423
+
1424
+ # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
1425
+ self.qk_norm_dim_scale = qk_norm_dim_scale
1426
+
1427
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1428
+ if qk_norm and qk_norm_dim_scale:
1429
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1430
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1431
+
1432
+ assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1433
+ assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1434
+
1435
+ # attend class - includes core attention algorithm + talking heads
1436
+
1437
+ self.attend = Attend(
1438
+ heads = heads,
1439
+ causal = causal,
1440
+ talking_heads = talking_heads,
1441
+ dropout = dropout,
1442
+ sparse_topk = sparse_topk,
1443
+ qk_norm = qk_norm,
1444
+ scale = qk_norm_scale if qk_norm else self.scale,
1445
+ add_zero_kv = add_zero_kv,
1446
+ flash = flash,
1447
+ onnxable = onnxable
1448
+ )
1449
+
1450
+ # head scaling
1451
+ self.head_scale = head_scale
1452
+ if head_scale:
1453
+ self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
1454
+
1455
+ # explicit topk sparse attention
1456
+ self.sparse_topk = sparse_topk
1457
+
1458
+ # add memory key / values
1459
+ self.num_mem_kv = num_mem_kv
1460
+ if num_mem_kv > 0:
1461
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1462
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
1463
+
1464
+ # attention on attention
1465
+ self.attn_on_attn = on_attn
1466
+ self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
1467
+
1468
+ # whether to rotate positions into values, for absolute positions in addition to relative
1469
+ self.rotary_embed_values = rotary_embed_values
1470
+
1471
+ # init output projection 0
1472
+ if zero_init_output:
1473
+ init_zero_(self.to_out)
1474
+
1475
+ def forward(
1476
+ self,
1477
+ x,
1478
+ context = None,
1479
+ mask = None,
1480
+ context_mask = None,
1481
+ attn_mask = None,
1482
+ rel_pos = None,
1483
+ rotary_pos_emb = None,
1484
+ prev_attn = None,
1485
+ mem = None,
1486
+ return_intermediates = False,
1487
+ cache: Optional[Intermediates] = None,
1488
+ ):
1489
+ b, n, _, h, kv_h, head_scale, device, has_context = *x.shape, self.heads, self.kv_heads, self.head_scale, x.device, exists(context)
1490
+ kv_input = default(context, x)
1491
+
1492
+ q_input = x
1493
+ k_input = kv_input
1494
+ v_input = kv_input
1495
+ r_input = x
1496
+
1497
+ if exists(mem):
1498
+ k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1499
+ v_input, _ = pack([mem, v_input], 'b * d')
1500
+
1501
+ q = self.to_q(q_input)
1502
+ k = self.to_k(k_input)
1503
+ v = self.to_v(v_input) if exists(self.to_v) else k
1504
+ r = self.to_r(r_input) if exists(self.to_r) else None
1505
+
1506
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1507
+
1508
+ k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h), (k, v, r))
1509
+
1510
+ if exists(cache) and not has_context:
1511
+ ck, cv = cache.cached_kv
1512
+
1513
+ if exists(mem):
1514
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
1515
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
1516
+
1517
+ k = torch.cat((ck, k), dim = -2)
1518
+ v = torch.cat((cv, v), dim = -2)
1519
+
1520
+ if exists(mem):
1521
+ k = torch.cat((mk, k), dim = -2)
1522
+ v = torch.cat((mv, v), dim = -2)
1523
+
1524
+ if return_intermediates:
1525
+ mem_len = mem.shape[-2] if exists(mem) else 0
1526
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1527
+
1528
+ if self.qk_norm:
1529
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1530
+ q, k = map(qk_l2norm, (q, k))
1531
+ scale = self.qk_norm_scale
1532
+
1533
+ q = q * self.qk_norm_q_scale
1534
+ k = k * self.qk_norm_k_scale
1535
+
1536
+ if exists(rotary_pos_emb) and not has_context:
1537
+ freqs, xpos_scale = rotary_pos_emb
1538
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1539
+
1540
+ q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1541
+ k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1542
+
1543
+ if self.rotary_embed_values:
1544
+ v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1545
+
1546
+ input_mask = context_mask
1547
+
1548
+ if not exists(input_mask) and not has_context:
1549
+ input_mask = mask
1550
+
1551
+ if self.num_mem_kv > 0:
1552
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1553
+
1554
+ if self.qk_norm:
1555
+ mem_k = l2norm(mem_k)
1556
+ mem_k = mem_k * self.qk_norm_k_scale
1557
+
1558
+ k = torch.cat((mem_k, k), dim = -2)
1559
+ v = torch.cat((mem_v, v), dim = -2)
1560
+
1561
+ if exists(input_mask):
1562
+ input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1563
+
1564
+ i, j = map(lambda t: t.shape[-2], (q, k))
1565
+
1566
+ # determine masking
1567
+
1568
+ mask_value = max_neg_value(q)
1569
+ masks = []
1570
+ final_attn_mask = None
1571
+
1572
+ if exists(input_mask):
1573
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1574
+ masks.append(~input_mask)
1575
+
1576
+ if exists(attn_mask):
1577
+ assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1578
+ if attn_mask.ndim == 2:
1579
+ attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1580
+ elif attn_mask.ndim == 3:
1581
+ attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1582
+ masks.append(~attn_mask)
1583
+
1584
+ if exists(self.max_attend_past):
1585
+ range_q = torch.arange(j - i, j, device = device)
1586
+ range_k = torch.arange(j, device = device)
1587
+ dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1588
+ max_attend_past_mask = dist > self.max_attend_past
1589
+ masks.append(max_attend_past_mask)
1590
+
1591
+ if len(masks) > 0:
1592
+ final_attn_mask = ~or_reduce(masks)
1593
+
1594
+ # prepare relative positional bias, if needed
1595
+
1596
+ attn_bias = None
1597
+ if exists(rel_pos):
1598
+ attn_bias = rel_pos(i, j)
1599
+
1600
+ # attention is all we need
1601
+
1602
+ out, intermediates = self.attend(
1603
+ q, k, v,
1604
+ mask = final_attn_mask,
1605
+ attn_bias = attn_bias,
1606
+ prev_attn = prev_attn
1607
+ )
1608
+
1609
+ # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1610
+
1611
+ if exists(r):
1612
+ out = out * r + out
1613
+
1614
+ # normformer scaling of heads
1615
+
1616
+ if head_scale:
1617
+ out = out * self.head_scale_params
1618
+
1619
+ # per head gating, from https://arxiv.org/abs/2306.12929
1620
+
1621
+ if exists(self.to_v_head_gate):
1622
+ head_gate = self.to_v_head_gate(x)
1623
+ out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1624
+
1625
+ # merge heads
1626
+
1627
+ out = rearrange(out, 'b h n d -> b n (h d)')
1628
+
1629
+ # alphafold2 styled gating of the values
1630
+
1631
+ if exists(self.to_v_gate):
1632
+ gates = self.to_v_gate(x)
1633
+ out = out * gates.sigmoid()
1634
+
1635
+ # combine the heads
1636
+
1637
+ out = self.to_out(out)
1638
+
1639
+ if exists(mask):
1640
+ mask = rearrange(mask, 'b n -> b n 1')
1641
+ out = out.masked_fill(~mask, 0.)
1642
+
1643
+ if not return_intermediates:
1644
+ return out
1645
+
1646
+ intermediates.cached_kv = cached_kv
1647
+
1648
+ return out, intermediates
1649
+
1650
+ class AttentionLayers(nn.Module):
1651
+ def __init__(
1652
+ self,
1653
+ dim,
1654
+ depth,
1655
+ heads = 8,
1656
+ causal = False,
1657
+ cross_attend = False,
1658
+ only_cross = False,
1659
+ use_scalenorm = False,
1660
+ use_rmsnorm = False,
1661
+ use_simple_rmsnorm = False,
1662
+ alibi_pos_bias = False,
1663
+ alibi_num_heads = None,
1664
+ rel_pos_bias = False,
1665
+ rel_pos_num_buckets = 32,
1666
+ rel_pos_max_distance = 128,
1667
+ dynamic_pos_bias = False,
1668
+ dynamic_pos_bias_log_distance = False,
1669
+ dynamic_pos_bias_mlp_depth = 2,
1670
+ dynamic_pos_bias_norm = False,
1671
+ rotary_pos_emb = False,
1672
+ rotary_emb_dim = None,
1673
+ rotary_xpos = False,
1674
+ rotary_interpolation_factor = 1.,
1675
+ rotary_xpos_scale_base = 512,
1676
+ rotary_base_rescale_factor = 1.,
1677
+ custom_layers = None,
1678
+ sandwich_coef = None,
1679
+ par_ratio = None,
1680
+ weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
1681
+ layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
1682
+ residual_attn = False,
1683
+ cross_residual_attn = False,
1684
+ macaron = False,
1685
+ pre_norm = True,
1686
+ pre_norm_has_final_norm = True,
1687
+ gate_residual = False,
1688
+ scale_residual = False,
1689
+ scale_residual_constant = 1.,
1690
+ shift_tokens = 0,
1691
+ sandwich_norm = False,
1692
+ resi_dual = False,
1693
+ resi_dual_scale = 1.,
1694
+ zero_init_branch_output = False,
1695
+ layer_dropout = 0.,
1696
+ cross_attn_tokens_dropout = 0.,
1697
+ **kwargs
1698
+ ):
1699
+ super().__init__()
1700
+ rotary_pos_emb = rotary_pos_emb or rotary_xpos
1701
+
1702
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1703
+ attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1704
+
1705
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1706
+
1707
+ self.dim = dim
1708
+ self.depth = depth
1709
+ self.causal = causal
1710
+ self.layers = nn.ModuleList([])
1711
+
1712
+ self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1713
+
1714
+ rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1715
+
1716
+ assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1717
+ self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1718
+
1719
+ assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1720
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1721
+
1722
+ # relative positional bias
1723
+
1724
+ flash_attn = attn_kwargs.get('flash', False)
1725
+ assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1726
+
1727
+ self.rel_pos = None
1728
+ if rel_pos_bias:
1729
+ assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1730
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1731
+ elif dynamic_pos_bias:
1732
+ assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1733
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1734
+ elif alibi_pos_bias:
1735
+ alibi_num_heads = default(alibi_num_heads, heads)
1736
+ assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1737
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1738
+
1739
+ assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1740
+ assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1741
+
1742
+ if resi_dual:
1743
+ pre_norm = False
1744
+
1745
+ self.pre_norm = pre_norm
1746
+ self.sandwich_norm = sandwich_norm
1747
+
1748
+ self.resi_dual = resi_dual
1749
+ assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1750
+ self.resi_dual_scale = resi_dual_scale
1751
+
1752
+ self.residual_attn = residual_attn
1753
+ self.cross_residual_attn = cross_residual_attn
1754
+ assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
1755
+
1756
+ self.cross_attend = cross_attend
1757
+
1758
+ assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1759
+
1760
+ if use_scalenorm:
1761
+ norm_class = ScaleNorm
1762
+ elif use_rmsnorm:
1763
+ norm_class = RMSNorm
1764
+ elif use_simple_rmsnorm:
1765
+ norm_class = SimpleRMSNorm
1766
+ else:
1767
+ norm_class = nn.LayerNorm
1768
+
1769
+ norm_fn = partial(norm_class, dim)
1770
+
1771
+ if cross_attend and not only_cross:
1772
+ default_block = ('a', 'c', 'f')
1773
+ elif cross_attend and only_cross:
1774
+ default_block = ('c', 'f')
1775
+ else:
1776
+ default_block = ('a', 'f')
1777
+
1778
+ if macaron:
1779
+ default_block = ('f',) + default_block
1780
+
1781
+ # zero init
1782
+
1783
+ if zero_init_branch_output:
1784
+ attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1785
+ ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1786
+
1787
+ # setup weight tying, which is a special case of `layer_execute_order`
1788
+
1789
+ assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1790
+
1791
+ if weight_tie_layers:
1792
+ assert not exists(layers_execute_order)
1793
+ layers_execute_order = tuple(range(len(default_block))) * depth
1794
+ depth = 1
1795
+
1796
+ # calculate layer block order
1797
+
1798
+ if exists(custom_layers):
1799
+ layer_types = custom_layers
1800
+ elif exists(par_ratio):
1801
+ par_depth = depth * len(default_block)
1802
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1803
+ default_block = tuple(filter(not_equals('f'), default_block))
1804
+ par_attn = par_depth // par_ratio
1805
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1806
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
1807
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1808
+ par_block = default_block + ('f',) * (par_width - len(default_block))
1809
+ par_head = par_block * par_attn
1810
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
1811
+ elif exists(sandwich_coef):
1812
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1813
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1814
+ else:
1815
+ layer_types = default_block * depth
1816
+
1817
+ self.layer_types = layer_types
1818
+ self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1819
+
1820
+ assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1821
+
1822
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1823
+
1824
+ # stochastic depth
1825
+
1826
+ self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1827
+
1828
+ # structured dropout for cross attending
1829
+
1830
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1831
+
1832
+ # calculate token shifting
1833
+
1834
+ shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1835
+
1836
+ # whether it has post norm
1837
+
1838
+ self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
1839
+
1840
+ # iterate and construct layers
1841
+
1842
+ for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1843
+ is_last_layer = ind == (len(self.layer_types) - 1)
1844
+
1845
+ if layer_type == 'a':
1846
+ layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1847
+ elif layer_type == 'c':
1848
+ layer = Attention(dim, heads = heads, **attn_kwargs)
1849
+ elif layer_type == 'f':
1850
+ layer = FeedForward(dim, **ff_kwargs)
1851
+ layer = layer if not macaron else Scale(0.5, layer)
1852
+ else:
1853
+ raise Exception(f'invalid layer type {layer_type}')
1854
+
1855
+ if layer_shift_tokens > 0:
1856
+ shift_range_upper = layer_shift_tokens + 1
1857
+ shift_range_lower = -layer_shift_tokens if not causal else 0
1858
+ layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1859
+
1860
+ residual_fn = GRUGating if gate_residual else Residual
1861
+ residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1862
+
1863
+ pre_branch_norm = norm_fn() if pre_norm else None
1864
+ post_branch_norm = norm_fn() if sandwich_norm else None
1865
+ post_main_norm = norm_fn() if not pre_norm else None
1866
+
1867
+ norms = nn.ModuleList([
1868
+ pre_branch_norm,
1869
+ post_branch_norm,
1870
+ post_main_norm
1871
+ ])
1872
+
1873
+ self.layers.append(nn.ModuleList([
1874
+ norms,
1875
+ layer,
1876
+ residual
1877
+ ]))
1878
+
1879
+ def forward(
1880
+ self,
1881
+ x,
1882
+ context = None,
1883
+ mask = None,
1884
+ context_mask = None,
1885
+ attn_mask = None,
1886
+ self_attn_kv_mask = None,
1887
+ mems = None,
1888
+ seq_start_pos: Optional[Tensor] = None,
1889
+ cache: Optional[LayerIntermediates] = None,
1890
+ cache_age = 1,
1891
+ return_hiddens = False
1892
+ ):
1893
+ assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1894
+
1895
+ # initialize accums
1896
+
1897
+ hiddens = []
1898
+ layer_hiddens = []
1899
+ intermediates = []
1900
+
1901
+ prev_attn = None
1902
+ prev_cross_attn = None
1903
+
1904
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1905
+
1906
+ # handle left padded sequences
1907
+
1908
+ if exists(seq_start_pos):
1909
+ seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
1910
+ left_pad_mask = seq_arange >= seq_start_pos[..., None]
1911
+
1912
+ if exists(self_attn_kv_mask):
1913
+ self_attn_kv_mask = self_attn_kv_mask & left_pad_mask
1914
+ else:
1915
+ self_attn_kv_mask = left_pad_mask
1916
+
1917
+ # rotary positions
1918
+
1919
+ rotary_pos_emb = None
1920
+
1921
+ if exists(self.rotary_pos_emb):
1922
+ max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1923
+ rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
1924
+
1925
+ # assume cached key / values
1926
+
1927
+ attn_cache = []
1928
+
1929
+ if exists(cache):
1930
+ assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
1931
+
1932
+ if cache_age > 0:
1933
+ x = x[:, -cache_age:] # for spec decoding, may be greater than 1
1934
+
1935
+ attn_cache = cache.attn_intermediates
1936
+
1937
+ iter_attn_cache = iter(attn_cache)
1938
+
1939
+ # outer residual - for resiDual paper
1940
+
1941
+ outer_residual = x * self.resi_dual_scale
1942
+
1943
+ # get layers to be executed
1944
+
1945
+ layer_variables = (
1946
+ self.layer_types,
1947
+ self.layers,
1948
+ self.layer_dropouts
1949
+ )
1950
+
1951
+ layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1952
+
1953
+ # go through the attention and feedforward layers
1954
+
1955
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1956
+ is_last = ind == (len(self.layers) - 1)
1957
+
1958
+ if self.training and layer_dropout > 0. and random() < layer_dropout:
1959
+ continue
1960
+
1961
+ if layer_type == 'a':
1962
+ if return_hiddens:
1963
+ hiddens.append(x)
1964
+ layer_mem = mems.pop(0) if mems else None
1965
+
1966
+ if layer_type == 'c':
1967
+ if self.training and self.cross_attn_tokens_dropout > 0.:
1968
+ context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1969
+
1970
+ inner_residual = x
1971
+
1972
+ if return_hiddens:
1973
+ layer_hiddens.append(x)
1974
+
1975
+ pre_norm, post_branch_norm, post_main_norm = norm
1976
+
1977
+ if exists(pre_norm):
1978
+ x = pre_norm(x)
1979
+
1980
+ if layer_type == 'a':
1981
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
1982
+ elif layer_type == 'c':
1983
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1984
+ elif layer_type == 'f':
1985
+ out = block(x)
1986
+
1987
+ if self.resi_dual:
1988
+ outer_residual = outer_residual + out * self.resi_dual_scale
1989
+
1990
+ if exists(post_branch_norm):
1991
+ out = post_branch_norm(out)
1992
+
1993
+ x = residual_fn(out, inner_residual)
1994
+
1995
+ if layer_type in ('a', 'c') and return_hiddens:
1996
+ intermediates.append(inter)
1997
+
1998
+ if layer_type == 'a' and self.residual_attn:
1999
+ prev_attn = inter.pre_softmax_attn
2000
+ elif layer_type == 'c' and self.cross_residual_attn:
2001
+ prev_cross_attn = inter.pre_softmax_attn
2002
+
2003
+ if exists(post_main_norm):
2004
+ x = post_main_norm(x)
2005
+
2006
+ if return_hiddens:
2007
+ layer_hiddens.append(x)
2008
+
2009
+ if self.resi_dual:
2010
+ x = x + self.final_norm(outer_residual)
2011
+ else:
2012
+ x = self.final_norm(x)
2013
+
2014
+ if not return_hiddens:
2015
+ return x
2016
+
2017
+ intermediates = LayerIntermediates(
2018
+ hiddens = hiddens,
2019
+ attn_intermediates = intermediates,
2020
+ layer_hiddens = layer_hiddens
2021
+ )
2022
+
2023
+ return x, intermediates
2024
+
2025
+ class Encoder(AttentionLayers):
2026
+ def __init__(self, **kwargs):
2027
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
2028
+ super().__init__(causal = False, **kwargs)
2029
+
2030
+ class Decoder(AttentionLayers):
2031
+ def __init__(self, **kwargs):
2032
+ assert 'causal' not in kwargs, 'cannot set causality on decoder'
2033
+ super().__init__(causal = True, **kwargs)
2034
+
2035
+ class CrossAttender(AttentionLayers):
2036
+ def __init__(self, **kwargs):
2037
+ super().__init__(cross_attend = True, only_cross = True, **kwargs)
2038
+
2039
+ class ViTransformerWrapper(nn.Module):
2040
+ def __init__(
2041
+ self,
2042
+ *,
2043
+ image_size,
2044
+ patch_size,
2045
+ attn_layers,
2046
+ channels = 3,
2047
+ num_classes = None,
2048
+ post_emb_norm = False,
2049
+ num_register_tokens = 0,
2050
+ emb_dropout = 0.
2051
+ ):
2052
+ super().__init__()
2053
+ assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
2054
+ assert divisible_by(image_size, patch_size), 'image dimensions must be divisible by the patch size'
2055
+ dim = attn_layers.dim
2056
+ num_patches = (image_size // patch_size) ** 2
2057
+ patch_dim = channels * patch_size ** 2
2058
+
2059
+ self.patch_size = patch_size
2060
+
2061
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
2062
+
2063
+ has_register_tokens = num_register_tokens > 0
2064
+ self.has_register_tokens = has_register_tokens
2065
+
2066
+ if has_register_tokens:
2067
+ self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))
2068
+
2069
+ self.patch_to_embedding = nn.Sequential(
2070
+ nn.LayerNorm(patch_dim),
2071
+ nn.Linear(patch_dim, dim),
2072
+ nn.LayerNorm(dim)
2073
+ )
2074
+
2075
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2076
+ self.dropout = nn.Dropout(emb_dropout)
2077
+
2078
+ self.attn_layers = attn_layers
2079
+
2080
+ self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
2081
+
2082
+ def forward(
2083
+ self,
2084
+ img,
2085
+ return_embeddings = False
2086
+ ):
2087
+ b, p = img.shape[0], self.patch_size
2088
+
2089
+ x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
2090
+ x = self.patch_to_embedding(x)
2091
+ n = x.shape[1]
2092
+
2093
+ x = x + self.pos_embedding[:, :n]
2094
+
2095
+ x = self.post_emb_norm(x)
2096
+ x = self.dropout(x)
2097
+
2098
+ if self.has_register_tokens:
2099
+ r = repeat(self.register_tokens, 'n d -> b n d', b = b)
2100
+ x, ps = pack((x, r), 'b * d')
2101
+
2102
+ x = self.attn_layers(x)
2103
+
2104
+ if self.has_register_tokens:
2105
+ x, _ = unpack(x, ps, 'b * d')
2106
+
2107
+ if not exists(self.mlp_head) or return_embeddings:
2108
+ return x
2109
+
2110
+ x = x.mean(dim = -2)
2111
+ return self.mlp_head(x)
2112
+
2113
+ class TransformerWrapper(nn.Module):
2114
+ def __init__(
2115
+ self,
2116
+ *,
2117
+ num_tokens,
2118
+ max_seq_len,
2119
+ attn_layers,
2120
+ emb_dim = None,
2121
+ max_mem_len = 0,
2122
+ shift_mem_down = 0,
2123
+ emb_dropout = 0.,
2124
+ post_emb_norm = False,
2125
+ num_memory_tokens = None,
2126
+ memory_tokens_interspersed_every = None,
2127
+ tie_embedding = False,
2128
+ logits_dim = None,
2129
+ use_abs_pos_emb = True,
2130
+ scaled_sinu_pos_emb = False,
2131
+ l2norm_embed = False,
2132
+ emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
2133
+ attn_z_loss_weight = 1e-4,
2134
+ ):
2135
+ super().__init__()
2136
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2137
+
2138
+ dim = attn_layers.dim
2139
+ emb_dim = default(emb_dim, dim)
2140
+ self.emb_dim = emb_dim
2141
+ self.num_tokens = num_tokens
2142
+
2143
+ self.max_seq_len = max_seq_len
2144
+ self.max_mem_len = max_mem_len
2145
+ self.shift_mem_down = shift_mem_down
2146
+
2147
+ self.l2norm_embed = l2norm_embed
2148
+ self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
2149
+
2150
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2151
+ self.pos_emb = always(0)
2152
+ elif scaled_sinu_pos_emb:
2153
+ self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
2154
+ else:
2155
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
2156
+
2157
+ self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
2158
+
2159
+ self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
2160
+ self.emb_dropout = nn.Dropout(emb_dropout)
2161
+
2162
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
2163
+ self.attn_layers = attn_layers
2164
+
2165
+ self.init_()
2166
+
2167
+ logits_dim = default(logits_dim, num_tokens)
2168
+ self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t()
2169
+
2170
+ # memory tokens (like [cls]) from Memory Transformers paper
2171
+
2172
+ num_memory_tokens = default(num_memory_tokens, 0)
2173
+ self.num_memory_tokens = num_memory_tokens
2174
+ if num_memory_tokens > 0:
2175
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
2176
+
2177
+ self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2178
+
2179
+ # whether can do cached kv decoding
2180
+
2181
+ self.can_cache_kv = self.num_memory_tokens == 0
2182
+
2183
+ def init_(self):
2184
+ if self.l2norm_embed:
2185
+ nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
2186
+ if not isinstance(self.pos_emb, always):
2187
+ nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
2188
+ return
2189
+
2190
+ nn.init.kaiming_normal_(self.token_emb.emb.weight)
2191
+
2192
+ def forward(
2193
+ self,
2194
+ x,
2195
+ return_embeddings = False,
2196
+ return_logits_and_embeddings = False,
2197
+ return_intermediates = False,
2198
+ mask = None,
2199
+ return_mems = False,
2200
+ return_attn = False,
2201
+ mems = None,
2202
+ pos = None,
2203
+ prepend_embeds = None,
2204
+ sum_embeds = None,
2205
+ return_attn_z_loss = False,
2206
+ attn_z_loss_weight = 1e-4,
2207
+ seq_start_pos = None,
2208
+ cache: Optional[LayerIntermediates] = None,
2209
+ **kwargs
2210
+ ):
2211
+ b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2212
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2213
+
2214
+ # absolute positional embedding
2215
+
2216
+ external_pos_emb = exists(pos) and pos.dtype != torch.long
2217
+ pos_emb = self.pos_emb(x, pos = pos, seq_start_pos = seq_start_pos) if not external_pos_emb else pos
2218
+ x = self.token_emb(x) + pos_emb
2219
+
2220
+ # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
2221
+
2222
+ if exists(sum_embeds):
2223
+ x = x + sum_embeds
2224
+
2225
+ # post embedding norm, purportedly leads to greater stabilization
2226
+
2227
+ x = self.post_emb_norm(x)
2228
+
2229
+ # whether to append embeds, as in PaLI, for image embeddings
2230
+
2231
+ if exists(prepend_embeds):
2232
+ prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2233
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2234
+
2235
+ x = torch.cat((prepend_embeds, x), dim = -2)
2236
+
2237
+ # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2238
+
2239
+ if emb_frac_gradient < 1:
2240
+ assert emb_frac_gradient > 0
2241
+ x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
2242
+
2243
+ # embedding dropout
2244
+
2245
+ x = self.emb_dropout(x)
2246
+
2247
+ x = self.project_emb(x)
2248
+
2249
+ if has_memory_tokens:
2250
+ mem_every = self.memory_tokens_interspersed_every
2251
+
2252
+ if exists(mem_every):
2253
+ assert mem_every > 0
2254
+ assert isinstance(self.attn_layers, Decoder), 'only for decoder'
2255
+ next_seq_len = math.ceil(n / mem_every) * mem_every
2256
+
2257
+ x = pad_at_dim(x, (0, next_seq_len - n), dim = -2, value = 0.)
2258
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = mem_every)
2259
+
2260
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b = x.shape[0])
2261
+ x, mem_packed_shape = pack((mem, x), 'b * d')
2262
+
2263
+ # auto-handle masking after appending memory tokens
2264
+ if not exists(mem_every) and exists(mask):
2265
+ mask = pad_at_dim(mask, (num_mems, 0), dim = -1, value = True)
2266
+
2267
+ if exists(mem_every):
2268
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2269
+
2270
+ if self.shift_mem_down and exists(mems):
2271
+ mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
2272
+ mems = [*mems_r, *mems_l]
2273
+
2274
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, cache = cache, return_hiddens = True, seq_start_pos = seq_start_pos, **kwargs)
2275
+
2276
+ if has_memory_tokens:
2277
+ if exists(mem_every):
2278
+ x = rearrange(x, 'b (n m) d -> (b n) m d', m = (mem_every + num_mems))
2279
+
2280
+ mem, x = unpack(x, mem_packed_shape, 'b * d')
2281
+
2282
+ if exists(mem_every):
2283
+ x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2284
+
2285
+ x = x[:, :n]
2286
+
2287
+ if return_logits_and_embeddings:
2288
+ out = (self.to_logits(x), x)
2289
+ elif return_embeddings:
2290
+ out = x
2291
+ else:
2292
+ out = self.to_logits(x)
2293
+
2294
+ if return_attn_z_loss:
2295
+ pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates))
2296
+ intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight)
2297
+ return_intermediates = True
2298
+
2299
+ if return_mems:
2300
+ hiddens = intermediates.hiddens
2301
+ new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
2302
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
2303
+
2304
+ if not return_intermediates:
2305
+ return out, new_mems
2306
+
2307
+ intermediates.mems = new_mems
2308
+
2309
+ if return_intermediates:
2310
+ return out, intermediates
2311
+
2312
+ if return_attn:
2313
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2314
+ return out, attn_maps
2315
+
2316
+ return out
2317
+
2318
+ class ContinuousTransformerWrapper(nn.Module):
2319
+ def __init__(
2320
+ self,
2321
+ *,
2322
+ max_seq_len,
2323
+ attn_layers,
2324
+ dim_in = None,
2325
+ dim_out = None,
2326
+ emb_dim = None,
2327
+ max_mem_len = 0,
2328
+ post_emb_norm = False,
2329
+ emb_dropout = 0.,
2330
+ use_abs_pos_emb = True,
2331
+ scaled_sinu_pos_emb = False
2332
+ ):
2333
+ super().__init__()
2334
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
2335
+
2336
+ dim = attn_layers.dim
2337
+
2338
+ self.max_seq_len = max_seq_len
2339
+
2340
+ self.max_mem_len = max_mem_len
2341
+
2342
+ if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
2343
+ self.pos_emb = always(0)
2344
+ elif scaled_sinu_pos_emb:
2345
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
2346
+ else:
2347
+ self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
2348
+
2349
+ self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
2350
+ self.emb_dropout = nn.Dropout(emb_dropout)
2351
+
2352
+ self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
2353
+
2354
+ self.attn_layers = attn_layers
2355
+
2356
+ self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
2357
+
2358
+ def forward(
2359
+ self,
2360
+ x,
2361
+ return_embeddings = False,
2362
+ return_intermediates = False,
2363
+ return_mems = False,
2364
+ mask = None,
2365
+ return_attn = False,
2366
+ mems = None,
2367
+ pos = None,
2368
+ prepend_embeds = None,
2369
+ **kwargs
2370
+ ):
2371
+ x = self.project_in(x)
2372
+ x = x + self.pos_emb(x, pos = pos)
2373
+
2374
+ x = self.post_emb_norm(x)
2375
+
2376
+ # whether to append embeds, as in PaLI, for image embeddings
2377
+
2378
+ if exists(prepend_embeds):
2379
+ _, prepend_dim = prepend_embeds.shape[1:]
2380
+ assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
2381
+
2382
+ x = torch.cat((prepend_embeds, x), dim = -2)
2383
+
2384
+ x = self.emb_dropout(x)
2385
+
2386
+ x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
2387
+
2388
+ out = self.project_out(x) if not return_embeddings else x
2389
+
2390
+ if return_intermediates:
2391
+ return out, intermediates
2392
+
2393
+ if return_mems:
2394
+ hiddens = intermediates.hiddens
2395
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), hiddens))
2396
+ return out, new_mems
2397
+
2398
+ if return_attn:
2399
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
2400
+ return out, attn_maps
2401
+
2402
+ return out
2403
+
2404
+ class XTransformer(nn.Module):
2405
+ def __init__(
2406
+ self,
2407
+ *,
2408
+ dim,
2409
+ tie_token_emb = False,
2410
+ ignore_index = -100,
2411
+ pad_value = 0,
2412
+ cross_attn_tokens_dropout = 0.,
2413
+ **kwargs
2414
+ ):
2415
+ super().__init__()
2416
+ enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
2417
+ dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
2418
+
2419
+ assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
2420
+ enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
2421
+ enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
2422
+ enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
2423
+ enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
2424
+ enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
2425
+
2426
+ dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
2427
+ dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
2428
+ dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
2429
+ dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
2430
+
2431
+ self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
2432
+
2433
+ self.encoder = TransformerWrapper(
2434
+ **enc_transformer_kwargs,
2435
+ attn_layers = Encoder(dim = dim, **enc_kwargs)
2436
+ )
2437
+
2438
+ self.decoder = TransformerWrapper(
2439
+ **dec_transformer_kwargs,
2440
+ attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
2441
+ )
2442
+
2443
+ if tie_token_emb:
2444
+ self.decoder.token_emb = self.encoder.token_emb
2445
+
2446
+ self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
2447
+
2448
+ @torch.no_grad()
2449
+ def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
2450
+ encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
2451
+ return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
2452
+
2453
+ def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
2454
+
2455
+ if exists(src_prepend_embeds) and exists(mask):
2456
+ mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
2457
+
2458
+ enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
2459
+
2460
+ if self.training and self.cross_attn_tokens_dropout > 0:
2461
+ enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
2462
+
2463
+ out = self.decoder(tgt, context = enc, context_mask = mask)
2464
+ return out