NCTC / models /official /nlp /tasks /question_answering.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
6 kB
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Question answering task."""
import logging
import dataclasses
import tensorflow as tf
import tensorflow_hub as hub
from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.bert import input_pipeline
from official.nlp.configs import encoders
from official.nlp.modeling import models
from official.nlp.tasks import utils
@dataclasses.dataclass
class QuestionAnsweringConfig(cfg.TaskConfig):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint: str = ''
hub_module_url: str = ''
network: encoders.TransformerEncoderConfig = (
encoders.TransformerEncoderConfig())
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
@base_task.register_task_cls(QuestionAnsweringConfig)
class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering.
TODO(lehou): Add post-processing.
"""
def __init__(self, params=cfg.TaskConfig):
super(QuestionAnsweringTask, self).__init__(params)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
def build_model(self):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else:
encoder_network = encoders.instantiate_encoder_from_cfg(
self.task_config.network)
return models.BertSpanLabeler(
network=encoder_network,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.network.initializer_range))
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
start_positions = labels['start_positions']
end_positions = labels['end_positions']
start_logits, end_logits = model_outputs
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions,
tf.cast(start_logits, dtype=tf.float32),
from_logits=True)
end_loss = tf.keras.losses.sparse_categorical_crossentropy(
end_positions,
tf.cast(end_logits, dtype=tf.float32),
from_logits=True)
loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
return loss
def build_inputs(self, params, input_context=None):
"""Returns tf.data.Dataset for sentence_prediction task."""
if params.input_path == 'dummy':
def dummy_data(_):
dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
x = dict(
input_word_ids=dummy_ids,
input_mask=dummy_ids,
input_type_ids=dummy_ids)
y = dict(
start_positions=tf.constant(0, dtype=tf.int32),
end_positions=tf.constant(1, dtype=tf.int32))
return (x, y)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
batch_size = input_context.get_per_replica_batch_size(
params.global_batch_size) if input_context else params.global_batch_size
# TODO(chendouble): add and use nlp.data.question_answering_dataloader.
dataset = input_pipeline.create_squad_dataset(
params.input_path,
params.seq_length,
batch_size,
is_training=params.is_training,
input_pipeline_context=input_context)
return dataset
def build_metrics(self, training=None):
del training
# TODO(lehou): a list of metrics doesn't work the same as in compile/fit.
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(
name='start_position_accuracy'),
tf.keras.metrics.SparseCategoricalAccuracy(
name='end_position_accuracy'),
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
metrics = dict([(metric.name, metric) for metric in metrics])
start_logits, end_logits = model_outputs
metrics['start_position_accuracy'].update_state(
labels['start_positions'], start_logits)
metrics['end_position_accuracy'].update_state(
labels['end_positions'], end_logits)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
start_logits, end_logits = model_outputs
compiled_metrics.update_state(
y_true=labels, # labels has keys 'start_positions' and 'end_positions'.
y_pred={'start_positions': start_logits, 'end_positions': end_logits})
def initialize(self, model):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)