shunxing1234 commited on
Commit
f3802e8
1 Parent(s): f65b4f1

Upload ZEN/modeling.py

Browse files
Files changed (1) hide show
  1. ZEN/modeling.py +1357 -0
ZEN/modeling.py ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright 2019 Sinovation Ventures AI Institute
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # This file is partially derived from the code at
17
+ # https://github.com/huggingface/transformers/tree/master/transformers
18
+ #
19
+ # Original copyright notice:
20
+ #
21
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
22
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
23
+ #
24
+ # Licensed under the Apache License, Version 2.0 (the "License");
25
+ # you may not use this file except in compliance with the License.
26
+ # You may obtain a copy of the License at
27
+ #
28
+ # http://www.apache.org/licenses/LICENSE-2.0
29
+ #
30
+ # Unless required by applicable law or agreed to in writing, software
31
+ # distributed under the License is distributed on an "AS IS" BASIS,
32
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33
+ # See the License for the specific language governing permissions and
34
+ # limitations under the License.
35
+ """PyTorch ZEN model classes."""
36
+
37
+ from __future__ import absolute_import, division, print_function, unicode_literals
38
+
39
+ import copy
40
+ import json
41
+ import logging
42
+ import math
43
+ import os
44
+ import sys
45
+ from io import open
46
+
47
+ import torch
48
+ from torch import nn
49
+ from torch.nn import CrossEntropyLoss
50
+
51
+ from .file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+ PRETRAINED_MODEL_ARCHIVE_MAP = {
56
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
57
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
58
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
59
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
60
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
61
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
62
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
63
+ 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
64
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
65
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
66
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
67
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
68
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
69
+ }
70
+ PRETRAINED_CONFIG_ARCHIVE_MAP = {
71
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
72
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
73
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
74
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
75
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
76
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
77
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
78
+ 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
79
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
80
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
81
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
82
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
83
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
84
+ }
85
+ BERT_CONFIG_NAME = 'bert_config.json'
86
+ TF_WEIGHTS_NAME = 'model.ckpt'
87
+
88
+
89
+ def prune_linear_layer(layer, index, dim=0):
90
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
91
+ Return the pruned layer as a new layer with requires_grad=True.
92
+ Used to remove heads.
93
+ """
94
+ index = index.to(layer.weight.device)
95
+ W = layer.weight.index_select(dim, index).clone().detach()
96
+ if layer.bias is not None:
97
+ if dim == 1:
98
+ b = layer.bias.clone().detach()
99
+ else:
100
+ b = layer.bias[index].clone().detach()
101
+ new_size = list(layer.weight.size())
102
+ new_size[dim] = len(index)
103
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
104
+ new_layer.weight.requires_grad = False
105
+ new_layer.weight.copy_(W.contiguous())
106
+ new_layer.weight.requires_grad = True
107
+ if layer.bias is not None:
108
+ new_layer.bias.requires_grad = False
109
+ new_layer.bias.copy_(b.contiguous())
110
+ new_layer.bias.requires_grad = True
111
+ return new_layer
112
+
113
+
114
+ def load_tf_weights_in_bert(model, tf_checkpoint_path):
115
+ """ Load tf checkpoints in a pytorch model
116
+ """
117
+ try:
118
+ import re
119
+ import numpy as np
120
+ import tensorflow as tf
121
+ except ImportError:
122
+ print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
123
+ "https://www.tensorflow.org/install/ for installation instructions.")
124
+ raise
125
+ tf_path = os.path.abspath(tf_checkpoint_path)
126
+ print("Converting TensorFlow checkpoint from {}".format(tf_path))
127
+ # Load weights from TF model
128
+ init_vars = tf.train.list_variables(tf_path)
129
+ names = []
130
+ arrays = []
131
+ for name, shape in init_vars:
132
+ print("Loading TF weight {} with shape {}".format(name, shape))
133
+ array = tf.train.load_variable(tf_path, name)
134
+ names.append(name)
135
+ arrays.append(array)
136
+
137
+ for name, array in zip(names, arrays):
138
+ name = name.split('/')
139
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
140
+ # which are not required for using pretrained model
141
+ if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
142
+ print("Skipping {}".format("/".join(name)))
143
+ continue
144
+ pointer = model
145
+ for m_name in name:
146
+ if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
147
+ l = re.split(r'_(\d+)', m_name)
148
+ else:
149
+ l = [m_name]
150
+ if l[0] == 'kernel' or l[0] == 'gamma':
151
+ pointer = getattr(pointer, 'weight')
152
+ elif l[0] == 'output_bias' or l[0] == 'beta':
153
+ pointer = getattr(pointer, 'bias')
154
+ elif l[0] == 'output_weights':
155
+ pointer = getattr(pointer, 'weight')
156
+ elif l[0] == 'squad':
157
+ pointer = getattr(pointer, 'classifier')
158
+ else:
159
+ try:
160
+ pointer = getattr(pointer, l[0])
161
+ except AttributeError:
162
+ print("Skipping {}".format("/".join(name)))
163
+ continue
164
+ if len(l) >= 2:
165
+ num = int(l[1])
166
+ pointer = pointer[num]
167
+ if m_name[-11:] == '_embeddings':
168
+ pointer = getattr(pointer, 'weight')
169
+ elif m_name == 'kernel':
170
+ array = np.transpose(array)
171
+ try:
172
+ assert pointer.shape == array.shape
173
+ except AssertionError as e:
174
+ e.args += (pointer.shape, array.shape)
175
+ raise
176
+ print("Initialize PyTorch weight {}".format(name))
177
+ pointer.data = torch.from_numpy(array)
178
+ return model
179
+
180
+
181
+ def gelu(x):
182
+ """Implementation of the gelu activation function.
183
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
184
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
185
+ Also see https://arxiv.org/abs/1606.08415
186
+ """
187
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
188
+
189
+
190
+ def swish(x):
191
+ return x * torch.sigmoid(x)
192
+
193
+
194
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
195
+
196
+
197
+ class ZenConfig(object):
198
+
199
+ """Configuration class to store the configuration of a `ZenModel`.
200
+ """
201
+
202
+ def __init__(self,
203
+ vocab_size_or_config_json_file,
204
+ word_vocab_size,
205
+ hidden_size=768,
206
+ num_hidden_layers=12,
207
+ num_attention_heads=12,
208
+ intermediate_size=3072,
209
+ hidden_act="gelu",
210
+ hidden_dropout_prob=0.1,
211
+ attention_probs_dropout_prob=0.1,
212
+ max_position_embeddings=512,
213
+ type_vocab_size=2,
214
+ initializer_range=0.02,
215
+ layer_norm_eps=1e-12,
216
+ num_hidden_word_layers=6):
217
+ """Constructs ZenConfig.
218
+
219
+ Args:
220
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
221
+ hidden_size: Size of the encoder layers and the pooler layer.
222
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
223
+ num_attention_heads: Number of attention heads for each attention layer in
224
+ the Transformer encoder.
225
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
226
+ layer in the Transformer encoder.
227
+ hidden_act: The non-linear activation function (function or string) in the
228
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
229
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
230
+ layers in the embeddings, encoder, and pooler.
231
+ attention_probs_dropout_prob: The dropout ratio for the attention
232
+ probabilities.
233
+ max_position_embeddings: The maximum sequence length that this model might
234
+ ever be used with. Typically set this to something large just in case
235
+ (e.g., 512 or 1024 or 2048).
236
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
237
+ `BertModel`.
238
+ initializer_range: The sttdev of the truncated_normal_initializer for
239
+ initializing all weight matrices.
240
+ layer_norm_eps: The epsilon used by LayerNorm.
241
+ """
242
+ if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
243
+ and isinstance(vocab_size_or_config_json_file, unicode)):
244
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
245
+ json_config = json.loads(reader.read())
246
+ for key, value in json_config.items():
247
+ self.__dict__[key] = value
248
+ self.word_size = word_vocab_size
249
+ elif isinstance(vocab_size_or_config_json_file, int):
250
+ self.vocab_size = vocab_size_or_config_json_file
251
+ self.word_size = word_vocab_size
252
+ self.hidden_size = hidden_size
253
+ self.num_hidden_layers = num_hidden_layers
254
+ self.num_attention_heads = num_attention_heads
255
+ self.hidden_act = hidden_act
256
+ self.intermediate_size = intermediate_size
257
+ self.hidden_dropout_prob = hidden_dropout_prob
258
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
259
+ self.max_position_embeddings = max_position_embeddings
260
+ self.type_vocab_size = type_vocab_size
261
+ self.initializer_range = initializer_range
262
+ self.layer_norm_eps = layer_norm_eps
263
+ self.num_hidden_word_layers = num_hidden_word_layers
264
+ else:
265
+ raise ValueError("First argument must be either a vocabulary size (int)"
266
+ "or the path to a pretrained model config file (str)")
267
+
268
+ @classmethod
269
+ def from_dict(cls, json_object):
270
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
271
+ config = ZenConfig(vocab_size_or_config_json_file=-1, word_vocab_size=104089)
272
+ for key, value in json_object.items():
273
+ config.__dict__[key] = value
274
+ return config
275
+
276
+ @classmethod
277
+ def from_json_file(cls, json_file):
278
+ """Constructs a `BertConfig` from a json file of parameters."""
279
+ with open(json_file, "r", encoding='utf-8') as reader:
280
+ text = reader.read()
281
+ return cls.from_dict(json.loads(text))
282
+
283
+ def __repr__(self):
284
+ return str(self.to_json_string())
285
+
286
+ def to_dict(self):
287
+ """Serializes this instance to a Python dictionary."""
288
+ output = copy.deepcopy(self.__dict__)
289
+ return output
290
+
291
+ def to_json_string(self):
292
+ """Serializes this instance to a JSON string."""
293
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
294
+
295
+ def to_json_file(self, json_file_path):
296
+ """ Save this instance to a json file."""
297
+ with open(json_file_path, "w", encoding='utf-8') as writer:
298
+ writer.write(self.to_json_string())
299
+
300
+
301
+ try:
302
+ from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
303
+ except ImportError:
304
+ logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
305
+
306
+
307
+ class BertLayerNorm(nn.Module):
308
+ def __init__(self, hidden_size, eps=1e-12):
309
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
310
+ """
311
+ super(BertLayerNorm, self).__init__()
312
+ self.weight = nn.Parameter(torch.ones(hidden_size))
313
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
314
+ self.variance_epsilon = eps
315
+
316
+ def forward(self, x):
317
+ u = x.mean(-1, keepdim=True)
318
+ s = (x - u).pow(2).mean(-1, keepdim=True)
319
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
320
+ return self.weight * x + self.bias
321
+
322
+
323
+ class BertEmbeddings(nn.Module):
324
+ """Construct the embeddings from word, position and token_type embeddings.
325
+ """
326
+
327
+ def __init__(self, config):
328
+ super(BertEmbeddings, self).__init__()
329
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
330
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
331
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
332
+
333
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
334
+ # any TensorFlow checkpoint file
335
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
336
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
337
+
338
+ def forward(self, input_ids, token_type_ids=None):
339
+ seq_length = input_ids.size(1)
340
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
341
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
342
+ if token_type_ids is None:
343
+ token_type_ids = torch.zeros_like(input_ids)
344
+
345
+ words_embeddings = self.word_embeddings(input_ids)
346
+ position_embeddings = self.position_embeddings(position_ids)
347
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
348
+
349
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
350
+ embeddings = self.LayerNorm(embeddings)
351
+ embeddings = self.dropout(embeddings)
352
+ return embeddings
353
+
354
+
355
+ class BertWordEmbeddings(nn.Module):
356
+ """Construct the embeddings from ngram, position and token_type embeddings.
357
+ """
358
+
359
+ def __init__(self, config):
360
+ super(BertWordEmbeddings, self).__init__()
361
+ self.word_embeddings = nn.Embedding(config.word_size, config.hidden_size, padding_idx=0)
362
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
363
+
364
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
365
+ # any TensorFlow checkpoint file
366
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
367
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
368
+
369
+ def forward(self, input_ids, token_type_ids=None):
370
+ if token_type_ids is None:
371
+ token_type_ids = torch.zeros_like(input_ids)
372
+
373
+ words_embeddings = self.word_embeddings(input_ids)
374
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
375
+
376
+ embeddings = words_embeddings + token_type_embeddings
377
+ embeddings = self.LayerNorm(embeddings)
378
+ embeddings = self.dropout(embeddings)
379
+ return embeddings
380
+
381
+
382
+ class BertSelfAttention(nn.Module):
383
+ def __init__(self, config, output_attentions=False, keep_multihead_output=False):
384
+ super(BertSelfAttention, self).__init__()
385
+ if config.hidden_size % config.num_attention_heads != 0:
386
+ raise ValueError(
387
+ "The hidden size (%d) is not a multiple of the number of attention "
388
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
389
+ self.output_attentions = output_attentions
390
+ self.keep_multihead_output = keep_multihead_output
391
+ self.multihead_output = None
392
+
393
+ self.num_attention_heads = config.num_attention_heads
394
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
395
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
396
+
397
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
398
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
399
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
400
+
401
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
402
+
403
+ def transpose_for_scores(self, x):
404
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
405
+ x = x.view(*new_x_shape)
406
+ return x.permute(0, 2, 1, 3)
407
+
408
+ def forward(self, hidden_states, attention_mask, head_mask=None):
409
+ mixed_query_layer = self.query(hidden_states)
410
+ mixed_key_layer = self.key(hidden_states)
411
+ mixed_value_layer = self.value(hidden_states)
412
+
413
+ query_layer = self.transpose_for_scores(mixed_query_layer)
414
+ key_layer = self.transpose_for_scores(mixed_key_layer)
415
+ value_layer = self.transpose_for_scores(mixed_value_layer)
416
+
417
+ # Take the dot product between "query" and "key" to get the raw attention scores.
418
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
419
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
420
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
421
+ attention_scores = attention_scores + attention_mask
422
+
423
+ # Normalize the attention scores to probabilities.
424
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
425
+
426
+ # This is actually dropping out entire tokens to attend to, which might
427
+ # seem a bit unusual, but is taken from the original Transformer paper.
428
+ attention_probs = self.dropout(attention_probs)
429
+
430
+ # Mask heads if we want to
431
+ if head_mask is not None:
432
+ attention_probs = attention_probs * head_mask
433
+
434
+ context_layer = torch.matmul(attention_probs, value_layer)
435
+ if self.keep_multihead_output:
436
+ self.multihead_output = context_layer
437
+ self.multihead_output.retain_grad()
438
+
439
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
440
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
441
+ context_layer = context_layer.view(*new_context_layer_shape)
442
+ if self.output_attentions:
443
+ return attention_probs, context_layer
444
+ return context_layer
445
+
446
+
447
+ class BertSelfOutput(nn.Module):
448
+ def __init__(self, config):
449
+ super(BertSelfOutput, self).__init__()
450
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
451
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
452
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
453
+
454
+ def forward(self, hidden_states, input_tensor):
455
+ hidden_states = self.dense(hidden_states)
456
+ hidden_states = self.dropout(hidden_states)
457
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
458
+ return hidden_states
459
+
460
+
461
+ class BertAttention(nn.Module):
462
+ def __init__(self, config, output_attentions=False, keep_multihead_output=False):
463
+ super(BertAttention, self).__init__()
464
+ self.output_attentions = output_attentions
465
+ self.self = BertSelfAttention(config, output_attentions=output_attentions,
466
+ keep_multihead_output=keep_multihead_output)
467
+ self.output = BertSelfOutput(config)
468
+
469
+ def prune_heads(self, heads):
470
+ if len(heads) == 0:
471
+ return
472
+ mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
473
+ for head in heads:
474
+ mask[head] = 0
475
+ mask = mask.view(-1).contiguous().eq(1)
476
+ index = torch.arange(len(mask))[mask].long()
477
+ # Prune linear layers
478
+ self.self.query = prune_linear_layer(self.self.query, index)
479
+ self.self.key = prune_linear_layer(self.self.key, index)
480
+ self.self.value = prune_linear_layer(self.self.value, index)
481
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
482
+ # Update hyper params
483
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
484
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
485
+
486
+ def forward(self, input_tensor, attention_mask, head_mask=None):
487
+ self_output = self.self(input_tensor, attention_mask, head_mask)
488
+ if self.output_attentions:
489
+ attentions, self_output = self_output
490
+ attention_output = self.output(self_output, input_tensor)
491
+ if self.output_attentions:
492
+ return attentions, attention_output
493
+ return attention_output
494
+
495
+
496
+ class BertIntermediate(nn.Module):
497
+ def __init__(self, config):
498
+ super(BertIntermediate, self).__init__()
499
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
500
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
501
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
502
+ else:
503
+ self.intermediate_act_fn = config.hidden_act
504
+
505
+ def forward(self, hidden_states):
506
+ hidden_states = self.dense(hidden_states)
507
+ hidden_states = self.intermediate_act_fn(hidden_states)
508
+ return hidden_states
509
+
510
+
511
+ class BertOutput(nn.Module):
512
+ def __init__(self, config):
513
+ super(BertOutput, self).__init__()
514
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
515
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
516
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
517
+
518
+ def forward(self, hidden_states, input_tensor):
519
+ hidden_states = self.dense(hidden_states)
520
+ hidden_states = self.dropout(hidden_states)
521
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
522
+ return hidden_states
523
+
524
+
525
+ class BertLayer(nn.Module):
526
+ def __init__(self, config, output_attentions=False, keep_multihead_output=False):
527
+ super(BertLayer, self).__init__()
528
+ self.output_attentions = output_attentions
529
+ self.attention = BertAttention(config, output_attentions=output_attentions,
530
+ keep_multihead_output=keep_multihead_output)
531
+ self.intermediate = BertIntermediate(config)
532
+ self.output = BertOutput(config)
533
+
534
+ def forward(self, hidden_states, attention_mask, head_mask=None):
535
+ attention_output = self.attention(hidden_states, attention_mask, head_mask)
536
+ if self.output_attentions:
537
+ attentions, attention_output = attention_output
538
+ intermediate_output = self.intermediate(attention_output)
539
+ layer_output = self.output(intermediate_output, attention_output)
540
+ if self.output_attentions:
541
+ return attentions, layer_output
542
+ return layer_output
543
+
544
+
545
+ class ZenEncoder(nn.Module):
546
+ def __init__(self, config, output_attentions=False, keep_multihead_output=False):
547
+ super(ZenEncoder, self).__init__()
548
+ self.output_attentions = output_attentions
549
+ layer = BertLayer(config, output_attentions=output_attentions,
550
+ keep_multihead_output=keep_multihead_output)
551
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
552
+ self.word_layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_word_layers)])
553
+ self.num_hidden_word_layers = config.num_hidden_word_layers
554
+
555
+ def forward(self, hidden_states, ngram_hidden_states, ngram_position_matrix, attention_mask,
556
+ ngram_attention_mask,
557
+ output_all_encoded_layers=True, head_mask=None):
558
+ # Need to check what is the attention masking doing here
559
+ all_encoder_layers = []
560
+ all_attentions = []
561
+ num_hidden_ngram_layers = self.num_hidden_word_layers
562
+ for i, layer_module in enumerate(self.layer):
563
+ hidden_states = layer_module(hidden_states, attention_mask, head_mask[i])
564
+ if i < num_hidden_ngram_layers:
565
+ ngram_hidden_states = self.word_layers[i](ngram_hidden_states, ngram_attention_mask, head_mask[i])
566
+ if self.output_attentions:
567
+ ngram_attentions, ngram_hidden_states = ngram_hidden_states
568
+ if self.output_attentions:
569
+ attentions, hidden_states = hidden_states
570
+ all_attentions.append(attentions)
571
+ hidden_states += torch.bmm(ngram_position_matrix.float(), ngram_hidden_states.float())
572
+ if output_all_encoded_layers:
573
+ all_encoder_layers.append(hidden_states)
574
+ if not output_all_encoded_layers:
575
+ all_encoder_layers.append(hidden_states)
576
+ if self.output_attentions:
577
+ return all_attentions, all_encoder_layers
578
+ return all_encoder_layers
579
+
580
+
581
+ class BertPooler(nn.Module):
582
+ def __init__(self, config):
583
+ super(BertPooler, self).__init__()
584
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
585
+ self.activation = nn.Tanh()
586
+
587
+ def forward(self, hidden_states):
588
+ # We "pool" the model by simply taking the hidden state corresponding
589
+ # to the first token.
590
+ first_token_tensor = hidden_states[:, 0]
591
+ pooled_output = self.dense(first_token_tensor)
592
+ pooled_output = self.activation(pooled_output)
593
+ return pooled_output
594
+
595
+
596
+ class BertPredictionHeadTransform(nn.Module):
597
+ def __init__(self, config):
598
+ super(BertPredictionHeadTransform, self).__init__()
599
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
600
+ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
601
+ self.transform_act_fn = ACT2FN[config.hidden_act]
602
+ else:
603
+ self.transform_act_fn = config.hidden_act
604
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
605
+
606
+ def forward(self, hidden_states):
607
+ hidden_states = self.dense(hidden_states)
608
+ hidden_states = self.transform_act_fn(hidden_states)
609
+ hidden_states = self.LayerNorm(hidden_states)
610
+ return hidden_states
611
+
612
+
613
+ class BertLMPredictionHead(nn.Module):
614
+ def __init__(self, config, bert_model_embedding_weights):
615
+ super(BertLMPredictionHead, self).__init__()
616
+ self.transform = BertPredictionHeadTransform(config)
617
+
618
+ # The output weights are the same as the input embeddings, but there is
619
+ # an output-only bias for each token.
620
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
621
+ bert_model_embedding_weights.size(0),
622
+ bias=False)
623
+ self.decoder.weight = bert_model_embedding_weights
624
+ self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
625
+
626
+ def forward(self, hidden_states):
627
+ hidden_states = self.transform(hidden_states)
628
+ hidden_states = self.decoder(hidden_states) + self.bias
629
+ return hidden_states
630
+
631
+
632
+ class ZenOnlyMLMHead(nn.Module):
633
+ def __init__(self, config, bert_model_embedding_weights):
634
+ super(ZenOnlyMLMHead, self).__init__()
635
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
636
+
637
+ def forward(self, sequence_output):
638
+ prediction_scores = self.predictions(sequence_output)
639
+ return prediction_scores
640
+
641
+
642
+ class ZenOnlyNSPHead(nn.Module):
643
+ def __init__(self, config):
644
+ super(ZenOnlyNSPHead, self).__init__()
645
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
646
+
647
+ def forward(self, pooled_output):
648
+ seq_relationship_score = self.seq_relationship(pooled_output)
649
+ return seq_relationship_score
650
+
651
+
652
+ class ZenPreTrainingHeads(nn.Module):
653
+ def __init__(self, config, bert_model_embedding_weights):
654
+ super(ZenPreTrainingHeads, self).__init__()
655
+ self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
656
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
657
+
658
+ def forward(self, sequence_output, pooled_output):
659
+ prediction_scores = self.predictions(sequence_output)
660
+ seq_relationship_score = self.seq_relationship(pooled_output)
661
+ return prediction_scores, seq_relationship_score
662
+
663
+
664
+ class ZenPreTrainedModel(nn.Module):
665
+ """ An abstract class to handle weights initialization and
666
+ a simple interface for dowloading and loading pretrained models.
667
+ """
668
+
669
+ def __init__(self, config, *inputs, **kwargs):
670
+ super(ZenPreTrainedModel, self).__init__()
671
+ if not isinstance(config, ZenConfig):
672
+ raise ValueError(
673
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
674
+ "To create a model from a Google pretrained model use "
675
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
676
+ self.__class__.__name__, self.__class__.__name__
677
+ ))
678
+ self.config = config
679
+
680
+ def init_bert_weights(self, module):
681
+ """ Initialize the weights.
682
+ """
683
+ if isinstance(module, (nn.Linear, nn.Embedding)):
684
+ # Slightly different from the TF version which uses truncated_normal for initialization
685
+ # cf https://github.com/pytorch/pytorch/pull/5617
686
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
687
+ elif isinstance(module, BertLayerNorm):
688
+ module.bias.data.zero_()
689
+ module.weight.data.fill_(1.0)
690
+ if isinstance(module, nn.Linear) and module.bias is not None:
691
+ module.bias.data.zero_()
692
+
693
+ @classmethod
694
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
695
+ """
696
+ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
697
+ Download and cache the pre-trained model file if needed.
698
+
699
+ Params:
700
+ pretrained_model_name_or_path: either:
701
+ - a str with the name of a pre-trained model to load selected in the list of:
702
+ . `bert-base-uncased`
703
+ . `bert-large-uncased`
704
+ . `bert-base-cased`
705
+ . `bert-large-cased`
706
+ . `bert-base-multilingual-uncased`
707
+ . `bert-base-multilingual-cased`
708
+ . `bert-base-chinese`
709
+ . `bert-base-german-cased`
710
+ . `bert-large-uncased-whole-word-masking`
711
+ . `bert-large-cased-whole-word-masking`
712
+ - a path or url to a pretrained model archive containing:
713
+ . `bert_config.json` a configuration file for the model
714
+ . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
715
+ - a path or url to a pretrained model archive containing:
716
+ . `bert_config.json` a configuration file for the model
717
+ . `model.chkpt` a TensorFlow checkpoint
718
+ from_tf: should we load the weights from a locally saved TensorFlow checkpoint
719
+ cache_dir: an optional path to a folder in which the pre-trained models will be cached.
720
+ state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
721
+ *inputs, **kwargs: additional input for the specific Bert class
722
+ (ex: num_labels for BertForSequenceClassification)
723
+ """
724
+ state_dict = kwargs.get('state_dict', None)
725
+ kwargs.pop('state_dict', None)
726
+ cache_dir = kwargs.get('cache_dir', None)
727
+ kwargs.pop('cache_dir', None)
728
+ from_tf = kwargs.get('from_tf', False)
729
+ kwargs.pop('from_tf', None)
730
+ multift = kwargs.get("multift", False)
731
+ kwargs.pop('multift', None)
732
+
733
+ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
734
+ archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
735
+ config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
736
+ else:
737
+ if from_tf:
738
+ # Directly load from a TensorFlow checkpoint
739
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME)
740
+ config_file = os.path.join(pretrained_model_name_or_path, BERT_CONFIG_NAME)
741
+ else:
742
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
743
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
744
+ # redirect to the cache, if necessary
745
+ try:
746
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
747
+ except EnvironmentError:
748
+ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
749
+ logger.error(
750
+ "Couldn't reach server at '{}' to download pretrained weights.".format(
751
+ archive_file))
752
+ else:
753
+ logger.error(
754
+ "Model name '{}' was not found in model name list ({}). "
755
+ "We assumed '{}' was a path or url but couldn't find any file "
756
+ "associated to this path or url.".format(
757
+ pretrained_model_name_or_path,
758
+ ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
759
+ archive_file))
760
+ return None
761
+ try:
762
+ resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
763
+ except EnvironmentError:
764
+ if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
765
+ logger.error(
766
+ "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
767
+ config_file))
768
+ else:
769
+ logger.error(
770
+ "Model name '{}' was not found in model name list ({}). "
771
+ "We assumed '{}' was a path or url but couldn't find any file "
772
+ "associated to this path or url.".format(
773
+ pretrained_model_name_or_path,
774
+ ', '.join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
775
+ config_file))
776
+ return None
777
+ if resolved_archive_file == archive_file and resolved_config_file == config_file:
778
+ logger.info("loading weights file {}".format(archive_file))
779
+ logger.info("loading configuration file {}".format(config_file))
780
+ else:
781
+ logger.info("loading weights file {} from cache at {}".format(
782
+ archive_file, resolved_archive_file))
783
+ logger.info("loading configuration file {} from cache at {}".format(
784
+ config_file, resolved_config_file))
785
+ # Load config
786
+ config = ZenConfig.from_json_file(resolved_config_file)
787
+ logger.info("Model config {}".format(config))
788
+ # Instantiate model.
789
+ model = cls(config, *inputs, **kwargs)
790
+ if state_dict is None and not from_tf:
791
+ state_dict = torch.load(resolved_archive_file, map_location='cpu')
792
+ # Load from a PyTorch state_dict
793
+ old_keys = []
794
+ new_keys = []
795
+ for key in state_dict.keys():
796
+ new_key = None
797
+ if 'gamma' in key:
798
+ new_key = key.replace('gamma', 'weight')
799
+ if 'beta' in key:
800
+ new_key = key.replace('beta', 'bias')
801
+ if new_key:
802
+ old_keys.append(key)
803
+ new_keys.append(new_key)
804
+ if multift:
805
+ state_dict.pop("classifier.weight")
806
+ state_dict.pop("classifier.bias")
807
+ for old_key, new_key in zip(old_keys, new_keys):
808
+ state_dict[new_key] = state_dict.pop(old_key)
809
+
810
+ missing_keys = []
811
+ unexpected_keys = []
812
+ error_msgs = []
813
+ # copy state_dict so _load_from_state_dict can modify it
814
+ metadata = getattr(state_dict, '_metadata', None)
815
+ state_dict = state_dict.copy()
816
+ if metadata is not None:
817
+ state_dict._metadata = metadata
818
+
819
+ def load(module, prefix=''):
820
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
821
+ module._load_from_state_dict(
822
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
823
+ for name, child in module._modules.items():
824
+ if child is not None:
825
+ load(child, prefix + name + '.')
826
+
827
+ start_prefix = ''
828
+ if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
829
+ start_prefix = 'bert.'
830
+ load(model, prefix=start_prefix)
831
+ if len(missing_keys) > 0:
832
+ logger.info("Weights of {} not initialized from pretrained model: {}".format(
833
+ model.__class__.__name__, missing_keys))
834
+ if len(unexpected_keys) > 0:
835
+ logger.info("Weights from pretrained model not used in {}: {}".format(
836
+ model.__class__.__name__, unexpected_keys))
837
+ if len(error_msgs) > 0:
838
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
839
+ model.__class__.__name__, "\n\t".join(error_msgs)))
840
+ return model
841
+
842
+
843
+ class ZenModel(ZenPreTrainedModel):
844
+ """ZEN model ("BERT-based Chinese (Z) text encoder Enhanced by N-gram representations").
845
+
846
+ Params:
847
+ `config`: a BertConfig class instance with the configuration to build a new model
848
+ `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
849
+ `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
850
+ This can be used to compute head importance metrics. Default: False
851
+
852
+ Inputs:
853
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
854
+ with the word token indices in the vocabulary
855
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
856
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
857
+ a `sentence B` token (see BERT paper for more details).
858
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
859
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
860
+ input sequence length in the current batch. It's the mask that we typically use for attention when
861
+ a batch has varying length sentences.
862
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
863
+ `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
864
+ It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
865
+ `input_ngram_ids`: input_ids of ngrams.
866
+ `ngram_token_type_ids`: token_type_ids of ngrams.
867
+ `ngram_attention_mask`: attention_mask of ngrams.
868
+ `ngram_position_matrix`: position matrix of ngrams.
869
+
870
+
871
+ Outputs: Tuple of (encoded_layers, pooled_output)
872
+ `encoded_layers`: controled by `output_all_encoded_layers` argument:
873
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
874
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
875
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
876
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
877
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
878
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
879
+ classifier pretrained on top of the hidden state associated to the first character of the
880
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
881
+
882
+ """
883
+
884
+ def __init__(self, config, output_attentions=False, keep_multihead_output=False):
885
+ super(ZenModel, self).__init__(config)
886
+ self.output_attentions = output_attentions
887
+ self.embeddings = BertEmbeddings(config)
888
+ self.word_embeddings = BertWordEmbeddings(config)
889
+ self.encoder = ZenEncoder(config, output_attentions=output_attentions,
890
+ keep_multihead_output=keep_multihead_output)
891
+ self.pooler = BertPooler(config)
892
+ self.apply(self.init_bert_weights)
893
+
894
+ def prune_heads(self, heads_to_prune):
895
+ """ Prunes heads of the model.
896
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
897
+ """
898
+ for layer, heads in heads_to_prune.items():
899
+ self.encoder.layer[layer].attention.prune_heads(heads)
900
+
901
+ def get_multihead_outputs(self):
902
+ """ Gather all multi-head outputs.
903
+ Return: list (layers) of multihead module outputs with gradients
904
+ """
905
+ return [layer.attention.self.multihead_output for layer in self.encoder.layer]
906
+
907
+ def forward(self, input_ids,
908
+ input_ngram_ids,
909
+ ngram_position_matrix,
910
+ token_type_ids=None,
911
+ ngram_token_type_ids=None,
912
+ attention_mask=None,
913
+ ngram_attention_mask=None,
914
+ output_all_encoded_layers=True,
915
+ head_mask=None):
916
+ if attention_mask is None:
917
+ attention_mask = torch.ones_like(input_ids)
918
+ if token_type_ids is None:
919
+ token_type_ids = torch.zeros_like(input_ids)
920
+
921
+ if ngram_attention_mask is None:
922
+ ngram_attention_mask = torch.ones_like(input_ngram_ids)
923
+ if ngram_token_type_ids is None:
924
+ ngram_token_type_ids = torch.zeros_like(input_ngram_ids)
925
+
926
+ # We create a 3D attention mask from a 2D tensor mask.
927
+ # Sizes are [batch_size, 1, 1, to_seq_length]
928
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
929
+ # this attention mask is more simple than the triangular masking of causal attention
930
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
931
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
932
+ extended_ngram_attention_mask = ngram_attention_mask.unsqueeze(1).unsqueeze(2)
933
+
934
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
935
+ # masked positions, this operation will create a tensor which is 0.0 for
936
+ # positions we want to attend and -10000.0 for masked positions.
937
+ # Since we are adding it to the raw scores before the softmax, this is
938
+ # effectively the same as removing these entirely.
939
+ extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) # fp16 compatibility
940
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
941
+
942
+ extended_ngram_attention_mask = extended_ngram_attention_mask.to(dtype=torch.float32)
943
+ extended_ngram_attention_mask = (1.0 - extended_ngram_attention_mask) * -10000.0
944
+
945
+ # Prepare head mask if needed
946
+ # 1.0 in head_mask indicate we keep the head
947
+ # attention_probs has shape bsz x n_heads x N x N
948
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
949
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
950
+ if head_mask is not None:
951
+ if head_mask.dim() == 1:
952
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
953
+ head_mask = head_mask.expand_as(self.config.num_hidden_layers, -1, -1, -1, -1)
954
+ elif head_mask.dim() == 2:
955
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
956
+ -1) # We can specify head_mask for each layer
957
+ head_mask = head_mask.to(
958
+ dtype=torch.float32) # switch to fload if need + fp16 compatibility
959
+ else:
960
+ head_mask = [None] * self.config.num_hidden_layers
961
+
962
+ embedding_output = self.embeddings(input_ids, token_type_ids)
963
+ ngram_embedding_output = self.word_embeddings(input_ngram_ids, ngram_token_type_ids)
964
+
965
+ encoded_layers = self.encoder(embedding_output,
966
+ ngram_embedding_output,
967
+ ngram_position_matrix,
968
+ extended_attention_mask,
969
+ extended_ngram_attention_mask,
970
+ output_all_encoded_layers=output_all_encoded_layers,
971
+ head_mask=head_mask)
972
+ if self.output_attentions:
973
+ all_attentions, encoded_layers = encoded_layers
974
+ sequence_output = encoded_layers[-1]
975
+ pooled_output = self.pooler(sequence_output)
976
+ if not output_all_encoded_layers:
977
+ encoded_layers = encoded_layers[-1]
978
+ if self.output_attentions:
979
+ return all_attentions, encoded_layers, pooled_output
980
+ return encoded_layers, pooled_output
981
+
982
+
983
+ class ZenForPreTraining(ZenPreTrainedModel):
984
+ """ZEN model with pre-training heads.
985
+ This module comprises the ZEN model followed by the two pre-training heads:
986
+ - the masked language modeling head, and
987
+ - the next sentence classification head.
988
+
989
+ Params:
990
+ `config`: a BertConfig class instance with the configuration to build a new model
991
+ `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
992
+ `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
993
+ This can be used to compute head importance metrics. Default: False
994
+
995
+ Inputs:
996
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
997
+ with the word token indices in the vocabulary
998
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
999
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1000
+ a `sentence B` token (see BERT paper for more details).
1001
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1002
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1003
+ input sequence length in the current batch. It's the mask that we typically use for attention when
1004
+ a batch has varying length sentences.
1005
+ `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
1006
+ with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
1007
+ is only computed for the labels set in [0, ..., vocab_size]
1008
+ `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
1009
+ with indices selected in [0, 1].
1010
+ 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
1011
+ `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
1012
+ It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
1013
+ `input_ngram_ids`: input_ids of ngrams.
1014
+ `ngram_token_type_ids`: token_type_ids of ngrams.
1015
+ `ngram_attention_mask`: attention_mask of ngrams.
1016
+ `ngram_position_matrix`: position matrix of ngrams.
1017
+
1018
+ Outputs:
1019
+ if `masked_lm_labels` and `next_sentence_label` are not `None`:
1020
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
1021
+ sentence classification loss.
1022
+ if `masked_lm_labels` or `next_sentence_label` is `None`:
1023
+ Outputs a tuple comprising
1024
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1025
+ - the next sentence classification logits of shape [batch_size, 2].
1026
+
1027
+ """
1028
+
1029
+ def __init__(self, config, output_attentions=False, keep_multihead_output=False):
1030
+ super(ZenForPreTraining, self).__init__(config)
1031
+ self.output_attentions = output_attentions
1032
+ self.bert = ZenModel(config, output_attentions=output_attentions,
1033
+ keep_multihead_output=keep_multihead_output)
1034
+ self.cls = ZenPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
1035
+ self.apply(self.init_bert_weights)
1036
+
1037
+ def forward(self, input_ids, input_ngram_ids, ngram_position_matrix, token_type_ids=None,
1038
+ ngram_token_type_ids=None,
1039
+ attention_mask=None,
1040
+ ngram_attention_mask=None,
1041
+ masked_lm_labels=None,
1042
+ next_sentence_label=None, head_mask=None):
1043
+ outputs = self.bert(input_ids,
1044
+ input_ngram_ids,
1045
+ ngram_position_matrix,
1046
+ token_type_ids,
1047
+ ngram_token_type_ids,
1048
+ attention_mask,
1049
+ ngram_attention_mask,
1050
+ output_all_encoded_layers=False, head_mask=head_mask)
1051
+ if self.output_attentions:
1052
+ all_attentions, sequence_output, pooled_output = outputs
1053
+ else:
1054
+ sequence_output, pooled_output = outputs
1055
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1056
+
1057
+ if masked_lm_labels is not None and next_sentence_label is not None:
1058
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
1059
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
1060
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1061
+ total_loss = masked_lm_loss + next_sentence_loss
1062
+ return total_loss
1063
+ elif self.output_attentions:
1064
+ return all_attentions, prediction_scores, seq_relationship_score
1065
+ return prediction_scores, seq_relationship_score
1066
+
1067
+
1068
+ class ZenForMaskedLM(ZenPreTrainedModel):
1069
+ """ZEN model with the masked language modeling head.
1070
+ This module comprises the ZEN model followed by the masked language modeling head.
1071
+
1072
+ Params:
1073
+ `config`: a BertConfig class instance with the configuration to build a new model
1074
+ `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
1075
+ `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
1076
+ This can be used to compute head importance metrics. Default: False
1077
+
1078
+ Inputs:
1079
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1080
+ with the word token indices in the vocabulary
1081
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1082
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1083
+ a `sentence B` token (see BERT paper for more details).
1084
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1085
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1086
+ input sequence length in the current batch. It's the mask that we typically use for attention when
1087
+ a batch has varying length sentences.
1088
+ `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
1089
+ with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
1090
+ is only computed for the labels set in [0, ..., vocab_size]
1091
+ `head_mask`: an optional torch.LongTensor of shape [num_heads] with indices
1092
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1093
+ input sequence length in the current batch. It's the mask that we typically use for attention when
1094
+ a batch has varying length sentences.
1095
+ `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
1096
+ It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
1097
+ `input_ngram_ids`: input_ids of ngrams.
1098
+ `ngram_token_type_ids`: token_type_ids of ngrams.
1099
+ `ngram_attention_mask`: attention_mask of ngrams.
1100
+ `ngram_position_matrix`: position matrix of ngrams.
1101
+
1102
+ Outputs:
1103
+ if `masked_lm_labels` is not `None`:
1104
+ Outputs the masked language modeling loss.
1105
+ if `masked_lm_labels` is `None`:
1106
+ Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
1107
+
1108
+ """
1109
+
1110
+ def __init__(self, config, output_attentions=False, keep_multihead_output=False):
1111
+ super(ZenForMaskedLM, self).__init__(config)
1112
+ self.output_attentions = output_attentions
1113
+ self.bert = ZenModel(config, output_attentions=output_attentions,
1114
+ keep_multihead_output=keep_multihead_output)
1115
+ self.cls = ZenOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
1116
+ self.apply(self.init_bert_weights)
1117
+
1118
+ def forward(self, input_ids, input_ngram_ids, ngram_position_matrix, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
1119
+ outputs = self.bert(input_ids, input_ngram_ids, ngram_position_matrix, token_type_ids, attention_mask,
1120
+ output_all_encoded_layers=False,
1121
+ head_mask=head_mask)
1122
+ if self.output_attentions:
1123
+ all_attentions, sequence_output, _ = outputs
1124
+ else:
1125
+ sequence_output, _ = outputs
1126
+ prediction_scores = self.cls(sequence_output)
1127
+
1128
+ if masked_lm_labels is not None:
1129
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
1130
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
1131
+ return masked_lm_loss
1132
+ elif self.output_attentions:
1133
+ return all_attentions, prediction_scores
1134
+ return prediction_scores
1135
+
1136
+
1137
+ class ZenForNextSentencePrediction(ZenPreTrainedModel):
1138
+ """ZEN model with next sentence prediction head.
1139
+ This module comprises the ZEN model followed by the next sentence classification head.
1140
+
1141
+ Params:
1142
+ `config`: a BertConfig class instance with the configuration to build a new model
1143
+ `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
1144
+ `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
1145
+ This can be used to compute head importance metrics. Default: False
1146
+
1147
+ Inputs:
1148
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1149
+ with the word token indices in the vocabulary
1150
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1151
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1152
+ a `sentence B` token (see BERT paper for more details).
1153
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1154
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1155
+ input sequence length in the current batch. It's the mask that we typically use for attention when
1156
+ a batch has varying length sentences.
1157
+ `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
1158
+ with indices selected in [0, 1].
1159
+ 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
1160
+ `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
1161
+ It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
1162
+ `input_ngram_ids`: input_ids of ngrams.
1163
+ `ngram_token_type_ids`: token_type_ids of ngrams.
1164
+ `ngram_attention_mask`: attention_mask of ngrams.
1165
+ `ngram_position_matrix`: position matrix of ngrams.
1166
+
1167
+ Outputs:
1168
+ if `next_sentence_label` is not `None`:
1169
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
1170
+ sentence classification loss.
1171
+ if `next_sentence_label` is `None`:
1172
+ Outputs the next sentence classification logits of shape [batch_size, 2].
1173
+
1174
+ """
1175
+
1176
+ def __init__(self, config, output_attentions=False, keep_multihead_output=False):
1177
+ super(ZenForNextSentencePrediction, self).__init__(config)
1178
+ self.output_attentions = output_attentions
1179
+ self.bert = ZenModel(config, output_attentions=output_attentions,
1180
+ keep_multihead_output=keep_multihead_output)
1181
+ self.cls = ZenOnlyNSPHead(config)
1182
+ self.apply(self.init_bert_weights)
1183
+
1184
+ def forward(self, input_ids, input_ngram_ids, ngram_position_matrix, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
1185
+ outputs = self.bert(input_ids, input_ngram_ids, ngram_position_matrix, token_type_ids, attention_mask,
1186
+ output_all_encoded_layers=False,
1187
+ head_mask=head_mask)
1188
+ if self.output_attentions:
1189
+ all_attentions, _, pooled_output = outputs
1190
+ else:
1191
+ _, pooled_output = outputs
1192
+ seq_relationship_score = self.cls(pooled_output)
1193
+
1194
+ if next_sentence_label is not None:
1195
+ loss_fct = CrossEntropyLoss(ignore_index=-1)
1196
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1197
+ return next_sentence_loss
1198
+ elif self.output_attentions:
1199
+ return all_attentions, seq_relationship_score
1200
+ return seq_relationship_score
1201
+
1202
+
1203
+ class ZenForSequenceClassification(ZenPreTrainedModel):
1204
+ """ZEN model for classification.
1205
+ This module is composed of the ZEN model with a linear layer on top of
1206
+ the pooled output.
1207
+
1208
+ Params:
1209
+ `config`: a BertConfig class instance with the configuration to build a new model
1210
+ `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
1211
+ `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
1212
+ This can be used to compute head importance metrics. Default: False
1213
+ `num_labels`: the number of classes for the classifier. Default = 2.
1214
+
1215
+ Inputs:
1216
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1217
+ with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
1218
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
1219
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1220
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1221
+ a `sentence B` token (see BERT paper for more details).
1222
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1223
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1224
+ input sequence length in the current batch. It's the mask that we typically use for attention when
1225
+ a batch has varying length sentences.
1226
+ `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
1227
+ with indices selected in [0, ..., num_labels].
1228
+ `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
1229
+ It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
1230
+ `input_ngram_ids`: input_ids of ngrams.
1231
+ `ngram_token_type_ids`: token_type_ids of ngrams.
1232
+ `ngram_attention_mask`: attention_mask of ngrams.
1233
+ `ngram_position_matrix`: position matrix of ngrams.
1234
+
1235
+ Outputs:
1236
+ if `labels` is not `None`:
1237
+ Outputs the CrossEntropy classification loss of the output with the labels.
1238
+ if `labels` is `None`:
1239
+ Outputs the classification logits of shape [batch_size, num_labels].
1240
+
1241
+ """
1242
+
1243
+ def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False):
1244
+ super(ZenForSequenceClassification, self).__init__(config)
1245
+ self.output_attentions = output_attentions
1246
+ self.num_labels = num_labels
1247
+ self.bert = ZenModel(config, output_attentions=output_attentions,
1248
+ keep_multihead_output=keep_multihead_output)
1249
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1250
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
1251
+ self.apply(self.init_bert_weights)
1252
+
1253
+ def forward(self, input_ids, input_ngram_ids, ngram_position_matrix, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
1254
+ outputs = self.bert(input_ids, input_ngram_ids, ngram_position_matrix, token_type_ids, attention_mask,
1255
+ output_all_encoded_layers=False,
1256
+ head_mask=head_mask)
1257
+ if self.output_attentions:
1258
+ all_attentions, _, pooled_output = outputs
1259
+ else:
1260
+ _, pooled_output = outputs
1261
+ pooled_output = self.dropout(pooled_output)
1262
+ logits = self.classifier(pooled_output)
1263
+
1264
+ if labels is not None:
1265
+ loss_fct = CrossEntropyLoss()
1266
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1267
+ return loss
1268
+ elif self.output_attentions:
1269
+ return all_attentions, logits
1270
+ return logits
1271
+
1272
+ class ZenForTokenClassification(ZenPreTrainedModel):
1273
+ """ZEN model for token-level classification.
1274
+ This module is composed of the ZEN model with a linear layer on top of
1275
+ the full hidden state of the last layer.
1276
+
1277
+ Params:
1278
+ `config`: a BertConfig class instance with the configuration to build a new model
1279
+ `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
1280
+ `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
1281
+ This can be used to compute head importance metrics. Default: False
1282
+ `num_labels`: the number of classes for the classifier. Default = 2.
1283
+
1284
+ Inputs:
1285
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
1286
+ with the word token indices in the vocabulary
1287
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
1288
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
1289
+ a `sentence B` token (see BERT paper for more details).
1290
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
1291
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
1292
+ input sequence length in the current batch. It's the mask that we typically use for attention when
1293
+ a batch has varying length sentences.
1294
+ `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
1295
+ with indices selected in [0, ..., num_labels].
1296
+ `head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
1297
+ It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
1298
+ `input_ngram_ids`: input_ids of ngrams.
1299
+ `ngram_token_type_ids`: token_type_ids of ngrams.
1300
+ `ngram_attention_mask`: attention_mask of ngrams.
1301
+ `ngram_position_matrix`: position matrix of ngrams.
1302
+
1303
+ Outputs:
1304
+ if `labels` is not `None`:
1305
+ Outputs the CrossEntropy classification loss of the output with the labels.
1306
+ if `labels` is `None`:
1307
+ Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
1308
+
1309
+ """
1310
+
1311
+ def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False):
1312
+ super(ZenForTokenClassification, self).__init__(config)
1313
+ self.output_attentions = output_attentions
1314
+ self.num_labels = num_labels
1315
+ self.bert = ZenModel(config, output_attentions=output_attentions,
1316
+ keep_multihead_output=keep_multihead_output)
1317
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1318
+ self.classifier = nn.Linear(config.hidden_size, num_labels)
1319
+ self.apply(self.init_bert_weights)
1320
+
1321
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, valid_ids=None,
1322
+ attention_mask_label=None, ngram_ids=None, ngram_positions=None, head_mask=None):
1323
+ outputs = self.bert(input_ids, ngram_ids, ngram_positions, token_type_ids, attention_mask,
1324
+ output_all_encoded_layers=False, head_mask=head_mask)
1325
+ if self.output_attentions:
1326
+ all_attentions, sequence_output, _ = outputs
1327
+ else:
1328
+ sequence_output, _ = outputs
1329
+
1330
+ batch_size, max_len, feat_dim = sequence_output.shape
1331
+ valid_output = torch.zeros(batch_size, max_len, feat_dim, dtype=torch.float32, device=input_ids.device)
1332
+
1333
+ if self.num_labels == 38:
1334
+ # just for POS to filter/mask input_ids=0
1335
+ for i in range(batch_size):
1336
+ temp = sequence_output[i][valid_ids[i] == 1]
1337
+ valid_output[i][:temp.size(0)] = temp
1338
+ else:
1339
+ valid_output = sequence_output
1340
+
1341
+ sequence_output = self.dropout(valid_output)
1342
+ logits = self.classifier(sequence_output)
1343
+
1344
+ if labels is not None:
1345
+ loss_fct = CrossEntropyLoss(ignore_index=0)
1346
+ # Only keep active parts of the loss
1347
+ attention_mask_label = None
1348
+ if attention_mask_label is not None:
1349
+ active_loss = attention_mask_label.view(-1) == 1
1350
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
1351
+ active_labels = labels.view(-1)[active_loss]
1352
+ loss = loss_fct(active_logits, active_labels)
1353
+ else:
1354
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1355
+ return loss
1356
+ else:
1357
+ return logits