NCTC / models /official /nlp /modeling /layers /position_embedding.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
8.2 kB
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based positional embedding layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import math
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package="Text")
class PositionEmbedding(tf.keras.layers.Layer):
"""Creates a positional embedding.
This layer creates a positional embedding as described in "BERT: Pre-training
of Deep Bidirectional Transformers for Language Understanding"
(https://arxiv.org/abs/1810.04805).
This layer can be set up to either create a statically shaped slice or a
dynamically shaped slice. If `use_dynamic_slicing` is True, the input tensor
can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the
input size must be fixed.
Arguments:
use_dynamic_slicing: Whether to use the dynamic slicing path.
max_sequence_length: The maximum size of the dynamic sequence. Only
applicable if `use_dynamic_slicing` is True.
initializer: The initializer to use for the embedding weights. Defaults to
"glorot_uniform".
"""
def __init__(self,
initializer="glorot_uniform",
use_dynamic_slicing=False,
max_sequence_length=None,
**kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
if "dtype" not in kwargs:
kwargs["dtype"] = "float32"
super(PositionEmbedding, self).__init__(**kwargs)
if use_dynamic_slicing and max_sequence_length is None:
raise ValueError(
"If `use_dynamic_slicing` is True, `max_sequence_length` must be set."
)
self._max_sequence_length = max_sequence_length
self._initializer = tf.keras.initializers.get(initializer)
self._use_dynamic_slicing = use_dynamic_slicing
def get_config(self):
config = {
"max_sequence_length": self._max_sequence_length,
"initializer": tf.keras.initializers.serialize(self._initializer),
"use_dynamic_slicing": self._use_dynamic_slicing,
}
base_config = super(PositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
"""Implements build() for the layer."""
dimension_list = input_shape.as_list()
if len(dimension_list) != 3:
raise ValueError("PositionEmbedding expects a 3-dimensional input tensor "
"of shape [batch, sequence, width]")
seq_length = dimension_list[1]
width = dimension_list[2]
# If we are not using dynamic slicing, we must assume that the sequence
# length is fixed and max_sequence_length should not be specified.
if not self._use_dynamic_slicing:
if seq_length is None:
raise ValueError(
"PositionEmbedding must have `use_dynamic_slicing` set "
"to True (and max_sequence_length set) when the "
"sequence (1st) dimension of the input is None.")
if self._max_sequence_length is not None:
raise ValueError(
"When `use_dynamic_slicing` is False, max_sequence_length should "
"not be specified and we ought to use seq_length to get the "
"variable shape.")
if self._max_sequence_length is not None:
weight_sequence_length = self._max_sequence_length
else:
weight_sequence_length = seq_length
self._position_embeddings = self.add_weight(
"embeddings",
shape=[weight_sequence_length, width],
initializer=self._initializer)
super(PositionEmbedding, self).build(input_shape)
def call(self, inputs):
"""Implements call() for the layer."""
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
if self._use_dynamic_slicing:
position_embeddings = self._position_embeddings[:input_shape[1], :]
else:
position_embeddings = self._position_embeddings
return tf.broadcast_to(position_embeddings, input_shape)
@tf.keras.utils.register_keras_serializable(package="Text")
class RelativePositionEmbedding(tf.keras.layers.Layer):
"""Creates a positional embedding.
This layer calculates the position encoding as a mix of sine and cosine
functions with geometrically increasing wavelengths. Defined and formulized in
"Attention is All You Need", section 3.5.
(https://arxiv.org/abs/1706.03762).
Arguments:
hidden_size: Size of the hidden layer.
min_timescale: Minimum scale that will be applied at each position
max_timescale: Maximum scale that will be applied at each position.
"""
def __init__(self,
hidden_size,
min_timescale=1.0,
max_timescale=1.0e4,
**kwargs):
# We need to have a default dtype of float32, since the inputs (which Keras
# usually uses to infer the dtype) will always be int32.
# We compute the positional encoding in float32 even if the model uses
# float16, as many of the ops used, like log and exp, are numerically
# unstable in float16.
if "dtype" not in kwargs:
kwargs["dtype"] = "float32"
super(RelativePositionEmbedding, self).__init__(**kwargs)
self._hidden_size = hidden_size
self._min_timescale = min_timescale
self._max_timescale = max_timescale
def get_config(self):
config = {
"hidden_size": self._hidden_size,
"min_timescale": self._min_timescale,
"max_timescale": self._max_timescale,
"length": self._length,
}
base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, length=None):
"""Implements call() for the layer.
Args:
inputs: An tensor whose second dimension will be used as `length`. If
`None`, the other `length` argument must be specified.
length: An optional integer specifying the number of positions. If both
`inputs` and `length` are spcified, `length` must be equal to the
second dimension of `inputs`.
Returns:
A tensor in shape of [length, hidden_size].
"""
if inputs is None and length is None:
raise ValueError(
"If inputs is None, `length` must be set in "
"RelativePositionEmbedding().")
if inputs is not None:
input_shape = tf_utils.get_shape_list(inputs)
if length is not None and length != input_shape[1]:
raise ValueError(
"If inputs is not None, `length` must equal to input_shape[1]."
)
length = input_shape[1]
position = tf.cast(tf.range(length), tf.float32)
num_timescales = self._hidden_size // 2
min_timescale, max_timescale = self._min_timescale, self._max_timescale
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
(tf.cast(num_timescales, tf.float32) - 1))
inv_timescales = min_timescale * tf.exp(
tf.cast(tf.range(num_timescales), tf.float32) *
-log_timescale_increment)
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales,
0)
position_embeddings = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)],
axis=1)
return position_embeddings