# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Transformer decoder that mimics a BERT encoder, to load BERT checkpoints.""" from __future__ import absolute_import from __future__ import division # from __future__ import google_type_annotations from __future__ import print_function import tensorflow as tf from official.modeling import tf_utils from official.nlp.modeling import layers from official.nlp.modeling.layers import transformer from official.nlp.transformer import model_utils as transformer_utils class TransformerDecoder(tf.keras.layers.Layer): """Transformer decoder stack.""" def __init__(self, num_hidden_layers=12, hidden_size=768, num_attention_heads=12, intermediate_size=3072, intermediate_activation="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, initializer_range=0.02, attend_to_last_layer=True, multi_channel_cross_attention=False, **kwargs): super(TransformerDecoder, self).__init__(**kwargs) self.num_hidden_layers = num_hidden_layers self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.intermediate_activation = tf_utils.get_activation( intermediate_activation) self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.initializer_range = initializer_range self.attend_to_last_layer = attend_to_last_layer self.multi_channel_cross_attention = multi_channel_cross_attention def build(self, unused_input_shapes): """Implements build() for the layer.""" self.layers = [] for i in range(self.num_hidden_layers): self.layers.append( transformer.TransformerDecoderLayer( num_attention_heads=self.num_attention_heads, intermediate_size=self.intermediate_size, intermediate_activation=self.intermediate_activation, dropout_rate=self.hidden_dropout_prob, attention_dropout_rate=self.attention_probs_dropout_prob, kernel_initializer=tf.keras.initializers.TruncatedNormal( stddev=self.initializer_range), multi_channel_cross_attention=self.multi_channel_cross_attention, name=("layer_%d" % i))) super(TransformerDecoder, self).build(unused_input_shapes) def call(self, inputs, cache=None, decode_loop_step=None): """Return the output of the decoder layer stacks. Args: inputs: A dictionary of inputs. `decoder_inputs` is a tf.int32 tensor for input ids. `encoder_outputs` is a list of tensors with shape [batch_size, input_length, hidden_size]. `self_attention_mask` is the bias for decoder self-attention layer. [1, 1, target_length, target_length]. `attention_mask` is the bias for encoder-decoder attention layer, [batch_size, 1, 1, input_length]. cache: A dictionary of cache tensors, including key & value attentions. decode_loop_step: an integer to indicate the step inside a decoding loop. Returns: Output of decoder layer stack. float32 tensor with shape [batch_size, target_length, hidden_size] """ decoder_inputs = inputs["decoder_inputs"] encoder_outputs = inputs["encoder_outputs"] self_attention_mask = inputs["self_attention_mask"] attention_mask = inputs["attention_mask"] decoder_shape = tf_utils.get_shape_list(decoder_inputs, expected_rank=3) batch_size = decoder_shape[0] decoder_length = decoder_shape[1] def _to_bert_self_attention_mask(matrix): """[1, 1, target_len, target_len] -> [bs, target_len, target_len].""" matrix = tf.squeeze(matrix, axis=[1]) matrix = tf.tile(matrix, [batch_size, 1, 1]) return matrix def _to_bert_encdec_attention_mask(matrix): """[bs, 1, 1, input_len] -> [bs, target_len, input_len].""" if self.multi_channel_cross_attention: matrix = tf.expand_dims(matrix, axis=2) matrix = tf.tile(matrix, [1, 1, decoder_length, 1]) else: matrix = tf.squeeze(matrix, axis=[1]) matrix = tf.tile(matrix, [1, decoder_length, 1]) return matrix attention_mask = _to_bert_encdec_attention_mask(attention_mask) self_attention_mask = _to_bert_self_attention_mask(self_attention_mask) output_tensor = decoder_inputs for layer_idx in range(self.num_hidden_layers): if self.attend_to_last_layer: memory = encoder_outputs[-1] else: memory = encoder_outputs[layer_idx] if self.multi_channel_cross_attention: transformer_inputs = [ output_tensor, memory, attention_mask, self_attention_mask, inputs["doc_attention_probs"] ] else: transformer_inputs = [ output_tensor, memory, attention_mask, self_attention_mask ] # Gets the cache for decoding. if cache is None: output_tensor, _ = self.layers[layer_idx](transformer_inputs) else: cache_layer_idx = str(layer_idx) output_tensor, cache[cache_layer_idx] = self.layers[layer_idx]( transformer_inputs, cache=cache[cache_layer_idx], decode_loop_step=decode_loop_step) return output_tensor, cache def get_attention_bias(input_tensor, bias_type, padding_value=0, max_length=None): """A helper function to get various attention bias tensors.""" if bias_type not in ("single_cross", "multi_cross", "decoder_self"): raise ValueError("Invalid attention bias type: %s" % bias_type) if bias_type == "single_cross": length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1] bias = transformer_utils.get_padding_bias( input_tensor, padding_value=padding_value) elif bias_type == "multi_cross": length = tf_utils.get_shape_list(input_tensor, expected_rank=3)[2] padding = transformer_utils.get_padding( input_tensor, padding_value=padding_value) bias = padding * -1e9 else: if max_length is not None: length = max_length else: length = tf_utils.get_shape_list(input_tensor, expected_rank=2)[1] bias = transformer_utils.get_decoder_self_attention_bias(length) return tf.where(bias < 0, tf.zeros_like(bias), tf.ones_like(bias)) class AttentionBias(tf.keras.layers.Layer): def __init__(self, bias_type, **kwargs): super(AttentionBias, self).__init__(**kwargs) self.bias_type = bias_type def call(self, inputs): return get_attention_bias(inputs, self.bias_type) class EmbeddingPostprocessor(tf.keras.layers.Layer): """Performs various post-processing on a word embedding tensor.""" def __init__(self, use_type_embeddings=False, token_type_vocab_size=None, use_position_embeddings=True, max_position_embeddings=512, dropout_prob=0.0, initializer_range=0.02, initializer=None, **kwargs): super(EmbeddingPostprocessor, self).__init__(**kwargs) self.use_type_embeddings = use_type_embeddings self.token_type_vocab_size = token_type_vocab_size self.use_position_embeddings = use_position_embeddings self.max_position_embeddings = max_position_embeddings self.dropout_prob = dropout_prob self.initializer_range = initializer_range if not initializer: self.initializer = tf.keras.initializers.TruncatedNormal( stddev=initializer_range) else: self.initializer = initializer if self.use_type_embeddings and not self.token_type_vocab_size: raise ValueError("If `use_type_embeddings` is True, then " "`token_type_vocab_size` must be specified.") def build(self, input_shapes): """Implements build() for the layer.""" (word_embeddings_shape, _) = input_shapes width = word_embeddings_shape.as_list()[-1] self.type_embeddings = None if self.use_type_embeddings: self.type_embeddings = self.add_weight( "type_embeddings", shape=[self.token_type_vocab_size, width], initializer=tf.keras.initializers.TruncatedNormal( stddev=self.initializer_range), dtype=self.dtype) self.position_embeddings = None if self.use_position_embeddings: self.position_embeddings = self.add_weight( "position_embeddings", shape=[self.max_position_embeddings, width], initializer=tf.keras.initializers.TruncatedNormal( stddev=self.initializer_range), dtype=self.dtype) self.output_layer_norm = tf.keras.layers.LayerNormalization( name="layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) self.output_dropout = tf.keras.layers.Dropout( rate=self.dropout_prob, dtype=tf.float32) super(EmbeddingPostprocessor, self).build(input_shapes) def __call__(self, word_embeddings, token_type_ids=None, **kwargs): inputs = tf_utils.pack_inputs([word_embeddings, token_type_ids]) return super(EmbeddingPostprocessor, self).__call__(inputs, **kwargs) def call(self, inputs): """Implements call() for the layer.""" unpacked_inputs = tf_utils.unpack_inputs(inputs) word_embeddings = unpacked_inputs[0] token_type_ids = unpacked_inputs[1] input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] width = input_shape[2] output = word_embeddings if self.use_type_embeddings: flat_token_type_ids = tf.reshape(token_type_ids, [-1]) token_type_embeddings = tf.gather(self.type_embeddings, flat_token_type_ids) token_type_embeddings = tf.reshape(token_type_embeddings, [batch_size, seq_length, width]) output += token_type_embeddings if self.use_position_embeddings: position_embeddings = tf.expand_dims( tf.slice(self.position_embeddings, [0, 0], [seq_length, width]), axis=0) output += position_embeddings output = self.output_layer_norm(output) output = self.output_dropout(output) return output class Decoder(tf.keras.layers.Layer): """The decoder network which can reuse encoder embeddings for target.""" def __init__(self, config, embedding_lookup=None, **kwargs): super(Decoder, self).__init__(**kwargs) self.config = config # Shares vocabulary embedding. self.embedding_lookup = None if embedding_lookup: self.embedding_lookup = embedding_lookup def build(self, unused_input_shapes): """Implements build() for the layer.""" if self.embedding_lookup is None: self.embedding_lookup = layers.OnDeviceEmbedding( vocab_size=self.config.vocab_size, embedding_width=self.config.hidden_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=self.config.initializer_range), name="target_embeddings") self.embedding_postprocessor = EmbeddingPostprocessor( use_type_embeddings=False, use_position_embeddings=True, max_position_embeddings=self.config.max_position_embeddings, dropout_prob=self.config.hidden_dropout_prob, initializer=tf.keras.initializers.VarianceScaling( scale=self.config.initializer_gain, mode="fan_avg", distribution="uniform"), name="embedding_postprocessor") # Decoder can use a different intermediate size. self.multi_channel_cross_attention = self.config.get( "multi_channel_cross_attention", False) self.decoder = TransformerDecoder( num_hidden_layers=self.config.num_decoder_layers, hidden_size=self.config.hidden_size, num_attention_heads=self.config.num_decoder_attn_heads, intermediate_size=self.config.decoder_intermediate_size, intermediate_activation=self.config.hidden_act, hidden_dropout_prob=self.config.hidden_dropout_prob, attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, initializer_range=self.config.initializer_range, multi_channel_cross_attention=self.multi_channel_cross_attention, name="decoder") super(Decoder, self).build(unused_input_shapes) def _decoding_step_time_signal(self, target_embeds, decode_loop_step): """Applies time signal (positional embeddings) for decoded embeddings.""" # TODO(hongkuny): migrate to keras bert and design a module to handle this. output = target_embeds if self.embedding_postprocessor.use_position_embeddings: position_embeddings = tf.gather( self.embedding_postprocessor.position_embeddings, [decode_loop_step]) # Broadcasts to all sequences inside a batch. output += position_embeddings output = self.embedding_postprocessor.output_layer_norm(output) output = self.embedding_postprocessor.output_dropout(output) return output def call(self, inputs, cache=None, decode_loop_step=None, padded_decode=False): """Implements call() for the layer. Args: inputs: a list of input tensors. cache: A dictionary of cache tensors, including key & value attentions. Due to the limit of keras, we uses the side effect to update cache and states of tensors will be mutated. decode_loop_step: an integer to indicate the step inside a decoding loop. padded_decode: a boolean indicates if the pass is for padded decoding. Returns: Decoder output tensors. """ attention_bias = inputs["attention_bias"] target_ids = inputs["target_ids"] all_encoder_outputs = inputs["all_encoder_outputs"] self_attention_bias = inputs["self_attention_bias"] if not isinstance(all_encoder_outputs, list): all_encoder_outputs = [all_encoder_outputs] target_embeds = self.embedding_lookup(target_ids) if decode_loop_step is None: target_embeds = self.embedding_postprocessor(target_embeds) else: target_embeds = self._decoding_step_time_signal(target_embeds, decode_loop_step) decoder_inputs = dict( decoder_inputs=target_embeds, encoder_outputs=all_encoder_outputs, self_attention_mask=self_attention_bias, attention_mask=attention_bias) if self.multi_channel_cross_attention: decoder_inputs["doc_attention_probs"] = inputs["doc_attention_probs"] decode_outputs, cache = self.decoder( decoder_inputs, cache, decode_loop_step if padded_decode else None) return decode_outputs