Spaces:
Running
Running
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Keras-based positional embedding layer.""" | |
# pylint: disable=g-classes-have-attributes | |
from __future__ import absolute_import | |
from __future__ import division | |
# from __future__ import google_type_annotations | |
from __future__ import print_function | |
import math | |
import tensorflow as tf | |
from official.modeling import tf_utils | |
class PositionEmbedding(tf.keras.layers.Layer): | |
"""Creates a positional embedding. | |
This layer creates a positional embedding as described in "BERT: Pre-training | |
of Deep Bidirectional Transformers for Language Understanding" | |
(https://arxiv.org/abs/1810.04805). | |
This layer can be set up to either create a statically shaped slice or a | |
dynamically shaped slice. If `use_dynamic_slicing` is True, the input tensor | |
can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the | |
input size must be fixed. | |
Arguments: | |
use_dynamic_slicing: Whether to use the dynamic slicing path. | |
max_sequence_length: The maximum size of the dynamic sequence. Only | |
applicable if `use_dynamic_slicing` is True. | |
initializer: The initializer to use for the embedding weights. Defaults to | |
"glorot_uniform". | |
""" | |
def __init__(self, | |
initializer="glorot_uniform", | |
use_dynamic_slicing=False, | |
max_sequence_length=None, | |
**kwargs): | |
# We need to have a default dtype of float32, since the inputs (which Keras | |
# usually uses to infer the dtype) will always be int32. | |
if "dtype" not in kwargs: | |
kwargs["dtype"] = "float32" | |
super(PositionEmbedding, self).__init__(**kwargs) | |
if use_dynamic_slicing and max_sequence_length is None: | |
raise ValueError( | |
"If `use_dynamic_slicing` is True, `max_sequence_length` must be set." | |
) | |
self._max_sequence_length = max_sequence_length | |
self._initializer = tf.keras.initializers.get(initializer) | |
self._use_dynamic_slicing = use_dynamic_slicing | |
def get_config(self): | |
config = { | |
"max_sequence_length": self._max_sequence_length, | |
"initializer": tf.keras.initializers.serialize(self._initializer), | |
"use_dynamic_slicing": self._use_dynamic_slicing, | |
} | |
base_config = super(PositionEmbedding, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def build(self, input_shape): | |
"""Implements build() for the layer.""" | |
dimension_list = input_shape.as_list() | |
if len(dimension_list) != 3: | |
raise ValueError("PositionEmbedding expects a 3-dimensional input tensor " | |
"of shape [batch, sequence, width]") | |
seq_length = dimension_list[1] | |
width = dimension_list[2] | |
# If we are not using dynamic slicing, we must assume that the sequence | |
# length is fixed and max_sequence_length should not be specified. | |
if not self._use_dynamic_slicing: | |
if seq_length is None: | |
raise ValueError( | |
"PositionEmbedding must have `use_dynamic_slicing` set " | |
"to True (and max_sequence_length set) when the " | |
"sequence (1st) dimension of the input is None.") | |
if self._max_sequence_length is not None: | |
raise ValueError( | |
"When `use_dynamic_slicing` is False, max_sequence_length should " | |
"not be specified and we ought to use seq_length to get the " | |
"variable shape.") | |
if self._max_sequence_length is not None: | |
weight_sequence_length = self._max_sequence_length | |
else: | |
weight_sequence_length = seq_length | |
self._position_embeddings = self.add_weight( | |
"embeddings", | |
shape=[weight_sequence_length, width], | |
initializer=self._initializer) | |
super(PositionEmbedding, self).build(input_shape) | |
def call(self, inputs): | |
"""Implements call() for the layer.""" | |
input_shape = tf_utils.get_shape_list(inputs, expected_rank=3) | |
if self._use_dynamic_slicing: | |
position_embeddings = self._position_embeddings[:input_shape[1], :] | |
else: | |
position_embeddings = self._position_embeddings | |
return tf.broadcast_to(position_embeddings, input_shape) | |
class RelativePositionEmbedding(tf.keras.layers.Layer): | |
"""Creates a positional embedding. | |
This layer calculates the position encoding as a mix of sine and cosine | |
functions with geometrically increasing wavelengths. Defined and formulized in | |
"Attention is All You Need", section 3.5. | |
(https://arxiv.org/abs/1706.03762). | |
Arguments: | |
hidden_size: Size of the hidden layer. | |
min_timescale: Minimum scale that will be applied at each position | |
max_timescale: Maximum scale that will be applied at each position. | |
""" | |
def __init__(self, | |
hidden_size, | |
min_timescale=1.0, | |
max_timescale=1.0e4, | |
**kwargs): | |
# We need to have a default dtype of float32, since the inputs (which Keras | |
# usually uses to infer the dtype) will always be int32. | |
# We compute the positional encoding in float32 even if the model uses | |
# float16, as many of the ops used, like log and exp, are numerically | |
# unstable in float16. | |
if "dtype" not in kwargs: | |
kwargs["dtype"] = "float32" | |
super(RelativePositionEmbedding, self).__init__(**kwargs) | |
self._hidden_size = hidden_size | |
self._min_timescale = min_timescale | |
self._max_timescale = max_timescale | |
def get_config(self): | |
config = { | |
"hidden_size": self._hidden_size, | |
"min_timescale": self._min_timescale, | |
"max_timescale": self._max_timescale, | |
"length": self._length, | |
} | |
base_config = super(RelativePositionEmbedding, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
def call(self, inputs, length=None): | |
"""Implements call() for the layer. | |
Args: | |
inputs: An tensor whose second dimension will be used as `length`. If | |
`None`, the other `length` argument must be specified. | |
length: An optional integer specifying the number of positions. If both | |
`inputs` and `length` are spcified, `length` must be equal to the | |
second dimension of `inputs`. | |
Returns: | |
A tensor in shape of [length, hidden_size]. | |
""" | |
if inputs is None and length is None: | |
raise ValueError( | |
"If inputs is None, `length` must be set in " | |
"RelativePositionEmbedding().") | |
if inputs is not None: | |
input_shape = tf_utils.get_shape_list(inputs) | |
if length is not None and length != input_shape[1]: | |
raise ValueError( | |
"If inputs is not None, `length` must equal to input_shape[1]." | |
) | |
length = input_shape[1] | |
position = tf.cast(tf.range(length), tf.float32) | |
num_timescales = self._hidden_size // 2 | |
min_timescale, max_timescale = self._min_timescale, self._max_timescale | |
log_timescale_increment = ( | |
math.log(float(max_timescale) / float(min_timescale)) / | |
(tf.cast(num_timescales, tf.float32) - 1)) | |
inv_timescales = min_timescale * tf.exp( | |
tf.cast(tf.range(num_timescales), tf.float32) * | |
-log_timescale_increment) | |
scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, | |
0) | |
position_embeddings = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], | |
axis=1) | |
return position_embeddings | |