# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """tf.keras Models for NHNet.""" from __future__ import absolute_import from __future__ import division # from __future__ import google_type_annotations from __future__ import print_function from absl import logging import gin import tensorflow as tf from typing import Optional, Text from official.modeling import tf_utils from official.modeling.hyperparams import params_dict from official.nlp.modeling import networks from official.nlp.modeling.layers import multi_channel_attention from official.nlp.nhnet import configs from official.nlp.nhnet import decoder from official.nlp.nhnet import utils from official.nlp.transformer import beam_search def embedding_linear(embedding_matrix, x): """Uses embeddings as linear transformation weights.""" with tf.name_scope("presoftmax_linear"): batch_size = tf.shape(x)[0] length = tf.shape(x)[1] hidden_size = tf.shape(x)[2] vocab_size = tf.shape(embedding_matrix)[0] x = tf.reshape(x, [-1, hidden_size]) logits = tf.matmul(x, embedding_matrix, transpose_b=True) return tf.reshape(logits, [batch_size, length, vocab_size]) def _add_sos_to_seq(seq, start_token_id): """Add a start sequence token while keeping seq length.""" batch_size = tf.shape(seq)[0] seq_len = tf.shape(seq)[1] sos_ids = tf.ones([batch_size], tf.int32) * start_token_id targets = tf.concat([tf.expand_dims(sos_ids, axis=1), seq], axis=1) targets = targets[:, :-1] tf.assert_equal(tf.shape(targets), (batch_size, seq_len)) return targets def remove_sos_from_seq(seq, pad_token_id): """Remove the start sequence token while keeping seq length.""" batch_size, seq_len = tf_utils.get_shape_list(seq, expected_rank=2) # remove targets = seq[:, 1:] # pad pad_ids = tf.ones([batch_size], tf.int32) * pad_token_id targets = tf.concat([targets, tf.expand_dims(pad_ids, axis=1)], axis=1) tf.assert_equal(tf.shape(targets), (batch_size, seq_len)) return targets class Bert2Bert(tf.keras.Model): """Bert2Bert encoder decoder model for training.""" def __init__(self, params, bert_layer, decoder_layer, name=None): super(Bert2Bert, self).__init__(name=name) self.params = params if not bert_layer.built: raise ValueError("bert_layer should be built.") if not decoder_layer.built: raise ValueError("decoder_layer should be built.") self.bert_layer = bert_layer self.decoder_layer = decoder_layer def get_config(self): return {"params": self.params.as_dict()} def get_decode_logits(self, decoder_inputs, ids, decoder_self_attention_bias, step, cache=None): if cache: if self.params.get("padded_decode", False): bias_shape = decoder_self_attention_bias.shape.as_list() self_attention_bias = tf.slice( decoder_self_attention_bias, [0, 0, step, 0], [bias_shape[0], bias_shape[1], 1, bias_shape[3]]) else: self_attention_bias = decoder_self_attention_bias[:, :, step:step + 1, :step + 1] # Sets decoder input to the last generated IDs. decoder_input = ids[:, -1:] else: self_attention_bias = decoder_self_attention_bias[:, :, :step + 1, :step + 1] decoder_input = ids decoder_inputs["target_ids"] = decoder_input decoder_inputs["self_attention_bias"] = self_attention_bias if cache: decoder_outputs = self.decoder_layer( decoder_inputs, cache, decode_loop_step=step, padded_decode=self.params.get("padded_decode", False)) else: decoder_outputs = self.decoder_layer(decoder_inputs) logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, decoder_outputs[:, -1:, :]) logits = tf.squeeze(logits, axis=[1]) return logits def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" # Max decode length should be smaller than the positional embedding max # sequence length. decoder_self_attention_bias = decoder.get_attention_bias( input_tensor=None, bias_type="decoder_self", max_length=max_decode_length) def _symbols_to_logits_fn(ids, i, cache): """Generate logits for next candidate IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1] i: Loop index cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ decoder_inputs = dict( all_encoder_outputs=cache["all_encoder_outputs"], attention_bias=cache["attention_bias"]) logits = self.get_decode_logits( decoder_inputs, ids, decoder_self_attention_bias, step=i, cache=cache if self.params.use_cache else None) return logits, cache return _symbols_to_logits_fn def train_decode(self, decode_outputs): logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, decode_outputs) decode_output_ids = tf.cast(tf.argmax(logits, axis=-1), tf.int32) output_log_probs = tf.nn.log_softmax(logits, axis=-1) return logits, decode_output_ids, output_log_probs def predict_decode(self, start_token_ids, cache): symbols_to_logits_fn = self._get_symbols_to_logits_fn(self.params.len_title) # Use beam search to find the top beam_size sequences and scores. decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=start_token_ids, initial_cache=cache, vocab_size=self.params.vocab_size, beam_size=self.params.beam_size, alpha=self.params.alpha, max_decode_length=self.params.len_title, padded_decode=self.params.get("padded_decode", False), eos_id=self.params.end_token_id) return decoded_ids, scores def _get_logits_for_decode_ids(self, decoder_inputs, top_decoded_ids): """Returns the log probabilities for ids.""" target_ids = _add_sos_to_seq(top_decoded_ids, self.params.start_token_id) decoder_inputs["self_attention_bias"] = decoder.get_attention_bias( target_ids, bias_type="decoder_self") decoder_inputs["target_ids"] = target_ids decoder_outputs = self.decoder_layer(decoder_inputs) logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, decoder_outputs) return logits def _init_cache(self, batch_size): num_heads = self.params.num_decoder_attn_heads dim_per_head = self.params.hidden_size // num_heads init_decode_length = ( self.params.len_title if self.params.get("padded_decode", False) else 0) cache = {} for layer in range(self.params.num_decoder_layers): cache[str(layer)] = { "key": tf.zeros( [batch_size, init_decode_length, num_heads, dim_per_head], dtype=tf.float32), "value": tf.zeros( [batch_size, init_decode_length, num_heads, dim_per_head], dtype=tf.float32) } return cache def call(self, inputs, mode="train"): """Implements call(). Args: inputs: a dictionary of tensors. mode: string, an enum for mode, train/eval. Returns: logits, decode_output_ids, output_log_probs for training. top_decoded_ids for eval. """ input_ids = inputs["input_ids"] input_mask = inputs["input_mask"] segment_ids = inputs["segment_ids"] all_encoder_outputs, _ = self.bert_layer( [input_ids, input_mask, segment_ids]) if mode not in ("train", "eval", "predict"): raise ValueError("Invalid call mode: %s" % mode) encoder_decoder_attention_bias = decoder.get_attention_bias( input_ids, bias_type="single_cross", padding_value=self.params.pad_token_id) if mode == "train": self_attention_bias = decoder.get_attention_bias( inputs["target_ids"], bias_type="decoder_self") decoder_inputs = dict( attention_bias=encoder_decoder_attention_bias, all_encoder_outputs=all_encoder_outputs, target_ids=inputs["target_ids"], self_attention_bias=self_attention_bias) decoder_outputs = self.decoder_layer(decoder_inputs) return self.train_decode(decoder_outputs) batch_size = tf.shape(input_ids)[0] start_token_ids = tf.ones([batch_size], tf.int32) * self.params.start_token_id # Add encoder output and attention bias to the cache. if self.params.use_cache: cache = self._init_cache(batch_size) else: cache = {} cache["all_encoder_outputs"] = all_encoder_outputs cache["attention_bias"] = encoder_decoder_attention_bias decoded_ids, scores = self.predict_decode(start_token_ids, cache) if mode == "predict": return decoded_ids[:, :self.params.beam_size, 1:], scores[:, :self.params.beam_size] decoder_inputs = dict( attention_bias=encoder_decoder_attention_bias, all_encoder_outputs=all_encoder_outputs) top_decoded_ids = decoded_ids[:, 0, 1:] return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids) class NHNet(Bert2Bert): """NHNet model which performs multi-doc decoding.""" def __init__(self, params, bert_layer, decoder_layer, name=None): super(NHNet, self).__init__(params, bert_layer, decoder_layer, name=name) self.doc_attention = multi_channel_attention.VotingAttention( num_heads=params.num_decoder_attn_heads, head_size=params.hidden_size // params.num_decoder_attn_heads) def _expand_doc_attention_probs(self, doc_attention_probs, target_length): """Expands doc attention probs to fit the decoding sequence length.""" doc_attention_probs = tf.expand_dims( doc_attention_probs, axis=[1]) # [B, 1, A] doc_attention_probs = tf.expand_dims( doc_attention_probs, axis=[2]) # [B, 1, 1, A] return tf.tile(doc_attention_probs, [1, self.params.num_decoder_attn_heads, target_length, 1]) def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" # Max decode length should be smaller than the positional embedding max # sequence length. decoder_self_attention_bias = decoder.get_attention_bias( input_tensor=None, bias_type="decoder_self", max_length=max_decode_length) def _symbols_to_logits_fn(ids, i, cache): """Generate logits for next candidate IDs.""" if self.params.use_cache: target_length = 1 else: target_length = i + 1 decoder_inputs = dict( doc_attention_probs=self._expand_doc_attention_probs( cache["doc_attention_probs"], target_length), all_encoder_outputs=cache["all_encoder_outputs"], attention_bias=cache["attention_bias"]) logits = self.get_decode_logits( decoder_inputs, ids, decoder_self_attention_bias, step=i, cache=cache if self.params.use_cache else None) return logits, cache return _symbols_to_logits_fn def call(self, inputs, mode="training"): input_shape = tf_utils.get_shape_list(inputs["input_ids"], expected_rank=3) batch_size, num_docs, len_passage = (input_shape[0], input_shape[1], input_shape[2]) input_ids = tf.reshape(inputs["input_ids"], [-1, len_passage]) input_mask = tf.reshape(inputs["input_mask"], [-1, len_passage]) segment_ids = tf.reshape(inputs["segment_ids"], [-1, len_passage]) all_encoder_outputs, _ = self.bert_layer( [input_ids, input_mask, segment_ids]) encoder_outputs = tf.reshape( all_encoder_outputs[-1], [batch_size, num_docs, len_passage, self.params.hidden_size]) doc_attention_mask = tf.reshape( tf.cast( tf.math.count_nonzero(input_mask, axis=1, dtype=tf.int32) > 2, tf.int32), [batch_size, num_docs]) doc_attention_probs = self.doc_attention(encoder_outputs, doc_attention_mask) encoder_decoder_attention_bias = decoder.get_attention_bias( inputs["input_ids"], bias_type="multi_cross", padding_value=self.params.pad_token_id) if mode == "train": target_length = tf_utils.get_shape_list( inputs["target_ids"], expected_rank=2)[1] doc_attention_probs = self._expand_doc_attention_probs( doc_attention_probs, target_length) self_attention_bias = decoder.get_attention_bias( inputs["target_ids"], bias_type="decoder_self") decoder_inputs = dict( attention_bias=encoder_decoder_attention_bias, self_attention_bias=self_attention_bias, target_ids=inputs["target_ids"], all_encoder_outputs=encoder_outputs, doc_attention_probs=doc_attention_probs) decoder_outputs = self.decoder_layer(decoder_inputs) return self.train_decode(decoder_outputs) # Adds encoder output and attention bias to the cache. if self.params.use_cache: cache = self._init_cache(batch_size) else: cache = {} cache["all_encoder_outputs"] = [encoder_outputs] cache["attention_bias"] = encoder_decoder_attention_bias cache["doc_attention_probs"] = doc_attention_probs start_token_ids = tf.ones([batch_size], tf.int32) * self.params.start_token_id decoded_ids, scores = self.predict_decode(start_token_ids, cache) if mode == "predict": return decoded_ids[:, :self.params.beam_size, 1:], scores[:, :self.params.beam_size] top_decoded_ids = decoded_ids[:, 0, 1:] target_length = tf_utils.get_shape_list(top_decoded_ids)[-1] decoder_inputs = dict( attention_bias=encoder_decoder_attention_bias, all_encoder_outputs=[encoder_outputs], doc_attention_probs=self._expand_doc_attention_probs( doc_attention_probs, target_length)) return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids) def get_bert2bert_layers(params: configs.BERT2BERTConfig): """Creates a Bert2Bert stem model and returns Bert encoder/decoder. We use funtional-style to create stem model because we need to make all layers built to restore variables in a customized way. The layers are called with placeholder inputs to make them fully built. Args: params: ParamsDict. Returns: two keras Layers, bert_model_layer and decoder_layer """ input_ids = tf.keras.layers.Input( shape=(None,), name="input_ids", dtype=tf.int32) input_mask = tf.keras.layers.Input( shape=(None,), name="input_mask", dtype=tf.int32) segment_ids = tf.keras.layers.Input( shape=(None,), name="segment_ids", dtype=tf.int32) target_ids = tf.keras.layers.Input( shape=(None,), name="target_ids", dtype=tf.int32) bert_config = utils.get_bert_config_from_params(params) bert_model_layer = networks.TransformerEncoder( vocab_size=bert_config.vocab_size, hidden_size=bert_config.hidden_size, num_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, activation=tf_utils.get_activation(bert_config.hidden_act), dropout_rate=bert_config.hidden_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob, sequence_length=None, max_sequence_length=bert_config.max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range), return_all_encoder_outputs=True, name="bert_encoder") all_encoder_outputs, _ = bert_model_layer( [input_ids, input_mask, segment_ids]) # pylint: disable=protected-access decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) # pylint: enable=protected-access cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")( input_ids) self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( target_ids) decoder_inputs = dict( attention_bias=cross_attention_bias, self_attention_bias=self_attention_bias, target_ids=target_ids, all_encoder_outputs=all_encoder_outputs) _ = decoder_layer(decoder_inputs) return bert_model_layer, decoder_layer def get_nhnet_layers(params: configs.NHNetConfig): """Creates a Mult-doc encoder/decoder. Args: params: ParamsDict. Returns: two keras Layers, bert_model_layer and decoder_layer """ input_ids = tf.keras.layers.Input( shape=(None,), name="input_ids", dtype=tf.int32) input_mask = tf.keras.layers.Input( shape=(None,), name="input_mask", dtype=tf.int32) segment_ids = tf.keras.layers.Input( shape=(None,), name="segment_ids", dtype=tf.int32) bert_config = utils.get_bert_config_from_params(params) bert_model_layer = networks.TransformerEncoder( vocab_size=bert_config.vocab_size, hidden_size=bert_config.hidden_size, num_layers=bert_config.num_hidden_layers, num_attention_heads=bert_config.num_attention_heads, intermediate_size=bert_config.intermediate_size, activation=tf_utils.get_activation(bert_config.hidden_act), dropout_rate=bert_config.hidden_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob, sequence_length=None, max_sequence_length=bert_config.max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=bert_config.initializer_range), return_all_encoder_outputs=True, name="bert_encoder") bert_model_layer([input_ids, input_mask, segment_ids]) input_ids = tf.keras.layers.Input( shape=(None, None), name="input_ids", dtype=tf.int32) all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size), dtype=tf.float32) target_ids = tf.keras.layers.Input( shape=(None,), name="target_ids", dtype=tf.int32) doc_attention_probs = tf.keras.layers.Input( (params.num_decoder_attn_heads, None, None), dtype=tf.float32) # pylint: disable=protected-access decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) # pylint: enable=protected-access cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")( input_ids) self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( target_ids) decoder_inputs = dict( attention_bias=cross_attention_bias, self_attention_bias=self_attention_bias, target_ids=target_ids, all_encoder_outputs=all_encoder_outputs, doc_attention_probs=doc_attention_probs) _ = decoder_layer(decoder_inputs) return bert_model_layer, decoder_layer def create_transformer_model(params, init_checkpoint: Optional[Text] = None ) -> tf.keras.Model: """A helper to create Transformer model.""" bert_layer, decoder_layer = get_bert2bert_layers(params=params) model = Bert2Bert( params=params, bert_layer=bert_layer, decoder_layer=decoder_layer, name="transformer") if init_checkpoint: logging.info( "Checkpoint file %s found and restoring from " "initial checkpoint.", init_checkpoint) ckpt = tf.train.Checkpoint(model=model) ckpt.restore(init_checkpoint).expect_partial() return model def create_bert2bert_model( params: configs.BERT2BERTConfig, cls=Bert2Bert, init_checkpoint: Optional[Text] = None) -> tf.keras.Model: """A helper to create Bert2Bert model.""" bert_layer, decoder_layer = get_bert2bert_layers(params=params) if init_checkpoint: utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer, init_checkpoint) return cls( params=params, bert_layer=bert_layer, decoder_layer=decoder_layer, name="bert2bert") def create_nhnet_model( params: configs.NHNetConfig, cls=NHNet, init_checkpoint: Optional[Text] = None) -> tf.keras.Model: """A helper to create NHNet model.""" bert_layer, decoder_layer = get_nhnet_layers(params=params) model = cls( params=params, bert_layer=bert_layer, decoder_layer=decoder_layer, name="nhnet") if init_checkpoint: logging.info( "Checkpoint file %s found and restoring from " "initial checkpoint.", init_checkpoint) if params.init_from_bert2bert: ckpt = tf.train.Checkpoint(model=model) ckpt.restore(init_checkpoint).assert_existing_objects_matched() else: utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer, init_checkpoint) return model @gin.configurable def get_model_params(model: Optional[Text] = "bert2bert", config_class=None) -> params_dict.ParamsDict: """Helper function to convert config file to ParamsDict.""" if model == "bert2bert": return configs.BERT2BERTConfig() elif model == "nhnet": return configs.NHNetConfig() elif config_class: return config_class() else: raise KeyError("The model type is not defined: %s" % model) @gin.configurable def create_model(model_type: Text, params, init_checkpoint: Optional[Text] = None): """A factory function to create different types of models.""" if model_type == "bert2bert": return create_bert2bert_model(params, init_checkpoint=init_checkpoint) elif model_type == "nhnet": return create_nhnet_model(params, init_checkpoint=init_checkpoint) elif "transformer" in model_type: return create_transformer_model( params, init_checkpoint=init_checkpoint) else: raise KeyError("The model type is not defined: %s" % model_type)