NCTC / models /official /core /base_task.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
10.6 kB
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the base task abstraction."""
import abc
import functools
from typing import Any, Callable, Optional
import six
import tensorflow as tf
from official.modeling.hyperparams import config_definitions as cfg
from official.utils import registry
@six.add_metaclass(abc.ABCMeta)
class Task(tf.Module):
"""A single-replica view of training procedure.
Tasks provide artifacts for training/evalution procedures, including
loading/iterating over Datasets, initializing the model, calculating the loss
and customized metrics with reduction.
"""
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self, params: cfg.TaskConfig):
self._task_config = params
@property
def task_config(self) -> cfg.TaskConfig:
return self._task_config
def initialize(self, model: tf.keras.Model):
"""A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir.
Args:
model: The keras.Model built or used by this task.
"""
pass
@abc.abstractmethod
def build_model(self) -> tf.keras.Model:
"""Creates the model architecture.
Returns:
A model instance.
"""
def compile_model(self,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
loss=None,
train_step: Optional[Callable[..., Any]] = None,
validation_step: Optional[Callable[..., Any]] = None,
**kwargs) -> tf.keras.Model:
"""Compiles the model with objects created by the task.
The method should not be used in any customized training implementation.
Args:
model: a keras.Model.
optimizer: the keras optimizer.
loss: a callable/list of losses.
train_step: optional train step function defined by the task.
validation_step: optional validation_step step function defined by the
task.
**kwargs: other kwargs consumed by keras.Model compile().
Returns:
a compiled keras.Model.
"""
if bool(loss is None) == bool(train_step is None):
raise ValueError("`loss` and `train_step` should be exclusive to "
"each other.")
model.compile(optimizer=optimizer, loss=loss, **kwargs)
if train_step:
model.train_step = functools.partial(
train_step, model=model, optimizer=model.optimizer)
if validation_step:
model.test_step = functools.partial(validation_step, model=model)
return model
@abc.abstractmethod
def build_inputs(self,
params: cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size.
Args:
params: hyperparams to create input pipelines.
input_context: optional distribution input pipeline context.
Returns:
A nested structure of per-replica input functions.
"""
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses.
Args:
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
del model_outputs, labels
if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)]
else:
losses = aux_losses
total_loss = tf.add_n(losses)
return total_loss
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
del training
return []
def process_metrics(self, metrics, labels, model_outputs):
"""Process and update metrics. Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects.
The return of function self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
"""
for metric in metrics:
metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics. call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example, output of the keras model built by self.build_model.
"""
compiled_metrics.update_state(labels, model_outputs)
def train_step(self,
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
else:
features, labels = inputs, inputs
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Computes per-replica loss.
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
# For mixed precision, when a LossScaleOptimizer is used, the loss is
# scaled to avoid numeric underflow.
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
if isinstance(optimizer,
tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
else:
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step."""
pass
def reduce_aggregated_logs(self, aggregated_logs):
"""Optional reduce of aggregated logs over validation steps."""
return {}
_REGISTERED_TASK_CLS = {}
# TODO(b/158268740): Move these outside the base class file.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls