Spaces:
Running
Running
# Lint as: python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Defines the base task abstraction.""" | |
import abc | |
import functools | |
from typing import Any, Callable, Optional | |
import six | |
import tensorflow as tf | |
from official.modeling.hyperparams import config_definitions as cfg | |
from official.utils import registry | |
class Task(tf.Module): | |
"""A single-replica view of training procedure. | |
Tasks provide artifacts for training/evalution procedures, including | |
loading/iterating over Datasets, initializing the model, calculating the loss | |
and customized metrics with reduction. | |
""" | |
# Special keys in train/validate step returned logs. | |
loss = "loss" | |
def __init__(self, params: cfg.TaskConfig): | |
self._task_config = params | |
def task_config(self) -> cfg.TaskConfig: | |
return self._task_config | |
def initialize(self, model: tf.keras.Model): | |
"""A callback function used as CheckpointManager's init_fn. | |
This function will be called when no checkpoint found for the model. | |
If there is a checkpoint, the checkpoint will be loaded and this function | |
will not be called. You can use this callback function to load a pretrained | |
checkpoint, saved under a directory other than the model_dir. | |
Args: | |
model: The keras.Model built or used by this task. | |
""" | |
pass | |
def build_model(self) -> tf.keras.Model: | |
"""Creates the model architecture. | |
Returns: | |
A model instance. | |
""" | |
def compile_model(self, | |
model: tf.keras.Model, | |
optimizer: tf.keras.optimizers.Optimizer, | |
loss=None, | |
train_step: Optional[Callable[..., Any]] = None, | |
validation_step: Optional[Callable[..., Any]] = None, | |
**kwargs) -> tf.keras.Model: | |
"""Compiles the model with objects created by the task. | |
The method should not be used in any customized training implementation. | |
Args: | |
model: a keras.Model. | |
optimizer: the keras optimizer. | |
loss: a callable/list of losses. | |
train_step: optional train step function defined by the task. | |
validation_step: optional validation_step step function defined by the | |
task. | |
**kwargs: other kwargs consumed by keras.Model compile(). | |
Returns: | |
a compiled keras.Model. | |
""" | |
if bool(loss is None) == bool(train_step is None): | |
raise ValueError("`loss` and `train_step` should be exclusive to " | |
"each other.") | |
model.compile(optimizer=optimizer, loss=loss, **kwargs) | |
if train_step: | |
model.train_step = functools.partial( | |
train_step, model=model, optimizer=model.optimizer) | |
if validation_step: | |
model.test_step = functools.partial(validation_step, model=model) | |
return model | |
def build_inputs(self, | |
params: cfg.DataConfig, | |
input_context: Optional[tf.distribute.InputContext] = None): | |
"""Returns a dataset or a nested structure of dataset functions. | |
Dataset functions define per-host datasets with the per-replica batch size. | |
Args: | |
params: hyperparams to create input pipelines. | |
input_context: optional distribution input pipeline context. | |
Returns: | |
A nested structure of per-replica input functions. | |
""" | |
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: | |
"""Standard interface to compute losses. | |
Args: | |
labels: optional label tensors. | |
model_outputs: a nested structure of output tensors. | |
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. | |
Returns: | |
The total loss tensor. | |
""" | |
del model_outputs, labels | |
if aux_losses is None: | |
losses = [tf.constant(0.0, dtype=tf.float32)] | |
else: | |
losses = aux_losses | |
total_loss = tf.add_n(losses) | |
return total_loss | |
def build_metrics(self, training: bool = True): | |
"""Gets streaming metrics for training/validation.""" | |
del training | |
return [] | |
def process_metrics(self, metrics, labels, model_outputs): | |
"""Process and update metrics. Called when using custom training loop API. | |
Args: | |
metrics: a nested structure of metrics objects. | |
The return of function self.build_metrics. | |
labels: a tensor or a nested structure of tensors. | |
model_outputs: a tensor or a nested structure of tensors. | |
For example, output of the keras model built by self.build_model. | |
""" | |
for metric in metrics: | |
metric.update_state(labels, model_outputs) | |
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): | |
"""Process and update compiled_metrics. call when using compile/fit API. | |
Args: | |
compiled_metrics: the compiled metrics (model.compiled_metrics). | |
labels: a tensor or a nested structure of tensors. | |
model_outputs: a tensor or a nested structure of tensors. | |
For example, output of the keras model built by self.build_model. | |
""" | |
compiled_metrics.update_state(labels, model_outputs) | |
def train_step(self, | |
inputs, | |
model: tf.keras.Model, | |
optimizer: tf.keras.optimizers.Optimizer, | |
metrics=None): | |
"""Does forward and backward. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the model, forward pass definition. | |
optimizer: the optimizer for this training step. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
if isinstance(inputs, tuple) and len(inputs) == 2: | |
features, labels = inputs | |
else: | |
features, labels = inputs, inputs | |
with tf.GradientTape() as tape: | |
outputs = model(features, training=True) | |
# Computes per-replica loss. | |
loss = self.build_losses( | |
labels=labels, model_outputs=outputs, aux_losses=model.losses) | |
# Scales loss as the default gradients allreduce performs sum inside the | |
# optimizer. | |
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync | |
# For mixed precision, when a LossScaleOptimizer is used, the loss is | |
# scaled to avoid numeric underflow. | |
if isinstance(optimizer, | |
tf.keras.mixed_precision.experimental.LossScaleOptimizer): | |
scaled_loss = optimizer.get_scaled_loss(scaled_loss) | |
tvars = model.trainable_variables | |
grads = tape.gradient(scaled_loss, tvars) | |
if isinstance(optimizer, | |
tf.keras.mixed_precision.experimental.LossScaleOptimizer): | |
grads = optimizer.get_unscaled_gradients(grads) | |
optimizer.apply_gradients(list(zip(grads, tvars))) | |
logs = {self.loss: loss} | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in metrics}) | |
elif model.compiled_metrics: | |
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in model.metrics}) | |
return logs | |
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): | |
"""Validatation step. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the keras.Model. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
if isinstance(inputs, tuple) and len(inputs) == 2: | |
features, labels = inputs | |
else: | |
features, labels = inputs, inputs | |
outputs = self.inference_step(features, model) | |
loss = self.build_losses( | |
labels=labels, model_outputs=outputs, aux_losses=model.losses) | |
logs = {self.loss: loss} | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in metrics}) | |
elif model.compiled_metrics: | |
self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in model.metrics}) | |
return logs | |
def inference_step(self, inputs, model: tf.keras.Model): | |
"""Performs the forward step.""" | |
return model(inputs, training=False) | |
def aggregate_logs(self, state, step_logs): | |
"""Optional aggregation over logs returned from a validation step.""" | |
pass | |
def reduce_aggregated_logs(self, aggregated_logs): | |
"""Optional reduce of aggregated logs over validation steps.""" | |
return {} | |
_REGISTERED_TASK_CLS = {} | |
# TODO(b/158268740): Move these outside the base class file. | |
# TODO(b/158741360): Add type annotations once pytype checks across modules. | |
def register_task_cls(task_config_cls): | |
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig. | |
This decorator supports registration of tasks as follows: | |
``` | |
@dataclasses.dataclass | |
class MyTaskConfig(TaskConfig): | |
# Add fields here. | |
pass | |
@register_task_cls(MyTaskConfig) | |
class MyTask(Task): | |
# Inherits def __init__(self, task_config). | |
pass | |
my_task_config = MyTaskConfig() | |
my_task = get_task(my_task_config) # Returns MyTask(my_task_config). | |
``` | |
Besisdes a class itself, other callables that create a Task from a TaskConfig | |
can be decorated by the result of this function, as long as there is at most | |
one registration for each config class. | |
Args: | |
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig). | |
Each task_config_cls can only be used for a single registration. | |
Returns: | |
A callable for use as class decorator that registers the decorated class | |
for creation from an instance of task_config_cls. | |
""" | |
return registry.register(_REGISTERED_TASK_CLS, task_config_cls) | |
# The user-visible get_task() is defined after classes have been registered. | |
# TODO(b/158741360): Add type annotations once pytype checks across modules. | |
def get_task_cls(task_config_cls): | |
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls) | |
return task_cls | |