Spaces:
Running
Running
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for Keras-based gated feedforward layer.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from absl.testing import parameterized | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import | |
from official.nlp.modeling.layers import gated_feedforward | |
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It | |
# guarantees forward compatibility of this code for the V2 switchover. | |
class GatedFeedforwardTest(keras_parameterized.TestCase): | |
def tearDown(self): | |
super(GatedFeedforwardTest, self).tearDown() | |
tf.keras.mixed_precision.experimental.set_policy("float32") | |
def test_layer_creation(self, use_gate, num_blocks, dropout_position, dtype): | |
tf.keras.mixed_precision.experimental.set_policy(dtype) | |
kwargs = dict( | |
intermediate_size=128, | |
intermediate_activation="relu", | |
dropout=0.1, | |
use_gate=use_gate, | |
num_blocks=num_blocks, | |
dropout_position=dropout_position, | |
kernel_initializer="glorot_uniform", | |
bias_initializer="zeros") | |
test_layer = gated_feedforward.GatedFeedforward(**kwargs) | |
sequence_length = 64 | |
width = 128 | |
# Create a 3-dimensional input (the first dimension is implicit). | |
data_tensor = tf.keras.Input(shape=(sequence_length, width)) | |
output_tensor = test_layer(data_tensor) | |
# The default output of a transformer layer should be the same as the input. | |
self.assertEqual(data_tensor.shape.as_list(), output_tensor.shape.as_list()) | |
def test_layer_invocation(self, use_gate, num_blocks, dropout_position, | |
dtype): | |
tf.keras.mixed_precision.experimental.set_policy(dtype) | |
kwargs = dict( | |
intermediate_size=16, | |
intermediate_activation="relu", | |
dropout=0.1, | |
use_gate=use_gate, | |
num_blocks=num_blocks, | |
dropout_position=dropout_position, | |
kernel_initializer="glorot_uniform", | |
bias_initializer="zeros") | |
test_layer = gated_feedforward.GatedFeedforward(**kwargs) | |
sequence_length = 16 | |
width = 32 | |
# Create a 3-dimensional input (the first dimension is implicit). | |
data_tensor = tf.keras.Input(shape=(sequence_length, width)) | |
output_tensor = test_layer(data_tensor) | |
# Create a model from the test layer. | |
model = tf.keras.Model(data_tensor, output_tensor) | |
# Invoke the model on test data. | |
batch_size = 6 | |
input_data = 10 * np.random.random_sample( | |
(batch_size, sequence_length, width)) | |
output_data = model.predict(input_data) | |
self.assertEqual(output_data.shape, (batch_size, sequence_length, width)) | |
def test_serialize_deserialize(self): | |
kwargs = dict( | |
intermediate_size=16, | |
intermediate_activation="relu", | |
dropout=0.1, | |
use_gate=False, | |
num_blocks=4, | |
dropout_position="after_residual", | |
kernel_initializer="glorot_uniform", | |
bias_initializer="zeros") | |
test_layer = gated_feedforward.GatedFeedforward(**kwargs) | |
new_layer = gated_feedforward.GatedFeedforward.from_config( | |
test_layer.get_config()) | |
# If the serialization was successful, the new config should match the old. | |
self.assertAllEqual(test_layer.get_config(), new_layer.get_config()) | |
if __name__ == "__main__": | |
tf.test.main() | |