DeepLearning101 commited on
Commit
437e42f
1 Parent(s): f4b6e70

Upload 6 files

Browse files
models/basic_modules/adapter.py ADDED
@@ -0,0 +1,1060 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom models for few-shot learning specific operations."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import transformers
6
+ import torch.nn.functional as F
7
+ from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
8
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertForSequenceClassification, BertModel, \
9
+ BertOnlyMLMHead
10
+ from transformers.models.roberta.modeling_roberta import *
11
+ from transformers.models.bert.modeling_bert import *
12
+ from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2PreTrainedModel, DebertaV2Model, StableDropout, \
13
+ ContextPooler, DebertaV2OnlyMLMHead
14
+ from transformers.models.deberta.modeling_deberta import DebertaPreTrainedModel, DebertaModel, StableDropout, \
15
+ ContextPooler, DebertaOnlyMLMHead
16
+ from transformers.modeling_outputs import SequenceClassifierOutput
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ import logging
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # adapter_choice: LiST, houlsby, lora
24
+
25
+ # add by wjn
26
+ def init_adapter(model, std=0.0002):
27
+ with torch.no_grad():
28
+ for name, param in model.named_parameters():
29
+ init_value = 0
30
+ if "adapter_proj" in name:
31
+ if std > 0:
32
+ init_value += torch.normal(0, std, size=param.size())
33
+ param.copy_(init_value)
34
+ return model
35
+
36
+
37
+ # Adapter Layer
38
+ class AdapeterLayer(nn.Module):
39
+ def __init__(self, n_in, n_out=None, adapter_dim=128, adapter_choice="LiST"):
40
+ super(AdapeterLayer, self).__init__()
41
+ if not n_out:
42
+ n_out = n_in
43
+
44
+ self.adapter_choice = adapter_choice
45
+ self.act_fun = None
46
+
47
+ if self.adapter_choice == "LiST":
48
+ self.adapter_dim = adapter_dim
49
+ self.adapter_proj_1 = nn.Linear(n_out, adapter_dim, bias=False)
50
+ nn.init.normal_(self.adapter_proj_1.weight, std=0.02)
51
+ self.adapter_proj_2 = nn.Linear(adapter_dim, n_out, bias=False)
52
+ nn.init.normal_(self.adapter_proj_2.weight, std=0.02)
53
+
54
+ elif self.adapter_choice == "houlsby":
55
+ self.adapter_dim = adapter_dim
56
+ self.adapter_proj_1 = nn.Linear(n_out, adapter_dim, bias=False)
57
+ nn.init.normal_(self.adapter_proj_1.weight, std=0.02)
58
+ self.act_fun = torch.nn.ReLU()
59
+ self.adapter_proj_2 = nn.Linear(adapter_dim, n_out, bias=False)
60
+ nn.init.normal_(self.adapter_proj_2.weight, std=0.02)
61
+
62
+ else:
63
+ self.adapter_dim = adapter_dim
64
+ self.adapter_proj_1 = nn.Linear(n_out, n_out, bias=False)
65
+
66
+
67
+ def forward(self, x):
68
+ if self.adapter_choice == "LiST":
69
+ result = torch.matmul(x, self.adapter_proj_1.weight.type_as(x).T)
70
+ result = torch.matmul(result, self.adapter_proj_2.weight.type_as(x).T)
71
+ return result + x
72
+ elif self.adapter_choice == "houlsby":
73
+ result = torch.matmul(x, self.adapter_proj_1.weight.type_as(x).T)
74
+ if self.act_fun is not None:
75
+ result = self.act_fun(result)
76
+ result = torch.matmul(result, self.adapter_proj_2.weight.type_as(x).T)
77
+ return result + x
78
+
79
+ else:
80
+ result = torch.matmul(x, self.adapter_proj_1.weight.type_as(x).T)
81
+ return result
82
+
83
+
84
+ ## ======== Adapter For RoBERTa ========
85
+ class RobertaAdaOutput(nn.Module):
86
+ def __init__(self, config):
87
+ super().__init__()
88
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
89
+ self.config = config
90
+ self.adaptation_layer = AdapeterLayer(n_in=config.intermediate_size, n_out=config.hidden_size,
91
+ adapter_dim=config.adapter_dim, adapter_choice=config.adapter_choice)
92
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
93
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
94
+
95
+ def forward(self, hidden_states, input_tensor):
96
+ hidden_states = self.dense(hidden_states)
97
+ hidden_states = self.adaptation_layer(hidden_states)
98
+ hidden_states = self.dropout(hidden_states)
99
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
100
+ return hidden_states
101
+
102
+
103
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
104
+ class RobertaAdaSelfOutput(nn.Module):
105
+ def __init__(self, config):
106
+ super().__init__()
107
+ self.config = config
108
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
109
+ self.adaptation_layer = AdapeterLayer(n_in=config.intermediate_size, n_out=config.hidden_size,
110
+ adapter_dim=config.adapter_dim, adapter_choice=config.adapter_choice)
111
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
112
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
113
+
114
+
115
+ def forward(self, hidden_states, input_tensor):
116
+ hidden_states = self.dense(hidden_states)
117
+ hidden_states = self.adaptation_layer(hidden_states)
118
+ hidden_states = self.dropout(hidden_states)
119
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
120
+ return hidden_states
121
+
122
+
123
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
124
+ class RobertaAdaAttention(nn.Module):
125
+ def __init__(self, config):
126
+ super().__init__()
127
+ self.self = RobertaSelfAttention(config)
128
+ self.output = RobertaAdaSelfOutput(config)
129
+ self.pruned_heads = set()
130
+
131
+
132
+ def prune_heads(self, heads):
133
+ if len(heads) == 0:
134
+ return
135
+ heads, index = find_pruneable_heads_and_indices(
136
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
137
+ )
138
+
139
+ # Prune linear layers
140
+ self.self.query = prune_linear_layer(self.self.query, index)
141
+ self.self.key = prune_linear_layer(self.self.key, index)
142
+ self.self.value = prune_linear_layer(self.self.value, index)
143
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
144
+
145
+ # Update hyper params and store pruned heads
146
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
147
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
148
+ self.pruned_heads = self.pruned_heads.union(heads)
149
+
150
+
151
+ def forward(
152
+ self,
153
+ hidden_states,
154
+ attention_mask=None,
155
+ head_mask=None,
156
+ encoder_hidden_states=None,
157
+ encoder_attention_mask=None,
158
+ past_key_value=None,
159
+ output_attentions=False,
160
+ ):
161
+ self_outputs = self.self(
162
+ hidden_states,
163
+ attention_mask,
164
+ head_mask,
165
+ encoder_hidden_states,
166
+ encoder_attention_mask,
167
+ past_key_value,
168
+ output_attentions,
169
+ )
170
+ attention_output = self.output(self_outputs[0], hidden_states)
171
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
172
+ return outputs
173
+
174
+
175
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
176
+ class RobertaAdaLayer(nn.Module):
177
+ def __init__(self, config):
178
+ super().__init__()
179
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
180
+ self.seq_len_dim = 1
181
+ self.attention = RobertaAdaAttention(config)
182
+ self.is_decoder = config.is_decoder
183
+ self.add_cross_attention = config.add_cross_attention
184
+ if self.add_cross_attention:
185
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
186
+ self.crossattention = RobertaAttention(config)
187
+ self.intermediate = RobertaIntermediate(config)
188
+ self.output = RobertaAdaOutput(config)
189
+
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states,
194
+ attention_mask=None,
195
+ head_mask=None,
196
+ encoder_hidden_states=None,
197
+ encoder_attention_mask=None,
198
+ past_key_value=None,
199
+ output_attentions=False,
200
+ ):
201
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
202
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
203
+ self_attention_outputs = self.attention(
204
+ hidden_states,
205
+ attention_mask,
206
+ head_mask,
207
+ output_attentions=output_attentions,
208
+ past_key_value=self_attn_past_key_value,
209
+ )
210
+ attention_output = self_attention_outputs[0]
211
+
212
+ # if decoder, the last output is tuple of self-attn cache
213
+ if self.is_decoder:
214
+ outputs = self_attention_outputs[1:-1]
215
+ present_key_value = self_attention_outputs[-1]
216
+ else:
217
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
218
+
219
+ cross_attn_present_key_value = None
220
+
221
+ if self.is_decoder and encoder_hidden_states is not None:
222
+ assert hasattr(
223
+ self, "crossattention"
224
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
225
+
226
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
227
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
228
+ cross_attention_outputs = self.crossattention(
229
+ attention_output,
230
+ attention_mask,
231
+ head_mask,
232
+ encoder_hidden_states,
233
+ encoder_attention_mask,
234
+ cross_attn_past_key_value,
235
+ output_attentions,
236
+ )
237
+ attention_output = cross_attention_outputs[0]
238
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
239
+
240
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
241
+ cross_attn_present_key_value = cross_attention_outputs[-1]
242
+ present_key_value = present_key_value + cross_attn_present_key_value
243
+
244
+ layer_output = apply_chunking_to_forward(
245
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
246
+ )
247
+ outputs = (layer_output,) + outputs
248
+
249
+ # if decoder, return the attn key/values as the last output
250
+ if self.is_decoder:
251
+ outputs = outputs + (present_key_value,)
252
+
253
+ return outputs
254
+
255
+
256
+ def feed_forward_chunk(self, attention_output):
257
+ intermediate_output = self.intermediate(attention_output)
258
+ layer_output = self.output(intermediate_output, attention_output)
259
+ return layer_output
260
+
261
+
262
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
263
+ class RobertaAdaEncoder(nn.Module):
264
+ def __init__(self, config):
265
+ super().__init__()
266
+ self.config = config
267
+ self.layer = nn.ModuleList([RobertaAdaLayer(config) for _ in range(config.num_hidden_layers)])
268
+ self.skip = 2
269
+
270
+
271
+ def learn_init(
272
+ self,
273
+ hidden_states,
274
+ attention_mask=None,
275
+ head_mask=None,
276
+ encoder_hidden_states=None,
277
+ encoder_attention_mask=None,
278
+ past_key_values=None,
279
+ use_cache=None,
280
+ output_attentions=False,
281
+ output_hidden_states=False,
282
+ return_dict=True):
283
+
284
+ all_hidden_states = () if output_hidden_states else None
285
+ all_self_attentions = () if output_attentions else None
286
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
287
+
288
+ next_decoder_cache = () if use_cache else None
289
+ self.skip_list = []
290
+ for i, layer_module in enumerate(self.layer):
291
+ # if i+1 % self.skip
292
+ if output_hidden_states:
293
+ all_hidden_states = all_hidden_states + (hidden_states,)
294
+
295
+ layer_head_mask = head_mask[i] if head_mask is not None else None
296
+ past_key_value = past_key_values[i] if past_key_values is not None else None
297
+
298
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
299
+
300
+ if use_cache:
301
+ logger.warning(
302
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
303
+ "`use_cache=False`..."
304
+ )
305
+ use_cache = False
306
+
307
+ def create_custom_forward(module):
308
+ def custom_forward(*inputs):
309
+ return module(*inputs, past_key_value, output_attentions)
310
+
311
+ return custom_forward
312
+
313
+ layer_outputs = torch.utils.checkpoint.checkpoint(
314
+ create_custom_forward(layer_module),
315
+ hidden_states,
316
+ attention_mask,
317
+ layer_head_mask,
318
+ encoder_hidden_states,
319
+ encoder_attention_mask,
320
+ )
321
+ else:
322
+ layer_outputs = layer_module(
323
+ hidden_states,
324
+ attention_mask,
325
+ layer_head_mask,
326
+ encoder_hidden_states,
327
+ encoder_attention_mask,
328
+ past_key_value,
329
+ output_attentions,
330
+ )
331
+
332
+ hidden_states = layer_outputs[0]
333
+ if use_cache:
334
+ next_decoder_cache += (layer_outputs[-1],)
335
+ if output_attentions:
336
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
337
+ if self.config.add_cross_attention:
338
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
339
+
340
+ if output_hidden_states:
341
+ all_hidden_states = all_hidden_states + (hidden_states,)
342
+
343
+ if not return_dict:
344
+ return tuple(
345
+ v
346
+ for v in [
347
+ hidden_states,
348
+ next_decoder_cache,
349
+ all_hidden_states,
350
+ all_self_attentions,
351
+ all_cross_attentions,
352
+ ]
353
+ if v is not None
354
+ )
355
+
356
+ return BaseModelOutputWithPastAndCrossAttentions(
357
+ last_hidden_state=hidden_states,
358
+ past_key_values=next_decoder_cache,
359
+ hidden_states=all_hidden_states,
360
+ attentions=all_self_attentions,
361
+ cross_attentions=all_cross_attentions,
362
+ )
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states,
367
+ attention_mask=None,
368
+ head_mask=None,
369
+ encoder_hidden_states=None,
370
+ encoder_attention_mask=None,
371
+ past_key_values=None,
372
+ use_cache=None,
373
+ output_attentions=False,
374
+ output_hidden_states=False,
375
+ return_dict=True,
376
+ ):
377
+ all_hidden_states = () if output_hidden_states else None
378
+ all_self_attentions = () if output_attentions else None
379
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
380
+
381
+ next_decoder_cache = () if use_cache else None
382
+ for i, layer_module in enumerate(self.layer):
383
+ # if (i+1) % 3 == 0:
384
+ # continue
385
+ if output_hidden_states:
386
+ all_hidden_states = all_hidden_states + (hidden_states,)
387
+
388
+ layer_head_mask = head_mask[i] if head_mask is not None else None
389
+ past_key_value = past_key_values[i] if past_key_values is not None else None
390
+
391
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
392
+
393
+ if use_cache:
394
+ logger.warning(
395
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
396
+ "`use_cache=False`..."
397
+ )
398
+ use_cache = False
399
+
400
+ def create_custom_forward(module):
401
+ def custom_forward(*inputs):
402
+ return module(*inputs, past_key_value, output_attentions)
403
+
404
+ return custom_forward
405
+
406
+ layer_outputs = torch.utils.checkpoint.checkpoint(
407
+ create_custom_forward(layer_module),
408
+ hidden_states,
409
+ attention_mask,
410
+ layer_head_mask,
411
+ encoder_hidden_states,
412
+ encoder_attention_mask,
413
+ )
414
+ else:
415
+ layer_outputs = layer_module(
416
+ hidden_states,
417
+ attention_mask,
418
+ layer_head_mask,
419
+ encoder_hidden_states,
420
+ encoder_attention_mask,
421
+ past_key_value,
422
+ output_attentions,
423
+ )
424
+
425
+ hidden_states = layer_outputs[0]
426
+ if use_cache:
427
+ next_decoder_cache += (layer_outputs[-1],)
428
+ if output_attentions:
429
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
430
+ if self.config.add_cross_attention:
431
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
432
+
433
+ if output_hidden_states:
434
+ all_hidden_states = all_hidden_states + (hidden_states,)
435
+
436
+ if not return_dict:
437
+ return tuple(
438
+ v
439
+ for v in [
440
+ hidden_states,
441
+ next_decoder_cache,
442
+ all_hidden_states,
443
+ all_self_attentions,
444
+ all_cross_attentions,
445
+ ]
446
+ if v is not None
447
+ )
448
+ return BaseModelOutputWithPastAndCrossAttentions(
449
+ last_hidden_state=hidden_states,
450
+ past_key_values=next_decoder_cache,
451
+ hidden_states=all_hidden_states,
452
+ attentions=all_self_attentions,
453
+ cross_attentions=all_cross_attentions,
454
+ )
455
+
456
+ """RoBERTa for Adapter"""
457
+ class RobertaAdaModel(RobertaPreTrainedModel):
458
+ """
459
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
460
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
461
+ all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
462
+ Kaiser and Illia Polosukhin.
463
+ To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
464
+ set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
465
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
466
+ input to the forward pass.
467
+ .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762
468
+ """
469
+
470
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
471
+
472
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
473
+ def __init__(self, config, add_pooling_layer=True):
474
+ super().__init__(config)
475
+ self.config = config
476
+ self.embeddings = RobertaEmbeddings(config)
477
+ self.encoder = RobertaAdaEncoder(config)
478
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
479
+
480
+ def get_input_embeddings(self):
481
+ return self.embeddings.word_embeddings
482
+
483
+ def set_input_embeddings(self, value):
484
+ self.embeddings.word_embeddings = value
485
+
486
+ def _prune_heads(self, heads_to_prune):
487
+ """
488
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
489
+ class PreTrainedModel
490
+ """
491
+ for layer, heads in heads_to_prune.items():
492
+ self.encoder.layer[layer].attention.prune_heads(heads)
493
+
494
+ # Copied from transformers.models.bert.modeling_bert.BertModel.forward
495
+ def forward(
496
+ self,
497
+ input_ids=None,
498
+ attention_mask=None,
499
+ token_type_ids=None,
500
+ position_ids=None,
501
+ head_mask=None,
502
+ inputs_embeds=None,
503
+ encoder_hidden_states=None,
504
+ encoder_attention_mask=None,
505
+ past_key_values=None,
506
+ use_cache=None,
507
+ output_attentions=None,
508
+ output_hidden_states=None,
509
+ return_dict=None,
510
+ ):
511
+ r"""
512
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
513
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
514
+ the model is configured as a decoder.
515
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
516
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
517
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
518
+ - 1 for tokens that are **not masked**,
519
+ - 0 for tokens that are **masked**.
520
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
521
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
522
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
523
+ (those that don"t have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
524
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
525
+ use_cache (:obj:`bool`, `optional`):
526
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
527
+ decoding (see :obj:`past_key_values`).
528
+ """
529
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
530
+ output_hidden_states = (
531
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
532
+ )
533
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
534
+
535
+ if self.config.is_decoder:
536
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
537
+ else:
538
+ use_cache = False
539
+
540
+ if input_ids is not None and inputs_embeds is not None:
541
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
542
+ elif input_ids is not None:
543
+ input_shape = input_ids.size()
544
+ batch_size, seq_length = input_shape
545
+ elif inputs_embeds is not None:
546
+ input_shape = inputs_embeds.size()[:-1]
547
+ batch_size, seq_length = input_shape
548
+ else:
549
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
550
+
551
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
552
+
553
+ # past_key_values_length
554
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
555
+
556
+ if attention_mask is None:
557
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
558
+
559
+ if token_type_ids is None:
560
+ if hasattr(self.embeddings, "token_type_ids"):
561
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
562
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
563
+ token_type_ids = buffered_token_type_ids_expanded
564
+ else:
565
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
566
+
567
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
568
+ # ourselves in which case we just need to make it broadcastable to all heads.
569
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
570
+
571
+ # If a 2D or 3D attention mask is provided for the cross-attention
572
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
573
+ if self.config.is_decoder and encoder_hidden_states is not None:
574
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
575
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
576
+ if encoder_attention_mask is None:
577
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
578
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
579
+ else:
580
+ encoder_extended_attention_mask = None
581
+
582
+ # Prepare head mask if needed
583
+ # 1.0 in head_mask indicate we keep the head
584
+ # attention_probs has shape bsz x n_heads x N x N
585
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
586
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
587
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
588
+
589
+ embedding_output = self.embeddings(
590
+ input_ids=input_ids,
591
+ position_ids=position_ids,
592
+ token_type_ids=token_type_ids,
593
+ inputs_embeds=inputs_embeds,
594
+ past_key_values_length=past_key_values_length,
595
+ )
596
+ encoder_outputs = self.encoder(
597
+ embedding_output,
598
+ attention_mask=extended_attention_mask,
599
+ head_mask=head_mask,
600
+ encoder_hidden_states=encoder_hidden_states,
601
+ encoder_attention_mask=encoder_extended_attention_mask,
602
+ past_key_values=past_key_values,
603
+ use_cache=use_cache,
604
+ output_attentions=output_attentions,
605
+ output_hidden_states=output_hidden_states,
606
+ return_dict=return_dict,
607
+ )
608
+ sequence_output = encoder_outputs[0]
609
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
610
+
611
+ if not return_dict:
612
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
613
+
614
+ return BaseModelOutputWithPoolingAndCrossAttentions(
615
+ last_hidden_state=sequence_output,
616
+ pooler_output=pooled_output,
617
+ past_key_values=encoder_outputs.past_key_values,
618
+ hidden_states=encoder_outputs.hidden_states,
619
+ attentions=encoder_outputs.attentions,
620
+ cross_attentions=encoder_outputs.cross_attentions,
621
+ )
622
+
623
+
624
+ ## ======== Adapter For BERT ========
625
+ class BertAdaOutput(nn.Module):
626
+ def __init__(self, config):
627
+ super().__init__()
628
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
629
+ self.config = config
630
+
631
+ self.adaptation_layer = AdapeterLayer(n_in=config.intermediate_size, n_out=config.hidden_size,
632
+ adapter_dim=config.adapter_dim, adapter_choice=config.adapter_choice)
633
+
634
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
635
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
636
+
637
+ def forward(self, hidden_states, input_tensor):
638
+ if self.config.adapter_choice == "lora":
639
+ hidden_states = self.dense(hidden_states) + self.adaptation_layer(hidden_states)
640
+ else:
641
+ hidden_states = self.dense(hidden_states)
642
+ hidden_states = self.adaptation_layer(hidden_states)
643
+ hidden_states = self.dropout(hidden_states)
644
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
645
+ return hidden_states
646
+
647
+ class BertAdaSelfOutput(nn.Module):
648
+ def __init__(self, config):
649
+ super().__init__()
650
+ self.config = config
651
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
652
+ self.adaptation_layer = AdapeterLayer(n_in=config.intermediate_size, n_out=config.hidden_size,
653
+ adapter_dim=config.adapter_dim, adapter_choice=config.adapter_choice)
654
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
655
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
656
+
657
+ def forward(self, hidden_states, input_tensor):
658
+ if self.config.adapter_choice == "lora":
659
+ hidden_states = self.dense(hidden_states) + self.adaptation_layer(hidden_states)
660
+ else:
661
+ hidden_states = self.dense(hidden_states)
662
+ hidden_states = self.adaptation_layer(hidden_states)
663
+ hidden_states = self.dropout(hidden_states)
664
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
665
+ return hidden_states
666
+
667
+
668
+ class BertAdaAttention(nn.Module):
669
+ def __init__(self, config):
670
+ super().__init__()
671
+ self.self = BertSelfAttention(config)
672
+ self.output = BertAdaSelfOutput(config)
673
+ self.pruned_heads = set()
674
+
675
+ def prune_heads(self, heads):
676
+ if len(heads) == 0:
677
+ return
678
+ heads, index = find_pruneable_heads_and_indices(
679
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
680
+ )
681
+
682
+ # Prune linear layers
683
+ self.self.query = prune_linear_layer(self.self.query, index)
684
+ self.self.key = prune_linear_layer(self.self.key, index)
685
+ self.self.value = prune_linear_layer(self.self.value, index)
686
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
687
+
688
+ # Update hyper params and store pruned heads
689
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
690
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
691
+ self.pruned_heads = self.pruned_heads.union(heads)
692
+
693
+ def forward(
694
+ self,
695
+ hidden_states,
696
+ attention_mask=None,
697
+ head_mask=None,
698
+ encoder_hidden_states=None,
699
+ encoder_attention_mask=None,
700
+ past_key_value=None,
701
+ output_attentions=False,
702
+ ):
703
+ self_outputs = self.self(
704
+ hidden_states,
705
+ attention_mask,
706
+ head_mask,
707
+ encoder_hidden_states,
708
+ encoder_attention_mask,
709
+ past_key_value,
710
+ output_attentions,
711
+ )
712
+ attention_output = self.output(self_outputs[0], hidden_states)
713
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
714
+ return outputs
715
+
716
+
717
+ class BertAdaLayer(nn.Module):
718
+ def __init__(self, config):
719
+ super().__init__()
720
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
721
+ self.seq_len_dim = 1
722
+ self.attention = BertAdaAttention(config)
723
+ self.is_decoder = config.is_decoder
724
+ self.add_cross_attention = config.add_cross_attention
725
+ if self.add_cross_attention:
726
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
727
+ self.crossattention = BertAttention(config)
728
+ self.intermediate = BertIntermediate(config)
729
+ self.output = BertAdaOutput(config)
730
+
731
+ def forward(
732
+ self,
733
+ hidden_states,
734
+ attention_mask=None,
735
+ head_mask=None,
736
+ encoder_hidden_states=None,
737
+ encoder_attention_mask=None,
738
+ past_key_value=None,
739
+ output_attentions=False,
740
+ ):
741
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
742
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
743
+ self_attention_outputs = self.attention(
744
+ hidden_states,
745
+ attention_mask,
746
+ head_mask,
747
+ output_attentions=output_attentions,
748
+ past_key_value=self_attn_past_key_value,
749
+ )
750
+ attention_output = self_attention_outputs[0]
751
+
752
+ # if decoder, the last output is tuple of self-attn cache
753
+ if self.is_decoder:
754
+ outputs = self_attention_outputs[1:-1]
755
+ present_key_value = self_attention_outputs[-1]
756
+ else:
757
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
758
+
759
+ cross_attn_present_key_value = None
760
+ if self.is_decoder and encoder_hidden_states is not None:
761
+ assert hasattr(
762
+ self, "crossattention"
763
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
764
+
765
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
766
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
767
+ cross_attention_outputs = self.crossattention(
768
+ attention_output,
769
+ attention_mask,
770
+ head_mask,
771
+ encoder_hidden_states,
772
+ encoder_attention_mask,
773
+ cross_attn_past_key_value,
774
+ output_attentions,
775
+ )
776
+ attention_output = cross_attention_outputs[0]
777
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
778
+
779
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
780
+ cross_attn_present_key_value = cross_attention_outputs[-1]
781
+ present_key_value = present_key_value + cross_attn_present_key_value
782
+
783
+ layer_output = apply_chunking_to_forward(
784
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
785
+ )
786
+ outputs = (layer_output,) + outputs
787
+
788
+ # if decoder, return the attn key/values as the last output
789
+ if self.is_decoder:
790
+ outputs = outputs + (present_key_value,)
791
+
792
+ return outputs
793
+
794
+ def feed_forward_chunk(self, attention_output):
795
+ intermediate_output = self.intermediate(attention_output)
796
+ layer_output = self.output(intermediate_output, attention_output)
797
+ return layer_output
798
+
799
+
800
+ class BertAdaEncoder(nn.Module):
801
+ def __init__(self, config):
802
+ super().__init__()
803
+ self.config = config
804
+ self.layer = nn.ModuleList([BertAdaLayer(config) for _ in range(config.num_hidden_layers)])
805
+
806
+ def forward(
807
+ self,
808
+ hidden_states,
809
+ attention_mask=None,
810
+ head_mask=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=False,
816
+ output_hidden_states=False,
817
+ return_dict=True,
818
+ ):
819
+ all_hidden_states = () if output_hidden_states else None
820
+ all_self_attentions = () if output_attentions else None
821
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
822
+
823
+ next_decoder_cache = () if use_cache else None
824
+ for i, layer_module in enumerate(self.layer):
825
+ if output_hidden_states:
826
+ all_hidden_states = all_hidden_states + (hidden_states,)
827
+
828
+ layer_head_mask = head_mask[i] if head_mask is not None else None
829
+ past_key_value = past_key_values[i] if past_key_values is not None else None
830
+
831
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
832
+
833
+ if use_cache:
834
+ logger.warning(
835
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
836
+ "`use_cache=False`..."
837
+ )
838
+ use_cache = False
839
+
840
+ def create_custom_forward(module):
841
+ def custom_forward(*inputs):
842
+ return module(*inputs, past_key_value, output_attentions)
843
+
844
+ return custom_forward
845
+
846
+ layer_outputs = torch.utils.checkpoint.checkpoint(
847
+ create_custom_forward(layer_module),
848
+ hidden_states,
849
+ attention_mask,
850
+ layer_head_mask,
851
+ encoder_hidden_states,
852
+ encoder_attention_mask,
853
+ )
854
+ else:
855
+ layer_outputs = layer_module(
856
+ hidden_states,
857
+ attention_mask,
858
+ layer_head_mask,
859
+ encoder_hidden_states,
860
+ encoder_attention_mask,
861
+ past_key_value,
862
+ output_attentions,
863
+ )
864
+
865
+ hidden_states = layer_outputs[0]
866
+ if use_cache:
867
+ next_decoder_cache += (layer_outputs[-1],)
868
+ if output_attentions:
869
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
870
+ if self.config.add_cross_attention:
871
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
872
+
873
+ if output_hidden_states:
874
+ all_hidden_states = all_hidden_states + (hidden_states,)
875
+
876
+ if not return_dict:
877
+ return tuple(
878
+ v
879
+ for v in [
880
+ hidden_states,
881
+ next_decoder_cache,
882
+ all_hidden_states,
883
+ all_self_attentions,
884
+ all_cross_attentions,
885
+ ]
886
+ if v is not None
887
+ )
888
+ return BaseModelOutputWithPastAndCrossAttentions(
889
+ last_hidden_state=hidden_states,
890
+ past_key_values=next_decoder_cache,
891
+ hidden_states=all_hidden_states,
892
+ attentions=all_self_attentions,
893
+ cross_attentions=all_cross_attentions,
894
+ )
895
+
896
+
897
+ class BertAdaModel(BertPreTrainedModel):
898
+ """
899
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
900
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
901
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
902
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
903
+ To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
904
+ set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
905
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
906
+ input to the forward pass.
907
+ """
908
+
909
+ def __init__(self, config, add_pooling_layer=True):
910
+ super().__init__(config)
911
+ self.config = config
912
+
913
+ self.embeddings = BertEmbeddings(config)
914
+ self.encoder = BertAdaEncoder(config)
915
+
916
+ self.pooler = BertPooler(config) if add_pooling_layer else None
917
+
918
+ self.init_weights()
919
+
920
+ def get_input_embeddings(self):
921
+ return self.embeddings.word_embeddings
922
+
923
+ def set_input_embeddings(self, value):
924
+ self.embeddings.word_embeddings = value
925
+
926
+ def _prune_heads(self, heads_to_prune):
927
+ """
928
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
929
+ class PreTrainedModel
930
+ """
931
+ for layer, heads in heads_to_prune.items():
932
+ self.encoder.layer[layer].attention.prune_heads(heads)
933
+
934
+
935
+ def forward(
936
+ self,
937
+ input_ids=None,
938
+ attention_mask=None,
939
+ token_type_ids=None,
940
+ position_ids=None,
941
+ head_mask=None,
942
+ inputs_embeds=None,
943
+ encoder_hidden_states=None,
944
+ encoder_attention_mask=None,
945
+ past_key_values=None,
946
+ use_cache=None,
947
+ output_attentions=None,
948
+ output_hidden_states=None,
949
+ return_dict=None,
950
+ ):
951
+ r"""
952
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
953
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
954
+ the model is configured as a decoder.
955
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
956
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
957
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
958
+ - 1 for tokens that are **not masked**,
959
+ - 0 for tokens that are **masked**.
960
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
961
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
962
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
963
+ (those that don"t have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
964
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
965
+ use_cache (:obj:`bool`, `optional`):
966
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
967
+ decoding (see :obj:`past_key_values`).
968
+ """
969
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
970
+ output_hidden_states = (
971
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
972
+ )
973
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
974
+
975
+ if self.config.is_decoder:
976
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
977
+ else:
978
+ use_cache = False
979
+
980
+ if input_ids is not None and inputs_embeds is not None:
981
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
982
+ elif input_ids is not None:
983
+ input_shape = input_ids.size()
984
+ elif inputs_embeds is not None:
985
+ input_shape = inputs_embeds.size()[:-1]
986
+ else:
987
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
988
+
989
+ batch_size, seq_length = input_shape
990
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
991
+
992
+ # past_key_values_length
993
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
994
+
995
+ if attention_mask is None:
996
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
997
+
998
+ if token_type_ids is None:
999
+ if hasattr(self.embeddings, "token_type_ids"):
1000
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1001
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1002
+ token_type_ids = buffered_token_type_ids_expanded
1003
+ else:
1004
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1005
+
1006
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1007
+ # ourselves in which case we just need to make it broadcastable to all heads.
1008
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
1009
+
1010
+ # If a 2D or 3D attention mask is provided for the cross-attention
1011
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1012
+ if self.config.is_decoder and encoder_hidden_states is not None:
1013
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1014
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1015
+ if encoder_attention_mask is None:
1016
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1017
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1018
+ else:
1019
+ encoder_extended_attention_mask = None
1020
+
1021
+ # Prepare head mask if needed
1022
+ # 1.0 in head_mask indicate we keep the head
1023
+ # attention_probs has shape bsz x n_heads x N x N
1024
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1025
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1026
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1027
+
1028
+ embedding_output = self.embeddings(
1029
+ input_ids=input_ids,
1030
+ position_ids=position_ids,
1031
+ token_type_ids=token_type_ids,
1032
+ inputs_embeds=inputs_embeds,
1033
+ past_key_values_length=past_key_values_length,
1034
+ )
1035
+ encoder_outputs = self.encoder(
1036
+ embedding_output,
1037
+ attention_mask=extended_attention_mask,
1038
+ head_mask=head_mask,
1039
+ encoder_hidden_states=encoder_hidden_states,
1040
+ encoder_attention_mask=encoder_extended_attention_mask,
1041
+ past_key_values=past_key_values,
1042
+ use_cache=use_cache,
1043
+ output_attentions=output_attentions,
1044
+ output_hidden_states=output_hidden_states,
1045
+ return_dict=return_dict,
1046
+ )
1047
+ sequence_output = encoder_outputs[0]
1048
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1049
+
1050
+ if not return_dict:
1051
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1052
+
1053
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1054
+ last_hidden_state=sequence_output,
1055
+ pooler_output=pooled_output,
1056
+ past_key_values=encoder_outputs.past_key_values,
1057
+ hidden_states=encoder_outputs.hidden_states,
1058
+ attentions=encoder_outputs.attentions,
1059
+ cross_attentions=encoder_outputs.cross_attentions,
1060
+ )
models/basic_modules/crf.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Optional
4
+
5
+ class CRF(nn.Module):
6
+ """Conditional random field.
7
+ This module implements a conditional random field [LMP01]_. The forward computation
8
+ of this class computes the log likelihood of the given sequence of tags and
9
+ emission score tensor. This class also has `~CRF.decode` method which finds
10
+ the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
11
+ Args:
12
+ num_tags: Number of tags.
13
+ batch_first: Whether the first dimension corresponds to the size of a minibatch.
14
+ Attributes:
15
+ start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
16
+ ``(num_tags,)``.
17
+ end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
18
+ ``(num_tags,)``.
19
+ transitions (`~torch.nn.Parameter`): Transition score tensor of size
20
+ ``(num_tags, num_tags)``.
21
+ .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
22
+ "Conditional random fields: Probabilistic models for segmenting and
23
+ labeling sequence data". *Proc. 18th International Conf. on Machine
24
+ Learning*. Morgan Kaufmann. pp. 282–289.
25
+ .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
26
+ """
27
+
28
+ def __init__(self, num_tags: int, batch_first: bool = False) -> None:
29
+ if num_tags <= 0:
30
+ raise ValueError(f"invalid number of tags: {num_tags}")
31
+ super().__init__()
32
+ self.num_tags = num_tags
33
+ self.batch_first = batch_first
34
+ self.start_transitions = nn.Parameter(torch.empty(num_tags))
35
+ self.end_transitions = nn.Parameter(torch.empty(num_tags))
36
+ self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
37
+
38
+ self.reset_parameters()
39
+
40
+ def reset_parameters(self) -> None:
41
+ """Initialize the transition parameters.
42
+ The parameters will be initialized randomly from a uniform distribution
43
+ between -0.1 and 0.1.
44
+ """
45
+ nn.init.uniform_(self.start_transitions, -0.1, 0.1)
46
+ nn.init.uniform_(self.end_transitions, -0.1, 0.1)
47
+ nn.init.uniform_(self.transitions, -0.1, 0.1)
48
+
49
+ def __repr__(self) -> str:
50
+ return f"{self.__class__.__name__}(num_tags={self.num_tags})"
51
+
52
+ def forward(self, emissions: torch.Tensor,
53
+ tags: torch.LongTensor,
54
+ mask: Optional[torch.ByteTensor] = None,
55
+ reduction: str = "mean") -> torch.Tensor:
56
+ """Compute the conditional log likelihood of a sequence of tags given emission scores.
57
+ Args:
58
+ emissions (`~torch.Tensor`): Emission score tensor of size
59
+ ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
60
+ ``(batch_size, seq_length, num_tags)`` otherwise.
61
+ tags (`~torch.LongTensor`): Sequence of tags tensor of size
62
+ ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
63
+ ``(batch_size, seq_length)`` otherwise.
64
+ mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
65
+ if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
66
+ reduction: Specifies the reduction to apply to the output:
67
+ ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
68
+ ``sum``: the output will be summed over batches. ``mean``: the output will be
69
+ averaged over batches. ``token_mean``: the output will be averaged over tokens.
70
+ Returns:
71
+ `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
72
+ reduction is ``none``, ``()`` otherwise.
73
+ """
74
+ if reduction not in ("none", "sum", "mean", "token_mean"):
75
+ raise ValueError(f"invalid reduction: {reduction}")
76
+ if mask is None:
77
+ mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)
78
+ if mask.dtype != torch.uint8:
79
+ mask = mask.byte()
80
+ self._validate(emissions, tags=tags, mask=mask)
81
+
82
+ if self.batch_first:
83
+ emissions = emissions.transpose(0, 1)
84
+ tags = tags.transpose(0, 1)
85
+ mask = mask.transpose(0, 1)
86
+
87
+ # shape: (batch_size,)
88
+ numerator = self._compute_score(emissions, tags, mask)
89
+ # shape: (batch_size,)
90
+ denominator = self._compute_normalizer(emissions, mask)
91
+ # shape: (batch_size,)
92
+ llh = numerator - denominator
93
+
94
+ if reduction == "none":
95
+ return llh
96
+ if reduction == "sum":
97
+ return llh.sum()
98
+ if reduction == "mean":
99
+ return llh.mean()
100
+ return llh.sum() / mask.float().sum()
101
+
102
+ def decode(self, emissions: torch.Tensor,
103
+ mask: Optional[torch.ByteTensor] = None,
104
+ nbest: Optional[int] = None,
105
+ pad_tag: Optional[int] = None) -> List[List[List[int]]]:
106
+ """Find the most likely tag sequence using Viterbi algorithm.
107
+ Args:
108
+ emissions (`~torch.Tensor`): Emission score tensor of size
109
+ ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
110
+ ``(batch_size, seq_length, num_tags)`` otherwise.
111
+ mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
112
+ if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
113
+ nbest (`int`): Number of most probable paths for each sequence
114
+ pad_tag (`int`): Tag at padded positions. Often input varies in length and
115
+ the length will be padded to the maximum length in the batch. Tags at
116
+ the padded positions will be assigned with a padding tag, i.e. `pad_tag`
117
+ Returns:
118
+ A PyTorch tensor of the best tag sequence for each batch of shape
119
+ (nbest, batch_size, seq_length)
120
+ """
121
+ if nbest is None:
122
+ nbest = 1
123
+ if mask is None:
124
+ mask = torch.ones(emissions.shape[:2], dtype=torch.uint8,
125
+ device=emissions.device)
126
+ if mask.dtype != torch.uint8:
127
+ mask = mask.byte()
128
+ self._validate(emissions, mask=mask)
129
+
130
+ if self.batch_first:
131
+ emissions = emissions.transpose(0, 1)
132
+ mask = mask.transpose(0, 1)
133
+
134
+ if nbest == 1:
135
+ return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)
136
+ return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)
137
+
138
+ def _validate(self, emissions: torch.Tensor,
139
+ tags: Optional[torch.LongTensor] = None,
140
+ mask: Optional[torch.ByteTensor] = None) -> None:
141
+ if emissions.dim() != 3:
142
+ raise ValueError(f"emissions must have dimension of 3, got {emissions.dim()}")
143
+ if emissions.size(2) != self.num_tags:
144
+ raise ValueError(
145
+ f"expected last dimension of emissions is {self.num_tags}, "
146
+ f"got {emissions.size(2)}")
147
+
148
+ if tags is not None:
149
+ if emissions.shape[:2] != tags.shape:
150
+ raise ValueError(
151
+ "the first two dimensions of emissions and tags must match, "
152
+ f"got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}")
153
+
154
+ if mask is not None:
155
+ if emissions.shape[:2] != mask.shape:
156
+ raise ValueError(
157
+ "the first two dimensions of emissions and mask must match, "
158
+ f"got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}")
159
+ no_empty_seq = not self.batch_first and mask[0].all()
160
+ no_empty_seq_bf = self.batch_first and mask[:, 0].all()
161
+ if not no_empty_seq and not no_empty_seq_bf:
162
+ raise ValueError("mask of the first timestep must all be on")
163
+
164
+ def _compute_score(self, emissions: torch.Tensor,
165
+ tags: torch.LongTensor,
166
+ mask: torch.ByteTensor) -> torch.Tensor:
167
+ # emissions: (seq_length, batch_size, num_tags)
168
+ # tags: (seq_length, batch_size)
169
+ # mask: (seq_length, batch_size)
170
+ seq_length, batch_size = tags.shape
171
+ mask = mask.float()
172
+
173
+ # Start transition score and first emission
174
+ # shape: (batch_size,)
175
+ score = self.start_transitions[tags[0]]
176
+ score += emissions[0, torch.arange(batch_size), tags[0]]
177
+
178
+ for i in range(1, seq_length):
179
+ # Transition score to next tag, only added if next timestep is valid (mask == 1)
180
+ # shape: (batch_size,)
181
+ score += self.transitions[tags[i - 1], tags[i]] * mask[i]
182
+
183
+ # Emission score for next tag, only added if next timestep is valid (mask == 1)
184
+ # shape: (batch_size,)
185
+ score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
186
+
187
+ # End transition score
188
+ # shape: (batch_size,)
189
+ seq_ends = mask.long().sum(dim=0) - 1
190
+ # shape: (batch_size,)
191
+ last_tags = tags[seq_ends, torch.arange(batch_size)]
192
+ # shape: (batch_size,)
193
+ score += self.end_transitions[last_tags]
194
+
195
+ return score
196
+
197
+ def _compute_normalizer(self, emissions: torch.Tensor,
198
+ mask: torch.ByteTensor) -> torch.Tensor:
199
+ # emissions: (seq_length, batch_size, num_tags)
200
+ # mask: (seq_length, batch_size)
201
+ seq_length = emissions.size(0)
202
+
203
+ # Start transition score and first emission; score has size of
204
+ # (batch_size, num_tags) where for each batch, the j-th column stores
205
+ # the score that the first timestep has tag j
206
+ # shape: (batch_size, num_tags)
207
+ score = self.start_transitions + emissions[0]
208
+
209
+ for i in range(1, seq_length):
210
+ # Broadcast score for every possible next tag
211
+ # shape: (batch_size, num_tags, 1)
212
+ broadcast_score = score.unsqueeze(2)
213
+
214
+ # Broadcast emission score for every possible current tag
215
+ # shape: (batch_size, 1, num_tags)
216
+ broadcast_emissions = emissions[i].unsqueeze(1)
217
+
218
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
219
+ # for each sample, entry at row i and column j stores the sum of scores of all
220
+ # possible tag sequences so far that end with transitioning from tag i to tag j
221
+ # and emitting
222
+ # shape: (batch_size, num_tags, num_tags)
223
+ next_score = broadcast_score + self.transitions + broadcast_emissions
224
+
225
+ # Sum over all possible current tags, but we"re in score space, so a sum
226
+ # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
227
+ # all possible tag sequences so far, that end in tag i
228
+ # shape: (batch_size, num_tags)
229
+ next_score = torch.logsumexp(next_score, dim=1)
230
+
231
+ # Set score to the next score if this timestep is valid (mask == 1)
232
+ # shape: (batch_size, num_tags)
233
+ score = torch.where(mask[i].unsqueeze(1), next_score, score)
234
+
235
+ # End transition score
236
+ # shape: (batch_size, num_tags)
237
+ score += self.end_transitions
238
+
239
+ # Sum (log-sum-exp) over all possible tags
240
+ # shape: (batch_size,)
241
+ return torch.logsumexp(score, dim=1)
242
+
243
+ def _viterbi_decode(self, emissions: torch.FloatTensor,
244
+ mask: torch.ByteTensor,
245
+ pad_tag: Optional[int] = None) -> List[List[int]]:
246
+ # emissions: (seq_length, batch_size, num_tags)
247
+ # mask: (seq_length, batch_size)
248
+ # return: (batch_size, seq_length)
249
+ if pad_tag is None:
250
+ pad_tag = 0
251
+
252
+ device = emissions.device
253
+ seq_length, batch_size = mask.shape
254
+
255
+ # Start transition and first emission
256
+ # shape: (batch_size, num_tags)
257
+ score = self.start_transitions + emissions[0]
258
+ history_idx = torch.zeros((seq_length, batch_size, self.num_tags),
259
+ dtype=torch.long, device=device)
260
+ oor_idx = torch.zeros((batch_size, self.num_tags),
261
+ dtype=torch.long, device=device)
262
+ oor_tag = torch.full((seq_length, batch_size), pad_tag,
263
+ dtype=torch.long, device=device)
264
+
265
+ # - score is a tensor of size (batch_size, num_tags) where for every batch,
266
+ # value at column j stores the score of the best tag sequence so far that ends
267
+ # with tag j
268
+ # - history_idx saves where the best tags candidate transitioned from; this is used
269
+ # when we trace back the best tag sequence
270
+ # - oor_idx saves the best tags candidate transitioned from at the positions
271
+ # where mask is 0, i.e. out of range (oor)
272
+
273
+ # Viterbi algorithm recursive case: we compute the score of the best tag sequence
274
+ # for every possible next tag
275
+ for i in range(1, seq_length):
276
+ # Broadcast viterbi score for every possible next tag
277
+ # shape: (batch_size, num_tags, 1)
278
+ broadcast_score = score.unsqueeze(2)
279
+
280
+ # Broadcast emission score for every possible current tag
281
+ # shape: (batch_size, 1, num_tags)
282
+ broadcast_emission = emissions[i].unsqueeze(1)
283
+
284
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
285
+ # for each sample, entry at row i and column j stores the score of the best
286
+ # tag sequence so far that ends with transitioning from tag i to tag j and emitting
287
+ # shape: (batch_size, num_tags, num_tags)
288
+ next_score = broadcast_score + self.transitions + broadcast_emission
289
+
290
+ # Find the maximum score over all possible current tag
291
+ # shape: (batch_size, num_tags)
292
+ next_score, indices = next_score.max(dim=1)
293
+
294
+ # Set score to the next score if this timestep is valid (mask == 1)
295
+ # and save the index that produces the next score
296
+ # shape: (batch_size, num_tags)
297
+ score = torch.where(mask[i].unsqueeze(-1), next_score, score)
298
+ indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx)
299
+ history_idx[i - 1] = indices
300
+
301
+ # End transition score
302
+ # shape: (batch_size, num_tags)
303
+ end_score = score + self.end_transitions
304
+ _, end_tag = end_score.max(dim=1)
305
+
306
+ # shape: (batch_size,)
307
+ seq_ends = mask.long().sum(dim=0) - 1
308
+
309
+ # insert the best tag at each sequence end (last position with mask == 1)
310
+ history_idx = history_idx.transpose(1, 0).contiguous()
311
+ history_idx.scatter_(1, seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags),
312
+ end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags))
313
+ history_idx = history_idx.transpose(1, 0).contiguous()
314
+
315
+ # The most probable path for each sequence
316
+ best_tags_arr = torch.zeros((seq_length, batch_size),
317
+ dtype=torch.long, device=device)
318
+ best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
319
+ for idx in range(seq_length - 1, -1, -1):
320
+ best_tags = torch.gather(history_idx[idx], 1, best_tags)
321
+ best_tags_arr[idx] = best_tags.data.view(batch_size)
322
+
323
+ return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1)
324
+
325
+ def _viterbi_decode_nbest(self, emissions: torch.FloatTensor,
326
+ mask: torch.ByteTensor,
327
+ nbest: int,
328
+ pad_tag: Optional[int] = None) -> List[List[List[int]]]:
329
+ # emissions: (seq_length, batch_size, num_tags)
330
+ # mask: (seq_length, batch_size)
331
+ # return: (nbest, batch_size, seq_length)
332
+ if pad_tag is None:
333
+ pad_tag = 0
334
+
335
+ device = emissions.device
336
+ seq_length, batch_size = mask.shape
337
+
338
+ # Start transition and first emission
339
+ # shape: (batch_size, num_tags)
340
+ score = self.start_transitions + emissions[0]
341
+ history_idx = torch.zeros((seq_length, batch_size, self.num_tags, nbest),
342
+ dtype=torch.long, device=device)
343
+ oor_idx = torch.zeros((batch_size, self.num_tags, nbest),
344
+ dtype=torch.long, device=device)
345
+ oor_tag = torch.full((seq_length, batch_size, nbest), pad_tag,
346
+ dtype=torch.long, device=device)
347
+
348
+ # + score is a tensor of size (batch_size, num_tags) where for every batch,
349
+ # value at column j stores the score of the best tag sequence so far that ends
350
+ # with tag j
351
+ # + history_idx saves where the best tags candidate transitioned from; this is used
352
+ # when we trace back the best tag sequence
353
+ # - oor_idx saves the best tags candidate transitioned from at the positions
354
+ # where mask is 0, i.e. out of range (oor)
355
+
356
+ # Viterbi algorithm recursive case: we compute the score of the best tag sequence
357
+ # for every possible next tag
358
+ for i in range(1, seq_length):
359
+ if i == 1:
360
+ broadcast_score = score.unsqueeze(-1)
361
+ broadcast_emission = emissions[i].unsqueeze(1)
362
+ # shape: (batch_size, num_tags, num_tags)
363
+ next_score = broadcast_score + self.transitions + broadcast_emission
364
+ else:
365
+ broadcast_score = score.unsqueeze(-1)
366
+ broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2)
367
+ # shape: (batch_size, num_tags, nbest, num_tags)
368
+ next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission
369
+
370
+ # Find the top `nbest` maximum score over all possible current tag
371
+ # shape: (batch_size, nbest, num_tags)
372
+ next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1)
373
+
374
+ if i == 1:
375
+ score = score.unsqueeze(-1).expand(-1, -1, nbest)
376
+ indices = indices * nbest
377
+
378
+ # convert to shape: (batch_size, num_tags, nbest)
379
+ next_score = next_score.transpose(2, 1)
380
+ indices = indices.transpose(2, 1)
381
+
382
+ # Set score to the next score if this timestep is valid (mask == 1)
383
+ # and save the index that produces the next score
384
+ # shape: (batch_size, num_tags, nbest)
385
+ score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), next_score, score)
386
+ indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, oor_idx)
387
+ history_idx[i - 1] = indices
388
+
389
+ # End transition score shape: (batch_size, num_tags, nbest)
390
+ end_score = score + self.end_transitions.unsqueeze(-1)
391
+ _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)
392
+
393
+ # shape: (batch_size,)
394
+ seq_ends = mask.long().sum(dim=0) - 1
395
+
396
+ # insert the best tag at each sequence end (last position with mask == 1)
397
+ history_idx = history_idx.transpose(1, 0).contiguous()
398
+ history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),
399
+ end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))
400
+ history_idx = history_idx.transpose(1, 0).contiguous()
401
+
402
+ # The most probable path for each sequence
403
+ best_tags_arr = torch.zeros((seq_length, batch_size, nbest),
404
+ dtype=torch.long, device=device)
405
+ best_tags = torch.arange(nbest, dtype=torch.long, device=device) \
406
+ .view(1, -1).expand(batch_size, -1)
407
+ for idx in range(seq_length - 1, -1, -1):
408
+ best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, best_tags)
409
+ best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest
410
+
411
+ return torch.where(mask.unsqueeze(-1), best_tags_arr, oor_tag).permute(2, 1, 0)
models/basic_modules/generation.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Optional
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.nn as nn
6
+
7
+ try:
8
+ from transformers.generation_logits_process import (
9
+ LogitsProcessorList,
10
+ TemperatureLogitsWarper,
11
+ TopKLogitsWarper,
12
+ TopPLogitsWarper,
13
+ )
14
+ except ImportError:
15
+ from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
16
+
17
+
18
+ def prepare_logits_processor(top_k: Optional[int] = None,
19
+ top_p: Optional[float] = None,
20
+ temperature: Optional[float] = None) -> LogitsProcessorList:
21
+ processor_list = LogitsProcessorList()
22
+ if temperature is not None and temperature != 1.0:
23
+ processor_list.append(TemperatureLogitsWarper(temperature))
24
+ if top_k is not None and top_k != 0:
25
+ processor_list.append(TopKLogitsWarper(top_k))
26
+ if top_p is not None and top_p < 1.0:
27
+ processor_list.append(TopPLogitsWarper(top_p))
28
+ return processor_list
29
+
30
+
31
+ def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
32
+ if dist.is_initialized() and dist.get_world_size() > 1:
33
+ # consider DP
34
+ unfinished_sequences = unfinished_sequences.clone()
35
+ dist.all_reduce(unfinished_sequences)
36
+ return unfinished_sequences.max() == 0
37
+
38
+
39
+ def sample(model: nn.Module,
40
+ input_ids: torch.Tensor,
41
+ max_length: int,
42
+ early_stopping: bool = False,
43
+ eos_token_id: Optional[int] = None,
44
+ pad_token_id: Optional[int] = None,
45
+ top_k: Optional[int] = None,
46
+ top_p: Optional[float] = None,
47
+ temperature: Optional[float] = None,
48
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
49
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
50
+ **model_kwargs) -> torch.Tensor:
51
+ if input_ids.size(1) >= max_length:
52
+ return input_ids
53
+
54
+ logits_processor = prepare_logits_processor(top_k, top_p, temperature)
55
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
56
+
57
+ for _ in range(input_ids.size(1), max_length):
58
+ model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
59
+ 'input_ids': input_ids
60
+ }
61
+ outputs = model(**model_inputs)
62
+
63
+ next_token_logits = outputs['logits'][:, -1, :]
64
+ # pre-process distribution
65
+ next_token_logits = logits_processor(input_ids, next_token_logits)
66
+ # sample
67
+ probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
68
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
69
+
70
+ # finished sentences should have their next token be a padding token
71
+ if eos_token_id is not None:
72
+ if pad_token_id is None:
73
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
74
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
75
+
76
+ # update generated ids, model inputs for next step
77
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
78
+ if update_model_kwargs_fn is not None:
79
+ model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
80
+
81
+ # if eos_token was found in one sentence, set sentence to finished
82
+ if eos_token_id is not None:
83
+ unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
84
+
85
+ # stop when each sentence is finished if early_stopping=True
86
+ if early_stopping and _is_sequence_finished(unfinished_sequences):
87
+ break
88
+
89
+ return input_ids
90
+
91
+
92
+ def generate(model: nn.Module,
93
+ input_ids: torch.Tensor,
94
+ max_length: int,
95
+ num_beams: int = 1,
96
+ do_sample: bool = True,
97
+ early_stopping: bool = False,
98
+ eos_token_id: Optional[int] = None,
99
+ pad_token_id: Optional[int] = None,
100
+ top_k: Optional[int] = None,
101
+ top_p: Optional[float] = None,
102
+ temperature: Optional[float] = None,
103
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
104
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
105
+ **model_kwargs) -> torch.Tensor:
106
+ """Generate token sequence. The returned sequence is input_ids + generated_tokens.
107
+
108
+ Args:
109
+ model (nn.Module): model
110
+ input_ids (torch.Tensor): input sequence
111
+ max_length (int): max length of the returned sequence
112
+ num_beams (int, optional): number of beams. Defaults to 1.
113
+ do_sample (bool, optional): whether to do sample. Defaults to True.
114
+ early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
115
+ eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
116
+ pad_token_id (Optional[int], optional): pad token id. Defaults to None.
117
+ top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
118
+ top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
119
+ temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
120
+ prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
121
+ update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
122
+ """
123
+ is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
124
+ is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
125
+ is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
126
+ if is_greedy_gen_mode:
127
+ # run greedy search
128
+ raise NotImplementedError
129
+ elif is_sample_gen_mode:
130
+ # run sample
131
+ return sample(model,
132
+ input_ids,
133
+ max_length,
134
+ early_stopping=early_stopping,
135
+ eos_token_id=eos_token_id,
136
+ pad_token_id=pad_token_id,
137
+ top_k=top_k,
138
+ top_p=top_p,
139
+ temperature=temperature,
140
+ prepare_inputs_fn=prepare_inputs_fn,
141
+ update_model_kwargs_fn=update_model_kwargs_fn,
142
+ **model_kwargs)
143
+ elif is_beam_gen_mode:
144
+ raise NotImplementedError
145
+ else:
146
+ raise ValueError("Unsupported generation mode")
models/basic_modules/linears.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # A simple MLP layer
6
+ class FeedForwardNetwork(nn.Module):
7
+ def __init__(self, input_size, hidden_size, output_size, dropout_rate=0):
8
+ super(FeedForwardNetwork, self).__init__()
9
+ self.dropout_rate = dropout_rate
10
+ self.linear1 = nn.Linear(input_size, hidden_size)
11
+ self.linear2 = nn.Linear(hidden_size, output_size)
12
+
13
+ def forward(self, x):
14
+ x_proj = F.dropout(F.relu(self.linear1(x)), p=self.dropout_rate, training=self.training)
15
+ x_proj = self.linear2(x_proj)
16
+ return x_proj
17
+
18
+ # Span Prediction for Start Position
19
+ class PoolerStartLogits(nn.Module):
20
+ def __init__(self, hidden_size, num_classes):
21
+ super(PoolerStartLogits, self).__init__()
22
+ self.dense = nn.Linear(hidden_size, num_classes)
23
+
24
+ def forward(self, hidden_states, p_mask=None):
25
+ x = self.dense(hidden_states)
26
+ return x
27
+
28
+ # Span Prediction for End Position
29
+ class PoolerEndLogits(nn.Module):
30
+ def __init__(self, hidden_size, num_classes):
31
+ super(PoolerEndLogits, self).__init__()
32
+ self.dense_0 = nn.Linear(hidden_size, hidden_size)
33
+ self.activation = nn.Tanh()
34
+ self.LayerNorm = nn.LayerNorm(hidden_size)
35
+ self.dense_1 = nn.Linear(hidden_size, num_classes)
36
+
37
+ def forward(self, hidden_states, start_positions=None, p_mask=None):
38
+ x = self.dense_0(torch.cat([hidden_states, start_positions], dim=-1))
39
+ x = self.activation(x)
40
+ x = self.LayerNorm(x)
41
+ x = self.dense_1(x)
42
+ return x
models/basic_modules/lora.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # DeepSpeed Team
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from deepspeed.compression.helper import recursive_getattr, recursive_setattr
10
+ import deepspeed
11
+
12
+
13
+ class LinearLayer_LoRA(nn.Module):
14
+ # an simple implementation of LoRA
15
+ # for now only support Linear Layer
16
+ def __init__(self,
17
+ weight,
18
+ lora_dim=0,
19
+ lora_scaling=1,
20
+ lora_droppout=0,
21
+ bias=None):
22
+ super(LinearLayer_LoRA, self).__init__()
23
+ self.weight = weight
24
+ self.bias = bias
25
+
26
+ if lora_dim <= 0:
27
+ raise ValueError(
28
+ "You are training to use LoRA, whose reduced dim should be larger than 1"
29
+ )
30
+
31
+ try:
32
+ # for zero stage 3
33
+ rows, columns = weight.ds_shape
34
+ except:
35
+ rows, columns = weight.shape
36
+ self.lora_right_weight = nn.Parameter(torch.zeros(
37
+ columns,
38
+ lora_dim)) # apply transpose so in forward we do not need to
39
+ self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows))
40
+ self.lora_scaling = lora_scaling / lora_dim
41
+
42
+ if lora_droppout > 0:
43
+ self.lora_dropout = nn.Dropout(lora_droppout)
44
+ else:
45
+ self.lora_dropout = nn.Identity()
46
+
47
+ self.reset_parameters()
48
+ # disable the original weight gradient
49
+ self.weight.requires_grad = False
50
+ # fuse LoRA to the original weight
51
+ self.fuse_lora = False
52
+
53
+ def eval(self):
54
+ self.lora_dropout.eval()
55
+
56
+ # self.fuse_lora_weight()
57
+
58
+ def train(self, mode=True):
59
+ self.lora_dropout.train(mode)
60
+ # self.unfuse_lora_weight()
61
+
62
+ def reset_parameters(self):
63
+ nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
64
+ nn.init.zeros_(self.lora_left_weight)
65
+
66
+ def fuse_lora_weight(self):
67
+ if not self.fuse_lora:
68
+ self.weight.data += self.lora_scaling * torch.matmul(
69
+ self.lora_left_weight.t(), self.lora_right_weight.t())
70
+ self.fuse_lora = True
71
+
72
+ def unfuse_lora_weight(self):
73
+ if self.fuse_lora:
74
+ self.weight.data -= self.lora_scaling * torch.matmul(
75
+ self.lora_left_weight.t(), self.lora_right_weight.t())
76
+ self.fuse_lora = False
77
+
78
+ def forward(self, input):
79
+ if self.fuse_lora:
80
+ return F.linear(input, self.weight, self.bias)
81
+ else:
82
+ return F.linear(
83
+ input, self.weight,
84
+ self.bias) + (self.lora_dropout(input) @ self.lora_right_weight
85
+ @ self.lora_left_weight) * self.lora_scaling
86
+
87
+
88
+ # convert the linear layer to LoRA
89
+ def convert_linear_layer_to_lora(model,
90
+ part_module_name,
91
+ lora_dim=0,
92
+ lora_scaling=1,
93
+ lora_droppout=0):
94
+ repalce_name = []
95
+ for name, module in model.named_modules():
96
+ if isinstance(module, nn.Linear) and part_module_name in name:
97
+ repalce_name.append(name)
98
+ for name in repalce_name:
99
+ module = recursive_getattr(model, name)
100
+ tmp = LinearLayer_LoRA(
101
+ module.weight, lora_dim, lora_scaling, lora_droppout,
102
+ module.bias).to(module.weight.device).to(module.weight.dtype)
103
+ recursive_setattr(model, name, tmp)
104
+ return model
105
+
106
+
107
+ def _z3_params_to_fetch(param_list):
108
+ return [
109
+ p for p in param_list
110
+ if hasattr(p, 'ds_id') and p.ds_status == deepspeed.runtime.zero.
111
+ partition_parameters.ZeroParamStatus.NOT_AVAILABLE
112
+ ]
113
+
114
+
115
+ # convert the LoRA layer to linear layer
116
+ def convert_lora_to_linear_layer(model):
117
+ repalce_name = []
118
+ for name, module in model.named_modules():
119
+ if isinstance(module, LinearLayer_LoRA):
120
+ repalce_name.append(name)
121
+ for name in repalce_name:
122
+ module = recursive_getattr(model, name)
123
+ zero_stage_3 = hasattr(module.weight, 'ds_id')
124
+ with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([
125
+ module.weight, module.bias, module.lora_left_weight,
126
+ module.lora_right_weight
127
+ ]),
128
+ modifier_rank=0,
129
+ enabled=zero_stage_3):
130
+ module.fuse_lora_weight()
131
+ return model
132
+
133
+
134
+ def only_optimize_lora_parameters(model):
135
+ # turn off the gradient of all the parameters except the LoRA parameters
136
+ for name, param in model.named_parameters():
137
+ if "lora_right_weight" in name or "lora_left_weight" in name:
138
+ param.requires_grad = True
139
+ else:
140
+ param.requires_grad = False
141
+ return model
models/basic_modules/prefix_encoder.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # from transformers.models.bart.modeling_bart import BartForConditionalGeneration
4
+ # from transformers.models.bert.modeling_bert import BertForSequenceClassification
5
+
6
+ # model = BartForConditionalGeneration(None)
7
+
8
+
9
+
10
+ class PrefixEncoder(torch.nn.Module):
11
+ r"""
12
+ The torch.nn model to encode the prefix
13
+
14
+ Input shape: (batch-size, prefix-length)
15
+
16
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
17
+ """
18
+ def __init__(self, config):
19
+ super().__init__()
20
+ self.prefix_projection = config.prefix_projection
21
+ if self.prefix_projection:
22
+ # Use a two-layer MLP to encode the prefix
23
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
24
+ self.trans = torch.nn.Sequential(
25
+ torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
26
+ torch.nn.Tanh(),
27
+ torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
28
+ )
29
+ else:
30
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)
31
+
32
+ def forward(self, prefix: torch.Tensor):
33
+ if self.prefix_projection:
34
+ prefix_tokens = self.embedding(prefix) # [pre_seq_len, hidden_dim]
35
+ past_key_values = self.trans(prefix_tokens)
36
+ else:
37
+ past_key_values = self.embedding(prefix)
38
+ return past_key_values