NCTC / models /official /nlp /nhnet /models_test.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
11.8 kB
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for nlp.nhnet.models."""
import os
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
# pylint: enable=g-direct-tensorflow-import
from official.nlp.nhnet import configs
from official.nlp.nhnet import models
from official.nlp.nhnet import utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
],
mode="eager",
)
def distribution_forward_path(strategy,
model,
inputs,
batch_size,
mode="train"):
dataset = tf.data.Dataset.from_tensor_slices((inputs))
dataset = dataset.batch(batch_size)
dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def test_step(inputs):
"""Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
return model(inputs, mode=mode, training=False)
outputs = strategy.run(_test_step_fn, args=(inputs,))
return tf.nest.map_structure(strategy.experimental_local_results, outputs)
return [test_step(inputs) for inputs in dataset]
def process_decoded_ids(predictions, end_token_id):
"""Transforms decoded tensors to lists ending with END_TOKEN_ID."""
if isinstance(predictions, tf.Tensor):
predictions = predictions.numpy()
flatten_ids = predictions.reshape((-1, predictions.shape[-1]))
results = []
for ids in flatten_ids:
ids = list(ids)
if end_token_id in ids:
ids = ids[:ids.index(end_token_id)]
results.append(ids)
return results
class Bert2BertTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(Bert2BertTest, self).setUp()
self._config = utils.get_test_params()
def test_model_creation(self):
model = models.create_bert2bert_model(params=self._config)
fake_ids = np.zeros((2, 10), dtype=np.int32)
fake_inputs = {
"input_ids": fake_ids,
"input_mask": fake_ids,
"segment_ids": fake_ids,
"target_ids": fake_ids,
}
model(fake_inputs)
@combinations.generate(all_strategy_combinations())
def test_bert2bert_train_forward(self, distribution):
seq_length = 10
# Defines the model inside distribution strategy scope.
with distribution.scope():
# Forward path.
batch_size = 2
batches = 4
fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32)
fake_inputs = {
"input_ids": fake_ids,
"input_mask": fake_ids,
"segment_ids": fake_ids,
"target_ids": fake_ids,
}
model = models.create_bert2bert_model(params=self._config)
results = distribution_forward_path(distribution, model, fake_inputs,
batch_size)
logging.info("Forward path results: %s", str(results))
self.assertLen(results, batches)
def test_bert2bert_decoding(self):
seq_length = 10
self._config.override(
{
"beam_size": 3,
"len_title": seq_length,
"alpha": 0.6,
},
is_strict=False)
batch_size = 2
fake_ids = np.zeros((batch_size, seq_length), dtype=np.int32)
fake_inputs = {
"input_ids": fake_ids,
"input_mask": fake_ids,
"segment_ids": fake_ids,
}
self._config.override({
"padded_decode": False,
"use_cache": False,
},
is_strict=False)
model = models.create_bert2bert_model(params=self._config)
ckpt = tf.train.Checkpoint(model=model)
# Initializes variables from checkpoint to keep outputs deterministic.
init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
ckpt.restore(init_checkpoint).assert_existing_objects_matched()
top_ids, scores = model(fake_inputs, mode="predict")
self._config.override({
"padded_decode": False,
"use_cache": True,
},
is_strict=False)
model = models.create_bert2bert_model(params=self._config)
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore(init_checkpoint).assert_existing_objects_matched()
cached_top_ids, cached_scores = model(fake_inputs, mode="predict")
self.assertEqual(
process_decoded_ids(top_ids, self._config.end_token_id),
process_decoded_ids(cached_top_ids, self._config.end_token_id))
self.assertAllClose(scores, cached_scores)
self._config.override({
"padded_decode": True,
"use_cache": True,
},
is_strict=False)
model = models.create_bert2bert_model(params=self._config)
ckpt = tf.train.Checkpoint(model=model)
ckpt.restore(init_checkpoint).assert_existing_objects_matched()
padded_top_ids, padded_scores = model(fake_inputs, mode="predict")
self.assertEqual(
process_decoded_ids(top_ids, self._config.end_token_id),
process_decoded_ids(padded_top_ids, self._config.end_token_id))
self.assertAllClose(scores, padded_scores)
@combinations.generate(all_strategy_combinations())
def test_bert2bert_eval(self, distribution):
seq_length = 10
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
self._config.override(
{
"beam_size": 3,
"len_title": seq_length,
"alpha": 0.6,
"padded_decode": padded_decode,
},
is_strict=False)
# Defines the model inside distribution strategy scope.
with distribution.scope():
# Forward path.
batch_size = 2
batches = 4
fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32)
fake_inputs = {
"input_ids": fake_ids,
"input_mask": fake_ids,
"segment_ids": fake_ids,
}
model = models.create_bert2bert_model(params=self._config)
results = distribution_forward_path(
distribution, model, fake_inputs, batch_size, mode="predict")
self.assertLen(results, batches)
results = distribution_forward_path(
distribution, model, fake_inputs, batch_size, mode="eval")
self.assertLen(results, batches)
class NHNetTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(NHNetTest, self).setUp()
self._nhnet_config = configs.NHNetConfig()
self._nhnet_config.override(utils.get_test_params().as_dict())
self._bert2bert_config = configs.BERT2BERTConfig()
self._bert2bert_config.override(utils.get_test_params().as_dict())
def _count_params(self, layer, trainable_only=True):
"""Returns the count of all model parameters, or just trainable ones."""
if not trainable_only:
return layer.count_params()
else:
return int(
np.sum([
tf.keras.backend.count_params(p) for p in layer.trainable_weights
]))
def test_create_nhnet_layers(self):
single_doc_bert, single_doc_decoder = models.get_bert2bert_layers(
self._bert2bert_config)
multi_doc_bert, multi_doc_decoder = models.get_nhnet_layers(
self._nhnet_config)
# Expects multi-doc encoder/decoder have the same number of parameters as
# single-doc encoder/decoder.
self.assertEqual(
self._count_params(multi_doc_bert), self._count_params(single_doc_bert))
self.assertEqual(
self._count_params(multi_doc_decoder),
self._count_params(single_doc_decoder))
def test_checkpoint_restore(self):
bert2bert_model = models.create_bert2bert_model(self._bert2bert_config)
ckpt = tf.train.Checkpoint(model=bert2bert_model)
init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
nhnet_model = models.create_nhnet_model(
params=self._nhnet_config, init_checkpoint=init_checkpoint)
source_weights = (
bert2bert_model.bert_layer.trainable_weights +
bert2bert_model.decoder_layer.trainable_weights)
dest_weights = (
nhnet_model.bert_layer.trainable_weights +
nhnet_model.decoder_layer.trainable_weights)
for source_weight, dest_weight in zip(source_weights, dest_weights):
self.assertAllClose(source_weight.numpy(), dest_weight.numpy())
@combinations.generate(all_strategy_combinations())
def test_nhnet_train_forward(self, distribution):
seq_length = 10
# Defines the model inside distribution strategy scope.
with distribution.scope():
# Forward path.
batch_size = 2
num_docs = 2
batches = 4
fake_ids = np.zeros((batch_size * batches, num_docs, seq_length),
dtype=np.int32)
fake_inputs = {
"input_ids":
fake_ids,
"input_mask":
fake_ids,
"segment_ids":
fake_ids,
"target_ids":
np.zeros((batch_size * batches, seq_length * 2), dtype=np.int32),
}
model = models.create_nhnet_model(params=self._nhnet_config)
results = distribution_forward_path(distribution, model, fake_inputs,
batch_size)
logging.info("Forward path results: %s", str(results))
self.assertLen(results, batches)
@combinations.generate(all_strategy_combinations())
def test_nhnet_eval(self, distribution):
seq_length = 10
padded_decode = isinstance(distribution,
tf.distribute.experimental.TPUStrategy)
self._nhnet_config.override(
{
"beam_size": 4,
"len_title": seq_length,
"alpha": 0.6,
"multi_channel_cross_attention": True,
"padded_decode": padded_decode,
},
is_strict=False)
# Defines the model inside distribution strategy scope.
with distribution.scope():
# Forward path.
batch_size = 2
num_docs = 2
batches = 4
fake_ids = np.zeros((batch_size * batches, num_docs, seq_length),
dtype=np.int32)
fake_inputs = {
"input_ids": fake_ids,
"input_mask": fake_ids,
"segment_ids": fake_ids,
"target_ids": np.zeros((batch_size * batches, 5), dtype=np.int32),
}
model = models.create_nhnet_model(params=self._nhnet_config)
results = distribution_forward_path(
distribution, model, fake_inputs, batch_size, mode="predict")
self.assertLen(results, batches)
results = distribution_forward_path(
distribution, model, fake_inputs, batch_size, mode="eval")
self.assertLen(results, batches)
if __name__ == "__main__":
tf.test.main()