NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
6.94 kB
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for data_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf
from data import data_utils
data = data_utils
class SequenceWrapperTest(tf.test.TestCase):
def testDefaultTimesteps(self):
seq = data.SequenceWrapper()
t1 = seq.add_timestep()
_ = seq.add_timestep()
self.assertEqual(len(seq), 2)
self.assertEqual(t1.weight, 0.0)
self.assertEqual(t1.label, 0)
self.assertEqual(t1.token, 0)
def testSettersAndGetters(self):
ts = data.SequenceWrapper().add_timestep()
ts.set_token(3)
ts.set_label(4)
ts.set_weight(2.0)
self.assertEqual(ts.token, 3)
self.assertEqual(ts.label, 4)
self.assertEqual(ts.weight, 2.0)
def testTimestepIteration(self):
seq = data.SequenceWrapper()
seq.add_timestep().set_token(0)
seq.add_timestep().set_token(1)
seq.add_timestep().set_token(2)
for i, ts in enumerate(seq):
self.assertEqual(ts.token, i)
def testFillsSequenceExampleCorrectly(self):
seq = data.SequenceWrapper()
seq.add_timestep().set_token(1).set_label(2).set_weight(3.0)
seq.add_timestep().set_token(10).set_label(20).set_weight(30.0)
seq_ex = seq.seq
fl = seq_ex.feature_lists.feature_list
fl_token = fl[data.SequenceWrapper.F_TOKEN_ID].feature
fl_label = fl[data.SequenceWrapper.F_LABEL].feature
fl_weight = fl[data.SequenceWrapper.F_WEIGHT].feature
_ = [self.assertEqual(len(f), 2) for f in [fl_token, fl_label, fl_weight]]
self.assertAllEqual([f.int64_list.value[0] for f in fl_token], [1, 10])
self.assertAllEqual([f.int64_list.value[0] for f in fl_label], [2, 20])
self.assertAllEqual([f.float_list.value[0] for f in fl_weight], [3.0, 30.0])
class DataUtilsTest(tf.test.TestCase):
def testSplitByPunct(self):
output = data.split_by_punct(
'hello! world, i\'ve been\nwaiting\tfor\ryou for.a long time')
expected = [
'hello', 'world', 'i', 've', 'been', 'waiting', 'for', 'you', 'for',
'a', 'long', 'time'
]
self.assertListEqual(output, expected)
def _buildDummySequence(self):
seq = data.SequenceWrapper()
for i in range(10):
seq.add_timestep().set_token(i)
return seq
def testBuildLMSeq(self):
seq = self._buildDummySequence()
lm_seq = data.build_lm_sequence(seq)
for i, ts in enumerate(lm_seq):
# For end of sequence, the token and label should be same, and weight
# should be 0.0.
if i == len(lm_seq) - 1:
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, i)
self.assertEqual(ts.weight, 0.0)
else:
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, i + 1)
self.assertEqual(ts.weight, 1.0)
def testBuildSAESeq(self):
seq = self._buildDummySequence()
sa_seq = data.build_seq_ae_sequence(seq)
self.assertEqual(len(sa_seq), len(seq) * 2 - 1)
# Tokens should be sequence twice, minus the EOS token at the end
for i, ts in enumerate(sa_seq):
self.assertEqual(ts.token, seq[i % 10].token)
# Weights should be len-1 0.0's and len 1.0's.
for i in range(len(seq) - 1):
self.assertEqual(sa_seq[i].weight, 0.0)
for i in range(len(seq) - 1, len(sa_seq)):
self.assertEqual(sa_seq[i].weight, 1.0)
# Labels should be len-1 0's, and then the sequence
for i in range(len(seq) - 1):
self.assertEqual(sa_seq[i].label, 0)
for i in range(len(seq) - 1, len(sa_seq)):
self.assertEqual(sa_seq[i].label, seq[i - (len(seq) - 1)].token)
def testBuildLabelSeq(self):
seq = self._buildDummySequence()
eos_id = len(seq) - 1
label_seq = data.build_labeled_sequence(seq, True)
for i, ts in enumerate(label_seq[:-1]):
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, 0)
self.assertEqual(ts.weight, 0.0)
final_timestep = label_seq[-1]
self.assertEqual(final_timestep.token, eos_id)
self.assertEqual(final_timestep.label, 1)
self.assertEqual(final_timestep.weight, 1.0)
def testBuildBidirLabelSeq(self):
seq = self._buildDummySequence()
reverse_seq = data.build_reverse_sequence(seq)
bidir_seq = data.build_bidirectional_seq(seq, reverse_seq)
label_seq = data.build_labeled_sequence(bidir_seq, True)
for (i, ts), j in zip(
enumerate(label_seq[:-1]), reversed(range(len(seq) - 1))):
self.assertAllEqual(ts.tokens, [i, j])
self.assertEqual(ts.label, 0)
self.assertEqual(ts.weight, 0.0)
final_timestep = label_seq[-1]
eos_id = len(seq) - 1
self.assertAllEqual(final_timestep.tokens, [eos_id, eos_id])
self.assertEqual(final_timestep.label, 1)
self.assertEqual(final_timestep.weight, 1.0)
def testReverseSeq(self):
seq = self._buildDummySequence()
reverse_seq = data.build_reverse_sequence(seq)
for i, ts in enumerate(reversed(reverse_seq[:-1])):
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, 0)
self.assertEqual(ts.weight, 0.0)
final_timestep = reverse_seq[-1]
eos_id = len(seq) - 1
self.assertEqual(final_timestep.token, eos_id)
self.assertEqual(final_timestep.label, 0)
self.assertEqual(final_timestep.weight, 0.0)
def testBidirSeq(self):
seq = self._buildDummySequence()
reverse_seq = data.build_reverse_sequence(seq)
bidir_seq = data.build_bidirectional_seq(seq, reverse_seq)
for (i, ts), j in zip(
enumerate(bidir_seq[:-1]), reversed(range(len(seq) - 1))):
self.assertAllEqual(ts.tokens, [i, j])
self.assertEqual(ts.label, 0)
self.assertEqual(ts.weight, 0.0)
final_timestep = bidir_seq[-1]
eos_id = len(seq) - 1
self.assertAllEqual(final_timestep.tokens, [eos_id, eos_id])
self.assertEqual(final_timestep.label, 0)
self.assertEqual(final_timestep.weight, 0.0)
def testLabelGain(self):
seq = self._buildDummySequence()
label_seq = data.build_labeled_sequence(seq, True, label_gain=True)
for i, ts in enumerate(label_seq):
self.assertEqual(ts.token, i)
self.assertEqual(ts.label, 1)
self.assertNear(ts.weight, float(i) / (len(seq) - 1), 1e-3)
if __name__ == '__main__':
tf.test.main()