Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Test beam search helper methods.""" | |
import tensorflow.compat.v1 as tf | |
from official.nlp.transformer import beam_search_v1 as beam_search | |
class BeamSearchHelperTests(tf.test.TestCase): | |
def setUp(self): | |
super(BeamSearchHelperTests, self).setUp() | |
tf.compat.v1.disable_eager_execution() | |
def test_expand_to_beam_size(self): | |
x = tf.ones([7, 4, 2, 5]) | |
x = beam_search._expand_to_beam_size(x, 3) | |
with self.session() as sess: | |
shape = sess.run(tf.shape(x)) | |
self.assertAllEqual([7, 3, 4, 2, 5], shape) | |
def test_shape_list(self): | |
y = tf.compat.v1.placeholder(dtype=tf.int32, shape=[]) | |
x = tf.ones([7, y, 2, 5]) | |
shape = beam_search._shape_list(x) | |
self.assertIsInstance(shape[0], int) | |
self.assertIsInstance(shape[1], tf.Tensor) | |
self.assertIsInstance(shape[2], int) | |
self.assertIsInstance(shape[3], int) | |
def test_get_shape_keep_last_dim(self): | |
y = tf.constant(4.0) | |
x = tf.ones([7, tf.cast(tf.sqrt(y), tf.int32), 2, 5]) | |
shape = beam_search._get_shape_keep_last_dim(x) | |
self.assertAllEqual([None, None, None, 5], | |
shape.as_list()) | |
def test_flatten_beam_dim(self): | |
x = tf.ones([7, 4, 2, 5]) | |
x = beam_search._flatten_beam_dim(x) | |
with self.session() as sess: | |
shape = sess.run(tf.shape(x)) | |
self.assertAllEqual([28, 2, 5], shape) | |
def test_unflatten_beam_dim(self): | |
x = tf.ones([28, 2, 5]) | |
x = beam_search._unflatten_beam_dim(x, 7, 4) | |
with self.session() as sess: | |
shape = sess.run(tf.shape(x)) | |
self.assertAllEqual([7, 4, 2, 5], shape) | |
def test_gather_beams(self): | |
x = tf.reshape(tf.range(24), [2, 3, 4]) | |
# x looks like: [[[ 0 1 2 3] | |
# [ 4 5 6 7] | |
# [ 8 9 10 11]] | |
# | |
# [[12 13 14 15] | |
# [16 17 18 19] | |
# [20 21 22 23]]] | |
y = beam_search._gather_beams(x, [[1, 2], [0, 2]], 2, 2) | |
with self.session() as sess: | |
y = sess.run(y) | |
self.assertAllEqual([[[4, 5, 6, 7], | |
[8, 9, 10, 11]], | |
[[12, 13, 14, 15], | |
[20, 21, 22, 23]]], | |
y) | |
def test_gather_topk_beams(self): | |
x = tf.reshape(tf.range(24), [2, 3, 4]) | |
x_scores = [[0, 1, 1], [1, 0, 1]] | |
y = beam_search._gather_topk_beams(x, x_scores, 2, 2) | |
with self.session() as sess: | |
y = sess.run(y) | |
self.assertAllEqual([[[4, 5, 6, 7], | |
[8, 9, 10, 11]], | |
[[12, 13, 14, 15], | |
[20, 21, 22, 23]]], | |
y) | |
if __name__ == "__main__": | |
tf.test.main() | |