|
"""Callbacks: utilities called at certain points during model training. |
|
|
|
# Adapted from |
|
|
|
- https://github.com/keras-team/keras |
|
- https://github.com/bstriner/keras-tqdm/blob/master/keras_tqdm/tqdm_callback.py |
|
|
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
import csv |
|
import six |
|
|
|
import numpy as np |
|
import time |
|
import json |
|
import warnings |
|
from tqdm import tqdm |
|
|
|
from collections import deque |
|
from collections import OrderedDict |
|
from collections import Iterable |
|
|
|
try: |
|
import requests |
|
except ImportError: |
|
requests = None |
|
|
|
|
|
class CallbackList(object): |
|
"""Container abstracting a list of callbacks. |
|
|
|
# Arguments |
|
callbacks: List of `Callback` instances. |
|
queue_length: Queue length for keeping |
|
running statistics over callback execution time. |
|
""" |
|
|
|
def __init__(self, callbacks=None, queue_length=10): |
|
callbacks = callbacks or [] |
|
self.callbacks = [c for c in callbacks] |
|
self.queue_length = queue_length |
|
|
|
def append(self, callback): |
|
self.callbacks.append(callback) |
|
|
|
def set_params(self, params): |
|
for callback in self.callbacks: |
|
callback.set_params(params) |
|
|
|
def set_model(self, model): |
|
for callback in self.callbacks: |
|
callback.set_model(model) |
|
|
|
def on_epoch_begin(self, epoch, logs=None): |
|
"""Called at the start of an epoch. |
|
|
|
# Arguments |
|
epoch: integer, index of epoch. |
|
logs: dictionary of logs. |
|
""" |
|
logs = logs or {} |
|
for callback in self.callbacks: |
|
callback.on_epoch_begin(epoch, logs) |
|
self._delta_t_batch = 0. |
|
self._delta_ts_batch_begin = deque([], maxlen=self.queue_length) |
|
self._delta_ts_batch_end = deque([], maxlen=self.queue_length) |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
"""Called at the end of an epoch. |
|
|
|
# Arguments |
|
epoch: integer, index of epoch. |
|
logs: dictionary of logs. |
|
""" |
|
logs = logs or {} |
|
for callback in self.callbacks: |
|
callback.on_epoch_end(epoch, logs) |
|
|
|
def on_batch_begin(self, batch, logs=None): |
|
"""Called right before processing a batch. |
|
|
|
# Arguments |
|
batch: integer, index of batch within the current epoch. |
|
logs: dictionary of logs. |
|
""" |
|
logs = logs or {} |
|
t_before_callbacks = time.time() |
|
for callback in self.callbacks: |
|
callback.on_batch_begin(batch, logs) |
|
self._delta_ts_batch_begin.append(time.time() - t_before_callbacks) |
|
delta_t_median = np.median(self._delta_ts_batch_begin) |
|
if (self._delta_t_batch > 0. and |
|
delta_t_median > 0.95 * self._delta_t_batch and |
|
delta_t_median > 0.1): |
|
warnings.warn('Method on_batch_begin() is slow compared ' |
|
'to the batch update (%f). Check your callbacks.' |
|
% delta_t_median) |
|
self._t_enter_batch = time.time() |
|
|
|
def on_batch_end(self, batch, logs=None): |
|
"""Called at the end of a batch. |
|
|
|
# Arguments |
|
batch: integer, index of batch within the current epoch. |
|
logs: dictionary of logs. |
|
""" |
|
logs = logs or {} |
|
if not hasattr(self, '_t_enter_batch'): |
|
self._t_enter_batch = time.time() |
|
self._delta_t_batch = time.time() - self._t_enter_batch |
|
t_before_callbacks = time.time() |
|
for callback in self.callbacks: |
|
callback.on_batch_end(batch, logs) |
|
self._delta_ts_batch_end.append(time.time() - t_before_callbacks) |
|
delta_t_median = np.median(self._delta_ts_batch_end) |
|
if (self._delta_t_batch > 0. and |
|
(delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)): |
|
warnings.warn('Method on_batch_end() is slow compared ' |
|
'to the batch update (%f). Check your callbacks.' |
|
% delta_t_median) |
|
|
|
def on_train_begin(self, logs=None): |
|
"""Called at the beginning of training. |
|
|
|
# Arguments |
|
logs: dictionary of logs. |
|
""" |
|
logs = logs or {} |
|
for callback in self.callbacks: |
|
callback.on_train_begin(logs) |
|
|
|
def on_train_end(self, logs=None): |
|
"""Called at the end of training. |
|
|
|
# Arguments |
|
logs: dictionary of logs. |
|
""" |
|
logs = logs or {} |
|
for callback in self.callbacks: |
|
callback.on_train_end(logs) |
|
|
|
def __iter__(self): |
|
return iter(self.callbacks) |
|
|
|
|
|
class Callback(object): |
|
"""Abstract base class used to build new callbacks. |
|
|
|
# Properties |
|
params: dict. Training parameters |
|
(eg. verbosity, batch size, number of epochs...). |
|
model: instance of `keras.models.Model`. |
|
Reference of the model being trained. |
|
|
|
The `logs` dictionary that callback methods |
|
take as argument will contain keys for quantities relevant to |
|
the current batch or epoch. |
|
|
|
Currently, the `.fit()` method of the `Sequential` model class |
|
will include the following quantities in the `logs` that |
|
it passes to its callbacks: |
|
|
|
on_epoch_end: logs include `acc` and `loss`, and |
|
optionally include `val_loss` |
|
(if validation is enabled in `fit`), and `val_acc` |
|
(if validation and accuracy monitoring are enabled). |
|
on_batch_begin: logs include `size`, |
|
the number of samples in the current batch. |
|
on_batch_end: logs include `loss`, and optionally `acc` |
|
(if accuracy monitoring is enabled). |
|
""" |
|
|
|
def __init__(self): |
|
self.validation_data = None |
|
self.model = None |
|
|
|
def set_params(self, params): |
|
self.params = params |
|
|
|
def set_model(self, model): |
|
self.model = model |
|
|
|
def on_epoch_begin(self, epoch, logs=None): |
|
pass |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
pass |
|
|
|
def on_batch_begin(self, batch, logs=None): |
|
pass |
|
|
|
def on_batch_end(self, batch, logs=None): |
|
pass |
|
|
|
def on_train_begin(self, logs=None): |
|
pass |
|
|
|
def on_train_end(self, logs=None): |
|
pass |
|
|
|
|
|
class BaseLogger(Callback): |
|
"""Callback that accumulates epoch averages of metrics. |
|
|
|
This callback is automatically applied to every Keras model. |
|
""" |
|
|
|
def on_epoch_begin(self, epoch, logs=None): |
|
self.seen = 0 |
|
self.totals = {} |
|
|
|
def on_batch_end(self, batch, logs=None): |
|
logs = logs or {} |
|
batch_size = logs.get('size', 0) |
|
self.seen += batch_size |
|
|
|
for k, v in logs.items(): |
|
if k in self.totals: |
|
self.totals[k] += v * batch_size |
|
else: |
|
self.totals[k] = v * batch_size |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
if logs is not None: |
|
for k in self.params['metrics']: |
|
if k in self.totals: |
|
|
|
logs[k] = self.totals[k] / self.seen |
|
|
|
|
|
class TerminateOnNaN(Callback): |
|
"""Callback that terminates training when a NaN loss is encountered. |
|
""" |
|
|
|
def __init__(self): |
|
super(TerminateOnNaN, self).__init__() |
|
|
|
def on_batch_end(self, batch, logs=None): |
|
logs = logs or {} |
|
loss = logs.get('loss') |
|
if loss is not None: |
|
if np.isnan(loss) or np.isinf(loss): |
|
print('Batch %d: Invalid loss, terminating training' % (batch)) |
|
self.model.stop_training = True |
|
|
|
|
|
class History(Callback): |
|
"""Callback that records events into a `History` object. |
|
|
|
This callback is automatically applied to |
|
every Keras model. The `History` object |
|
gets returned by the `fit` method of models. |
|
""" |
|
|
|
def on_train_begin(self, logs=None): |
|
self.epoch = [] |
|
self.history = {} |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
logs = logs or {} |
|
self.epoch.append(epoch) |
|
for k, v in logs.items(): |
|
self.history.setdefault(k, []).append(v) |
|
|
|
|
|
class ModelCheckpoint(Callback): |
|
"""Save the model after every epoch. |
|
|
|
`filepath` can contain named formatting options, |
|
which will be filled the value of `epoch` and |
|
keys in `logs` (passed in `on_epoch_end`). |
|
|
|
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, |
|
then the model checkpoints will be saved with the epoch number and |
|
the validation loss in the filename. |
|
|
|
# Arguments |
|
filepath: string, path to save the model file. |
|
monitor: quantity to monitor. |
|
verbose: verbosity mode, 0 or 1. |
|
save_best_only: if `save_best_only=True`, |
|
the latest best model according to |
|
the quantity monitored will not be overwritten. |
|
mode: one of {auto, min, max}. |
|
If `save_best_only=True`, the decision |
|
to overwrite the current save file is made |
|
based on either the maximization or the |
|
minimization of the monitored quantity. For `val_acc`, |
|
this should be `max`, for `val_loss` this should |
|
be `min`, etc. In `auto` mode, the direction is |
|
automatically inferred from the name of the monitored quantity. |
|
save_weights_only: if True, then only the model's weights will be |
|
saved (`torch.save(self.model.state_dict(), filepath)`), else the full model |
|
is saved (`torch.save(self.model.state_dict(), filepath)`). |
|
period: Interval (number of epochs) between checkpoints. |
|
""" |
|
|
|
def __init__(self, filepath, monitor='val_loss', verbose=0, |
|
save_best_only=False, save_weights_only=False, |
|
mode='auto', period=1): |
|
super(ModelCheckpoint, self).__init__() |
|
self.monitor = monitor |
|
self.verbose = verbose |
|
self.filepath = filepath |
|
self.save_best_only = save_best_only |
|
self.save_weights_only = save_weights_only |
|
self.period = period |
|
self.epochs_since_last_save = 0 |
|
|
|
if mode not in ['auto', 'min', 'max']: |
|
warnings.warn('ModelCheckpoint mode %s is unknown, ' |
|
'fallback to auto mode.' % (mode), |
|
RuntimeWarning) |
|
mode = 'auto' |
|
|
|
if mode == 'min': |
|
self.monitor_op = np.less |
|
self.best = np.Inf |
|
elif mode == 'max': |
|
self.monitor_op = np.greater |
|
self.best = -np.Inf |
|
else: |
|
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): |
|
self.monitor_op = np.greater |
|
self.best = -np.Inf |
|
else: |
|
self.monitor_op = np.less |
|
self.best = np.Inf |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
import torch |
|
logs = logs or {} |
|
self.epochs_since_last_save += 1 |
|
if self.epochs_since_last_save >= self.period: |
|
self.epochs_since_last_save = 0 |
|
filepath = self.filepath.format(epoch=epoch + 1, **logs) |
|
if self.save_best_only: |
|
current = logs.get(self.monitor) |
|
if current is None: |
|
warnings.warn('Can save best model only with %s available, ' |
|
'skipping.' % (self.monitor), RuntimeWarning) |
|
else: |
|
if self.monitor_op(current, self.best): |
|
if self.verbose > 0: |
|
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,' |
|
' saving model to %s' |
|
% (epoch + 1, self.monitor, self.best, |
|
current, filepath)) |
|
self.best = current |
|
if self.save_weights_only: |
|
torch.save(self.model.state_dict(), filepath) |
|
else: |
|
torch.save(self.model.state_dict(), filepath) |
|
else: |
|
if self.verbose > 0: |
|
print('\nEpoch %05d: %s did not improve' % |
|
(epoch + 1, self.monitor)) |
|
else: |
|
if self.verbose > 0: |
|
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath)) |
|
if self.save_weights_only: |
|
torch.save(self.model.state_dict(), filepath) |
|
else: |
|
torch.save(self.model.state_dict(), filepath) |
|
|
|
|
|
class EarlyStopping(Callback): |
|
"""Stop training when a monitored quantity has stopped improving. |
|
|
|
# Arguments |
|
monitor: quantity to be monitored. |
|
min_delta: minimum change in the monitored quantity |
|
to qualify as an improvement, i.e. an absolute |
|
change of less than min_delta, will count as no |
|
improvement. |
|
patience: number of epochs with no improvement |
|
after which training will be stopped. |
|
verbose: verbosity mode. |
|
mode: one of {auto, min, max}. In `min` mode, |
|
training will stop when the quantity |
|
monitored has stopped decreasing; in `max` |
|
mode it will stop when the quantity |
|
monitored has stopped increasing; in `auto` |
|
mode, the direction is automatically inferred |
|
from the name of the monitored quantity. |
|
""" |
|
|
|
def __init__(self, monitor='val_loss', |
|
min_delta=0, patience=0, verbose=0, mode='auto'): |
|
super(EarlyStopping, self).__init__() |
|
|
|
self.monitor = monitor |
|
self.patience = patience |
|
self.verbose = verbose |
|
self.min_delta = min_delta |
|
self.wait = 0 |
|
self.stopped_epoch = 0 |
|
|
|
if mode not in ['auto', 'min', 'max']: |
|
warnings.warn('EarlyStopping mode %s is unknown, ' |
|
'fallback to auto mode.' % mode, |
|
RuntimeWarning) |
|
mode = 'auto' |
|
|
|
if mode == 'min': |
|
self.monitor_op = np.less |
|
elif mode == 'max': |
|
self.monitor_op = np.greater |
|
else: |
|
if 'acc' in self.monitor: |
|
self.monitor_op = np.greater |
|
else: |
|
self.monitor_op = np.less |
|
|
|
if self.monitor_op == np.greater: |
|
self.min_delta *= 1 |
|
else: |
|
self.min_delta *= -1 |
|
|
|
def on_train_begin(self, logs=None): |
|
|
|
self.wait = 0 |
|
self.stopped_epoch = 0 |
|
self.best = np.Inf if self.monitor_op == np.less else -np.Inf |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
current = logs.get(self.monitor) |
|
if current is None: |
|
warnings.warn( |
|
'Early stopping conditioned on metric `%s` ' |
|
'which is not available. Available metrics are: %s' % |
|
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning |
|
) |
|
return |
|
if self.monitor_op(current - self.min_delta, self.best): |
|
self.best = current |
|
self.wait = 0 |
|
else: |
|
self.wait += 1 |
|
if self.wait >= self.patience: |
|
self.stopped_epoch = epoch |
|
self.model.stop_training = True |
|
|
|
def on_train_end(self, logs=None): |
|
if self.stopped_epoch > 0 and self.verbose > 0: |
|
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) |
|
|
|
|
|
class RemoteMonitor(Callback): |
|
"""Callback used to stream events to a server. |
|
|
|
Requires the `requests` library. |
|
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are |
|
HTTP POST, with a `images` argument which is a |
|
JSON-encoded dictionary of event images. |
|
|
|
# Arguments |
|
root: String; root url of the target server. |
|
path: String; path relative to `root` to which the events will be sent. |
|
field: String; JSON field under which the images will be stored. |
|
headers: Dictionary; optional custom HTTP headers. |
|
""" |
|
|
|
def __init__(self, |
|
root='http://localhost:9000', |
|
path='/publish/epoch/end/', |
|
field='images', |
|
headers=None): |
|
super(RemoteMonitor, self).__init__() |
|
|
|
self.root = root |
|
self.path = path |
|
self.field = field |
|
self.headers = headers |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
if requests is None: |
|
raise ImportError('RemoteMonitor requires ' |
|
'the `requests` library.') |
|
logs = logs or {} |
|
send = {} |
|
send['epoch'] = epoch |
|
for k, v in logs.items(): |
|
if isinstance(v, (np.ndarray, np.generic)): |
|
send[k] = v.item() |
|
else: |
|
send[k] = v |
|
try: |
|
requests.post(self.root + self.path, |
|
{self.field: json.dumps(send)}, |
|
headers=self.headers) |
|
except requests.exceptions.RequestException: |
|
warnings.warn('Warning: could not reach RemoteMonitor ' |
|
'root server at ' + str(self.root)) |
|
|
|
|
|
class TensorBoard(Callback): |
|
"""TensorBoard basic visualizations. |
|
|
|
[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) |
|
is a visualization tool provided with TensorFlow. |
|
|
|
This callback writes a log for TensorBoard, which allows |
|
you to visualize dynamic graphs of your training and test |
|
metrics, as well as activation histograms for the different |
|
layers in your model. |
|
|
|
If you have installed TensorFlow with pip, you should be able |
|
to launch TensorBoard from the command line: |
|
```sh |
|
tensorboard --logdir=/full_path_to_your_logs |
|
``` |
|
|
|
When using a backend other than TensorFlow, TensorBoard will still work |
|
(if you have TensorFlow installed), but the only feature available will |
|
be the display of the losses and metrics plots. |
|
|
|
# Arguments |
|
log_dir: the path of the directory where to save the log |
|
files to be parsed by TensorBoard. |
|
histogram_freq: frequency (in epochs) at which to compute activation |
|
and weight histograms for the layers of the model. If set to 0, |
|
histograms won't be computed. Validation images (or split) must be |
|
specified for histogram visualizations. |
|
write_graph: whether to visualize the graph in TensorBoard. |
|
The log file can become quite large when |
|
write_graph is set to True. |
|
write_grads: whether to visualize gradient histograms in TensorBoard. |
|
`histogram_freq` must be greater than 0. |
|
batch_size: size of batch of inputs to feed to the network |
|
for histograms computation. |
|
write_images: whether to write model weights to visualize as |
|
image in TensorBoard. |
|
embeddings_freq: frequency (in epochs) at which selected embedding |
|
layers will be saved. |
|
embeddings_layer_names: a list of names of layers to keep eye on. If |
|
None or empty list all the embedding layer will be watched. |
|
embeddings_metadata: a dictionary which maps layer name to a file name |
|
in which metadata for this embedding layer is saved. See the |
|
[details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional) |
|
about metadata files format. In case if the same metadata file is |
|
used for all embedding layers, string can be passed. |
|
""" |
|
|
|
def __init__(self, log_dir='./logs', |
|
histogram_freq=0, |
|
batch_size=32, |
|
write_graph=True, |
|
write_grads=False, |
|
write_images=False, |
|
embeddings_freq=0, |
|
embeddings_layer_names=None, |
|
embeddings_metadata=None): |
|
super(TensorBoard, self).__init__() |
|
global tf, projector |
|
try: |
|
import tensorflow as tf |
|
from tensorflow.contrib.tensorboard.plugins import projector |
|
except ImportError: |
|
raise ImportError('You need the TensorFlow module installed to use TensorBoard.') |
|
|
|
if K.backend() != 'tensorflow': |
|
if histogram_freq != 0: |
|
warnings.warn('You are not using the TensorFlow backend. ' |
|
'histogram_freq was set to 0') |
|
histogram_freq = 0 |
|
if write_graph: |
|
warnings.warn('You are not using the TensorFlow backend. ' |
|
'write_graph was set to False') |
|
write_graph = False |
|
if write_images: |
|
warnings.warn('You are not using the TensorFlow backend. ' |
|
'write_images was set to False') |
|
write_images = False |
|
if embeddings_freq != 0: |
|
warnings.warn('You are not using the TensorFlow backend. ' |
|
'embeddings_freq was set to 0') |
|
embeddings_freq = 0 |
|
|
|
self.log_dir = log_dir |
|
self.histogram_freq = histogram_freq |
|
self.merged = None |
|
self.write_graph = write_graph |
|
self.write_grads = write_grads |
|
self.write_images = write_images |
|
self.embeddings_freq = embeddings_freq |
|
self.embeddings_layer_names = embeddings_layer_names |
|
self.embeddings_metadata = embeddings_metadata or {} |
|
self.batch_size = batch_size |
|
|
|
def set_model(self, model): |
|
self.model = model |
|
if K.backend() == 'tensorflow': |
|
self.sess = K.get_session() |
|
if self.histogram_freq and self.merged is None: |
|
for layer in self.model.layers: |
|
|
|
for weight in layer.weights: |
|
mapped_weight_name = weight.name.replace(':', '_') |
|
tf.summary.histogram(mapped_weight_name, weight) |
|
if self.write_grads: |
|
grads = model.optimizer.get_gradients(model.total_loss, |
|
weight) |
|
|
|
def is_indexed_slices(grad): |
|
return type(grad).__name__ == 'IndexedSlices' |
|
grads = [ |
|
grad.values if is_indexed_slices(grad) else grad |
|
for grad in grads] |
|
tf.summary.histogram('{}_grad'.format(mapped_weight_name), grads) |
|
if self.write_images: |
|
w_img = tf.squeeze(weight) |
|
shape = K.int_shape(w_img) |
|
if len(shape) == 2: |
|
if shape[0] > shape[1]: |
|
w_img = tf.transpose(w_img) |
|
shape = K.int_shape(w_img) |
|
w_img = tf.reshape(w_img, [1, |
|
shape[0], |
|
shape[1], |
|
1]) |
|
elif len(shape) == 3: |
|
if K.image_data_format() == 'channels_last': |
|
|
|
|
|
w_img = tf.transpose(w_img, perm=[2, 0, 1]) |
|
shape = K.int_shape(w_img) |
|
w_img = tf.reshape(w_img, [shape[0], |
|
shape[1], |
|
shape[2], |
|
1]) |
|
elif len(shape) == 1: |
|
w_img = tf.reshape(w_img, [1, |
|
shape[0], |
|
1, |
|
1]) |
|
else: |
|
|
|
continue |
|
|
|
shape = K.int_shape(w_img) |
|
assert len(shape) == 4 and shape[-1] in [1, 3, 4] |
|
tf.summary.image(mapped_weight_name, w_img) |
|
|
|
if hasattr(layer, 'output'): |
|
tf.summary.histogram('{}_out'.format(layer.name), |
|
layer.output) |
|
self.merged = tf.summary.merge_all() |
|
|
|
if self.write_graph: |
|
self.writer = tf.summary.FileWriter(self.log_dir, |
|
self.sess.graph) |
|
else: |
|
self.writer = tf.summary.FileWriter(self.log_dir) |
|
|
|
if self.embeddings_freq: |
|
embeddings_layer_names = self.embeddings_layer_names |
|
|
|
if not embeddings_layer_names: |
|
embeddings_layer_names = [layer.name for layer in self.model.layers |
|
if type(layer).__name__ == 'Embedding'] |
|
|
|
embeddings = {layer.name: layer.weights[0] |
|
for layer in self.model.layers |
|
if layer.name in embeddings_layer_names} |
|
|
|
self.saver = tf.train.Saver(list(embeddings.values())) |
|
|
|
embeddings_metadata = {} |
|
|
|
if not isinstance(self.embeddings_metadata, str): |
|
embeddings_metadata = self.embeddings_metadata |
|
else: |
|
embeddings_metadata = {layer_name: self.embeddings_metadata |
|
for layer_name in embeddings.keys()} |
|
|
|
config = projector.ProjectorConfig() |
|
self.embeddings_ckpt_path = os.path.join(self.log_dir, |
|
'keras_embedding.ckpt') |
|
|
|
for layer_name, tensor in embeddings.items(): |
|
embedding = config.embeddings.add() |
|
embedding.tensor_name = tensor.name |
|
|
|
if layer_name in embeddings_metadata: |
|
embedding.metadata_path = embeddings_metadata[layer_name] |
|
|
|
projector.visualize_embeddings(self.writer, config) |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
logs = logs or {} |
|
|
|
if not self.validation_data and self.histogram_freq: |
|
raise ValueError('If printing histograms, validation_data must be ' |
|
'provided, and cannot be a generator.') |
|
if self.validation_data and self.histogram_freq: |
|
if epoch % self.histogram_freq == 0: |
|
|
|
val_data = self.validation_data |
|
tensors = (self.model.inputs + |
|
self.model.targets + |
|
self.model.sample_weights) |
|
|
|
if self.model.uses_learning_phase: |
|
tensors += [K.learning_phase()] |
|
|
|
assert len(val_data) == len(tensors) |
|
val_size = val_data[0].shape[0] |
|
i = 0 |
|
while i < val_size: |
|
step = min(self.batch_size, val_size - i) |
|
if self.model.uses_learning_phase: |
|
|
|
batch_val = [x[i:i + step] for x in val_data[:-1]] |
|
batch_val.append(val_data[-1]) |
|
else: |
|
batch_val = [x[i:i + step] for x in val_data] |
|
assert len(batch_val) == len(tensors) |
|
feed_dict = dict(zip(tensors, batch_val)) |
|
result = self.sess.run([self.merged], feed_dict=feed_dict) |
|
summary_str = result[0] |
|
self.writer.add_summary(summary_str, epoch) |
|
i += self.batch_size |
|
|
|
if self.embeddings_freq and self.embeddings_ckpt_path: |
|
if epoch % self.embeddings_freq == 0: |
|
self.saver.save(self.sess, |
|
self.embeddings_ckpt_path, |
|
epoch) |
|
|
|
for name, value in logs.items(): |
|
if name in ['batch', 'size']: |
|
continue |
|
summary = tf.Summary() |
|
summary_value = summary.value.add() |
|
summary_value.simple_value = value.item() |
|
summary_value.tag = name |
|
self.writer.add_summary(summary, epoch) |
|
self.writer.flush() |
|
|
|
def on_train_end(self, _): |
|
self.writer.close() |
|
|
|
|
|
class CSVLogger(Callback): |
|
"""Callback that streams epoch results to a csv file. |
|
|
|
Supports all values that can be represented as a string, |
|
including 1D iterables such as np.ndarray. |
|
|
|
# Example |
|
|
|
```python |
|
csv_logger = CSVLogger('training.log') |
|
model.fit(X_train, Y_train, callbacks=[csv_logger]) |
|
``` |
|
|
|
# Arguments |
|
filename: filename of the csv file, e.g. 'run/log.csv'. |
|
separator: string used to separate elements in the csv file. |
|
append: True: append if file exists (useful for continuing |
|
training). False: overwrite existing file, |
|
output_on_train_end: An additional output file to write to |
|
write to when training ends. An example is |
|
CSVLogger(filename='./mylog.csv', output_on_train_end=os.sys.stdout) |
|
""" |
|
|
|
def __init__(self, filename, separator=',', append=False, output_on_train_end=None): |
|
self.sep = separator |
|
self.filename = filename |
|
self.append = append |
|
self.writer = None |
|
self.keys = None |
|
self.append_header = True |
|
self.file_flags = 'b' if six.PY2 and os.name == 'nt' else '' |
|
self.output_on_train_end = output_on_train_end |
|
super(CSVLogger, self).__init__() |
|
|
|
def on_train_begin(self, logs=None): |
|
if self.append: |
|
if os.path.exists(self.filename): |
|
with open(self.filename, 'r' + self.file_flags) as f: |
|
self.append_header = not bool(len(f.readline())) |
|
self.csv_file = open(self.filename, 'a' + self.file_flags) |
|
else: |
|
self.csv_file = open(self.filename, 'w' + self.file_flags) |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
logs = logs or {} |
|
|
|
def handle_value(k): |
|
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0 |
|
if isinstance(k, six.string_types): |
|
return k |
|
elif isinstance(k, Iterable) and not is_zero_dim_ndarray: |
|
return '"[%s]"' % (', '.join(map(str, k))) |
|
else: |
|
return k |
|
|
|
if self.keys is None: |
|
self.keys = sorted(logs.keys()) |
|
|
|
if self.model is not None and getattr(self.model, 'stop_training', False): |
|
|
|
logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys]) |
|
|
|
if not self.writer: |
|
class CustomDialect(csv.excel): |
|
delimiter = self.sep |
|
|
|
self.writer = csv.DictWriter(self.csv_file, |
|
fieldnames=['epoch'] + self.keys, dialect=CustomDialect) |
|
if self.append_header: |
|
self.writer.writeheader() |
|
|
|
row_dict = OrderedDict({'epoch': epoch}) |
|
row_dict.update((key, handle_value(logs[key])) for key in self.keys) |
|
self.writer.writerow(row_dict) |
|
self.csv_file.flush() |
|
|
|
def on_train_end(self, logs=None): |
|
self.csv_file.close() |
|
if os.path.exists(self.filename): |
|
with open(self.filename, 'r' + self.file_flags) as f: |
|
print(f.read(), file=self.output_on_train_end) |
|
self.writer = None |
|
|
|
|
|
class LambdaCallback(Callback): |
|
r"""Callback for creating simple, custom callbacks on-the-fly. |
|
|
|
This callback is constructed with anonymous functions that will be called |
|
at the appropriate time. Note that the callbacks expects positional |
|
arguments, as: |
|
|
|
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments: |
|
`epoch`, `logs` |
|
- `on_batch_begin` and `on_batch_end` expect two positional arguments: |
|
`batch`, `logs` |
|
- `on_train_begin` and `on_train_end` expect one positional argument: |
|
`logs` |
|
|
|
# Arguments |
|
on_epoch_begin: called at the beginning of every epoch. |
|
on_epoch_end: called at the end of every epoch. |
|
on_batch_begin: called at the beginning of every batch. |
|
on_batch_end: called at the end of every batch. |
|
on_train_begin: called at the beginning of model training. |
|
on_train_end: called at the end of model training. |
|
|
|
# Example |
|
|
|
```python |
|
# Print the batch number at the beginning of every batch. |
|
batch_print_callback = LambdaCallback( |
|
on_batch_begin=lambda batch,logs: print(batch)) |
|
|
|
# Stream the epoch loss to a file in JSON format. The file content |
|
# is not well-formed JSON but rather has a JSON object per line. |
|
import json |
|
json_log = open('loss_log.json', mode='wt', buffering=1) |
|
json_logging_callback = LambdaCallback( |
|
on_epoch_end=lambda epoch, logs: json_log.write( |
|
json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'), |
|
on_train_end=lambda logs: json_log.close() |
|
) |
|
|
|
# Terminate some processes after having finished model training. |
|
processes = ... |
|
cleanup_callback = LambdaCallback( |
|
on_train_end=lambda logs: [ |
|
p.terminate() for p in processes if p.is_alive()]) |
|
|
|
model.fit(..., |
|
callbacks=[batch_print_callback, |
|
json_logging_callback, |
|
cleanup_callback]) |
|
``` |
|
""" |
|
|
|
def __init__(self, |
|
on_epoch_begin=None, |
|
on_epoch_end=None, |
|
on_batch_begin=None, |
|
on_batch_end=None, |
|
on_train_begin=None, |
|
on_train_end=None, |
|
**kwargs): |
|
super(LambdaCallback, self).__init__() |
|
self.__dict__.update(kwargs) |
|
if on_epoch_begin is not None: |
|
self.on_epoch_begin = on_epoch_begin |
|
else: |
|
self.on_epoch_begin = lambda epoch, logs: None |
|
if on_epoch_end is not None: |
|
self.on_epoch_end = on_epoch_end |
|
else: |
|
self.on_epoch_end = lambda epoch, logs: None |
|
if on_batch_begin is not None: |
|
self.on_batch_begin = on_batch_begin |
|
else: |
|
self.on_batch_begin = lambda batch, logs: None |
|
if on_batch_end is not None: |
|
self.on_batch_end = on_batch_end |
|
else: |
|
self.on_batch_end = lambda batch, logs: None |
|
if on_train_begin is not None: |
|
self.on_train_begin = on_train_begin |
|
else: |
|
self.on_train_begin = lambda logs: None |
|
if on_train_end is not None: |
|
self.on_train_end = on_train_end |
|
else: |
|
self.on_train_end = lambda logs: None |
|
from sys import stderr |
|
|
|
|
|
class TQDMCallback(Callback): |
|
def __init__(self, outer_description="Training", |
|
inner_description_initial="Epoch: {epoch}", |
|
inner_description_update="Epoch: {epoch} - {metrics}", |
|
metric_format="{name}: {value:0.3f}", |
|
separator=", ", |
|
leave_inner=True, |
|
leave_outer=True, |
|
show_inner=True, |
|
show_outer=True, |
|
output_file=stderr, |
|
initial=0): |
|
""" |
|
Construct a callback that will create and update progress bars. |
|
|
|
:param outer_description: string for outer progress bar |
|
:param inner_description_initial: initial format for epoch ("Epoch: {epoch}") |
|
:param inner_description_update: format after metrics collected ("Epoch: {epoch} - {metrics}") |
|
:param metric_format: format for each metric name/value pair ("{name}: {value:0.3f}") |
|
:param separator: separator between metrics (", ") |
|
:param leave_inner: True to leave inner bars |
|
:param leave_outer: True to leave outer bars |
|
:param show_inner: False to hide inner bars |
|
:param show_outer: False to hide outer bar |
|
:param output_file: output file (default sys.stderr) |
|
:param initial: Initial counter state |
|
""" |
|
self.outer_description = outer_description |
|
self.inner_description_initial = inner_description_initial |
|
self.inner_description_update = inner_description_update |
|
self.metric_format = metric_format |
|
self.separator = separator |
|
self.leave_inner = leave_inner |
|
self.leave_outer = leave_outer |
|
self.show_inner = show_inner |
|
self.show_outer = show_outer |
|
self.output_file = output_file |
|
self.tqdm_outer = None |
|
self.tqdm_inner = None |
|
self.epoch = None |
|
self.running_logs = None |
|
self.inner_count = None |
|
self.initial = initial |
|
|
|
def tqdm(self, desc, total, leave, initial=0): |
|
""" |
|
Extension point. Override to provide custom options to tqdm initializer. |
|
:param desc: Description string |
|
:param total: Total number of updates |
|
:param leave: Leave progress bar when done |
|
:param initial: Initial counter state |
|
:return: new progress bar |
|
""" |
|
return tqdm(desc=desc, total=total, leave=leave, file=self.output_file, initial=initial) |
|
|
|
def build_tqdm_outer(self, desc, total): |
|
""" |
|
Extension point. Override to provide custom options to outer progress bars (Epoch loop) |
|
:param desc: Description |
|
:param total: Number of epochs |
|
:return: new progress bar |
|
""" |
|
return self.tqdm(desc=desc, total=total, leave=self.leave_outer, initial=self.initial) |
|
|
|
def build_tqdm_inner(self, desc, total): |
|
""" |
|
Extension point. Override to provide custom options to inner progress bars (Batch loop) |
|
:param desc: Description |
|
:param total: Number of batches |
|
:return: new progress bar |
|
""" |
|
return self.tqdm(desc=desc, total=total, leave=self.leave_inner) |
|
|
|
def on_epoch_begin(self, epoch, logs={}): |
|
self.epoch = epoch |
|
desc = self.inner_description_initial.format(epoch=self.epoch) |
|
self.mode = 0 |
|
if 'samples' in self.params: |
|
self.inner_total = self.params['samples'] |
|
elif 'nb_sample' in self.params: |
|
self.inner_total = self.params['nb_sample'] |
|
else: |
|
self.mode = 1 |
|
self.inner_total = self.params['steps'] |
|
if self.show_inner: |
|
self.tqdm_inner = self.build_tqdm_inner(desc=desc, total=self.inner_total) |
|
self.inner_count = 0 |
|
self.running_logs = {} |
|
|
|
def on_epoch_end(self, epoch, logs={}): |
|
metrics = self.format_metrics(logs) |
|
desc = self.inner_description_update.format(epoch=epoch, metrics=metrics) |
|
if self.show_inner: |
|
self.tqdm_inner.desc = desc |
|
|
|
self.tqdm_inner.miniters = 0 |
|
self.tqdm_inner.mininterval = 0 |
|
self.tqdm_inner.update(self.inner_total - self.tqdm_inner.n) |
|
self.tqdm_inner.close() |
|
if self.show_outer: |
|
self.tqdm_outer.update(1) |
|
|
|
def on_batch_begin(self, batch, logs={}): |
|
pass |
|
|
|
def on_batch_end(self, batch, logs={}): |
|
if self.mode == 0: |
|
update = logs['size'] |
|
else: |
|
update = 1 |
|
self.inner_count += update |
|
if self.inner_count < self.inner_total: |
|
self.append_logs(logs) |
|
metrics = self.format_metrics(self.running_logs) |
|
desc = self.inner_description_update.format(epoch=self.epoch, metrics=metrics) |
|
if self.show_inner: |
|
self.tqdm_inner.desc = desc |
|
self.tqdm_inner.update(update) |
|
|
|
def on_train_begin(self, logs={}): |
|
if self.show_outer: |
|
epochs = (self.params['epochs'] if 'epochs' in self.params |
|
else self.params['nb_epoch']) |
|
self.tqdm_outer = self.build_tqdm_outer(desc=self.outer_description, |
|
total=epochs) |
|
|
|
def on_train_end(self, logs={}): |
|
if self.show_outer: |
|
self.tqdm_outer.close() |
|
|
|
def append_logs(self, logs): |
|
metrics = self.params['metrics'] |
|
for metric, value in six.iteritems(logs): |
|
if metric in metrics: |
|
if metric in self.running_logs: |
|
self.running_logs[metric].append(value[()]) |
|
else: |
|
self.running_logs[metric] = [value[()]] |
|
|
|
def format_metrics(self, logs): |
|
metrics = self.params['metrics'] |
|
strings = [self.metric_format.format(name=metric, value=np.mean(logs[metric], axis=None)) for metric in metrics |
|
if |
|
metric in logs] |
|
return self.separator.join(strings) |