Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for sequence_layers.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.contrib import slim | |
import model | |
import sequence_layers | |
def fake_net(batch_size, num_features, feature_size): | |
return tf.convert_to_tensor( | |
np.random.uniform(size=(batch_size, num_features, feature_size)), | |
dtype=tf.float32) | |
def fake_labels(batch_size, seq_length, num_char_classes): | |
labels_np = tf.convert_to_tensor( | |
np.random.randint( | |
low=0, high=num_char_classes, size=(batch_size, seq_length))) | |
return slim.one_hot_encoding(labels_np, num_classes=num_char_classes) | |
def create_layer(layer_class, batch_size, seq_length, num_char_classes): | |
model_params = model.ModelParams( | |
num_char_classes=num_char_classes, | |
seq_length=seq_length, | |
num_views=1, | |
null_code=num_char_classes) | |
net = fake_net( | |
batch_size=batch_size, num_features=seq_length * 5, feature_size=6) | |
labels_one_hot = fake_labels(batch_size, seq_length, num_char_classes) | |
layer_params = sequence_layers.SequenceLayerParams( | |
num_lstm_units=10, weight_decay=0.00004, lstm_state_clip_value=10.0) | |
return layer_class(net, labels_one_hot, model_params, layer_params) | |
class SequenceLayersTest(tf.test.TestCase): | |
def test_net_slice_char_logits_with_correct_shape(self): | |
batch_size = 2 | |
seq_length = 4 | |
num_char_classes = 3 | |
layer = create_layer(sequence_layers.NetSlice, batch_size, seq_length, | |
num_char_classes) | |
char_logits = layer.create_logits() | |
self.assertEqual( | |
tf.TensorShape([batch_size, seq_length, num_char_classes]), | |
char_logits.get_shape()) | |
def test_net_slice_with_autoregression_char_logits_with_correct_shape(self): | |
batch_size = 2 | |
seq_length = 4 | |
num_char_classes = 3 | |
layer = create_layer(sequence_layers.NetSliceWithAutoregression, | |
batch_size, seq_length, num_char_classes) | |
char_logits = layer.create_logits() | |
self.assertEqual( | |
tf.TensorShape([batch_size, seq_length, num_char_classes]), | |
char_logits.get_shape()) | |
def test_attention_char_logits_with_correct_shape(self): | |
batch_size = 2 | |
seq_length = 4 | |
num_char_classes = 3 | |
layer = create_layer(sequence_layers.Attention, batch_size, seq_length, | |
num_char_classes) | |
char_logits = layer.create_logits() | |
self.assertEqual( | |
tf.TensorShape([batch_size, seq_length, num_char_classes]), | |
char_logits.get_shape()) | |
def test_attention_with_autoregression_char_logits_with_correct_shape(self): | |
batch_size = 2 | |
seq_length = 4 | |
num_char_classes = 3 | |
layer = create_layer(sequence_layers.AttentionWithAutoregression, | |
batch_size, seq_length, num_char_classes) | |
char_logits = layer.create_logits() | |
self.assertEqual( | |
tf.TensorShape([batch_size, seq_length, num_char_classes]), | |
char_logits.get_shape()) | |
if __name__ == '__main__': | |
tf.test.main() | |