Jyothirmai commited on
Commit
6ea8a2e
·
verified ·
1 Parent(s): 26e26de

Delete callbacks.py

Browse files
Files changed (1) hide show
  1. callbacks.py +0 -1066
callbacks.py DELETED
@@ -1,1066 +0,0 @@
1
- """Callbacks: utilities called at certain points during model training.
2
-
3
- # Adapted from
4
-
5
- - https://github.com/keras-team/keras
6
- - https://github.com/bstriner/keras-tqdm/blob/master/keras_tqdm/tqdm_callback.py
7
-
8
- """
9
- from __future__ import absolute_import
10
- from __future__ import division
11
- from __future__ import print_function
12
-
13
- import os
14
- import csv
15
- import six
16
-
17
- import numpy as np
18
- import time
19
- import json
20
- import warnings
21
- from tqdm import tqdm
22
-
23
- from collections import deque
24
- from collections import OrderedDict
25
- from collections import Iterable
26
-
27
- try:
28
- import requests
29
- except ImportError:
30
- requests = None
31
-
32
-
33
- class CallbackList(object):
34
- """Container abstracting a list of callbacks.
35
-
36
- # Arguments
37
- callbacks: List of `Callback` instances.
38
- queue_length: Queue length for keeping
39
- running statistics over callback execution time.
40
- """
41
-
42
- def __init__(self, callbacks=None, queue_length=10):
43
- callbacks = callbacks or []
44
- self.callbacks = [c for c in callbacks]
45
- self.queue_length = queue_length
46
-
47
- def append(self, callback):
48
- self.callbacks.append(callback)
49
-
50
- def set_params(self, params):
51
- for callback in self.callbacks:
52
- callback.set_params(params)
53
-
54
- def set_model(self, model):
55
- for callback in self.callbacks:
56
- callback.set_model(model)
57
-
58
- def on_epoch_begin(self, epoch, logs=None):
59
- """Called at the start of an epoch.
60
-
61
- # Arguments
62
- epoch: integer, index of epoch.
63
- logs: dictionary of logs.
64
- """
65
- logs = logs or {}
66
- for callback in self.callbacks:
67
- callback.on_epoch_begin(epoch, logs)
68
- self._delta_t_batch = 0.
69
- self._delta_ts_batch_begin = deque([], maxlen=self.queue_length)
70
- self._delta_ts_batch_end = deque([], maxlen=self.queue_length)
71
-
72
- def on_epoch_end(self, epoch, logs=None):
73
- """Called at the end of an epoch.
74
-
75
- # Arguments
76
- epoch: integer, index of epoch.
77
- logs: dictionary of logs.
78
- """
79
- logs = logs or {}
80
- for callback in self.callbacks:
81
- callback.on_epoch_end(epoch, logs)
82
-
83
- def on_batch_begin(self, batch, logs=None):
84
- """Called right before processing a batch.
85
-
86
- # Arguments
87
- batch: integer, index of batch within the current epoch.
88
- logs: dictionary of logs.
89
- """
90
- logs = logs or {}
91
- t_before_callbacks = time.time()
92
- for callback in self.callbacks:
93
- callback.on_batch_begin(batch, logs)
94
- self._delta_ts_batch_begin.append(time.time() - t_before_callbacks)
95
- delta_t_median = np.median(self._delta_ts_batch_begin)
96
- if (self._delta_t_batch > 0. and
97
- delta_t_median > 0.95 * self._delta_t_batch and
98
- delta_t_median > 0.1):
99
- warnings.warn('Method on_batch_begin() is slow compared '
100
- 'to the batch update (%f). Check your callbacks.'
101
- % delta_t_median)
102
- self._t_enter_batch = time.time()
103
-
104
- def on_batch_end(self, batch, logs=None):
105
- """Called at the end of a batch.
106
-
107
- # Arguments
108
- batch: integer, index of batch within the current epoch.
109
- logs: dictionary of logs.
110
- """
111
- logs = logs or {}
112
- if not hasattr(self, '_t_enter_batch'):
113
- self._t_enter_batch = time.time()
114
- self._delta_t_batch = time.time() - self._t_enter_batch
115
- t_before_callbacks = time.time()
116
- for callback in self.callbacks:
117
- callback.on_batch_end(batch, logs)
118
- self._delta_ts_batch_end.append(time.time() - t_before_callbacks)
119
- delta_t_median = np.median(self._delta_ts_batch_end)
120
- if (self._delta_t_batch > 0. and
121
- (delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
122
- warnings.warn('Method on_batch_end() is slow compared '
123
- 'to the batch update (%f). Check your callbacks.'
124
- % delta_t_median)
125
-
126
- def on_train_begin(self, logs=None):
127
- """Called at the beginning of training.
128
-
129
- # Arguments
130
- logs: dictionary of logs.
131
- """
132
- logs = logs or {}
133
- for callback in self.callbacks:
134
- callback.on_train_begin(logs)
135
-
136
- def on_train_end(self, logs=None):
137
- """Called at the end of training.
138
-
139
- # Arguments
140
- logs: dictionary of logs.
141
- """
142
- logs = logs or {}
143
- for callback in self.callbacks:
144
- callback.on_train_end(logs)
145
-
146
- def __iter__(self):
147
- return iter(self.callbacks)
148
-
149
-
150
- class Callback(object):
151
- """Abstract base class used to build new callbacks.
152
-
153
- # Properties
154
- params: dict. Training parameters
155
- (eg. verbosity, batch size, number of epochs...).
156
- model: instance of `keras.models.Model`.
157
- Reference of the model being trained.
158
-
159
- The `logs` dictionary that callback methods
160
- take as argument will contain keys for quantities relevant to
161
- the current batch or epoch.
162
-
163
- Currently, the `.fit()` method of the `Sequential` model class
164
- will include the following quantities in the `logs` that
165
- it passes to its callbacks:
166
-
167
- on_epoch_end: logs include `acc` and `loss`, and
168
- optionally include `val_loss`
169
- (if validation is enabled in `fit`), and `val_acc`
170
- (if validation and accuracy monitoring are enabled).
171
- on_batch_begin: logs include `size`,
172
- the number of samples in the current batch.
173
- on_batch_end: logs include `loss`, and optionally `acc`
174
- (if accuracy monitoring is enabled).
175
- """
176
-
177
- def __init__(self):
178
- self.validation_data = None
179
- self.model = None
180
-
181
- def set_params(self, params):
182
- self.params = params
183
-
184
- def set_model(self, model):
185
- self.model = model
186
-
187
- def on_epoch_begin(self, epoch, logs=None):
188
- pass
189
-
190
- def on_epoch_end(self, epoch, logs=None):
191
- pass
192
-
193
- def on_batch_begin(self, batch, logs=None):
194
- pass
195
-
196
- def on_batch_end(self, batch, logs=None):
197
- pass
198
-
199
- def on_train_begin(self, logs=None):
200
- pass
201
-
202
- def on_train_end(self, logs=None):
203
- pass
204
-
205
-
206
- class BaseLogger(Callback):
207
- """Callback that accumulates epoch averages of metrics.
208
-
209
- This callback is automatically applied to every Keras model.
210
- """
211
-
212
- def on_epoch_begin(self, epoch, logs=None):
213
- self.seen = 0
214
- self.totals = {}
215
-
216
- def on_batch_end(self, batch, logs=None):
217
- logs = logs or {}
218
- batch_size = logs.get('size', 0)
219
- self.seen += batch_size
220
-
221
- for k, v in logs.items():
222
- if k in self.totals:
223
- self.totals[k] += v * batch_size
224
- else:
225
- self.totals[k] = v * batch_size
226
-
227
- def on_epoch_end(self, epoch, logs=None):
228
- if logs is not None:
229
- for k in self.params['metrics']:
230
- if k in self.totals:
231
- # Make value available to next callbacks.
232
- logs[k] = self.totals[k] / self.seen
233
-
234
-
235
- class TerminateOnNaN(Callback):
236
- """Callback that terminates training when a NaN loss is encountered.
237
- """
238
-
239
- def __init__(self):
240
- super(TerminateOnNaN, self).__init__()
241
-
242
- def on_batch_end(self, batch, logs=None):
243
- logs = logs or {}
244
- loss = logs.get('loss')
245
- if loss is not None:
246
- if np.isnan(loss) or np.isinf(loss):
247
- print('Batch %d: Invalid loss, terminating training' % (batch))
248
- self.model.stop_training = True
249
-
250
-
251
- class History(Callback):
252
- """Callback that records events into a `History` object.
253
-
254
- This callback is automatically applied to
255
- every Keras model. The `History` object
256
- gets returned by the `fit` method of models.
257
- """
258
-
259
- def on_train_begin(self, logs=None):
260
- self.epoch = []
261
- self.history = {}
262
-
263
- def on_epoch_end(self, epoch, logs=None):
264
- logs = logs or {}
265
- self.epoch.append(epoch)
266
- for k, v in logs.items():
267
- self.history.setdefault(k, []).append(v)
268
-
269
-
270
- class ModelCheckpoint(Callback):
271
- """Save the model after every epoch.
272
-
273
- `filepath` can contain named formatting options,
274
- which will be filled the value of `epoch` and
275
- keys in `logs` (passed in `on_epoch_end`).
276
-
277
- For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
278
- then the model checkpoints will be saved with the epoch number and
279
- the validation loss in the filename.
280
-
281
- # Arguments
282
- filepath: string, path to save the model file.
283
- monitor: quantity to monitor.
284
- verbose: verbosity mode, 0 or 1.
285
- save_best_only: if `save_best_only=True`,
286
- the latest best model according to
287
- the quantity monitored will not be overwritten.
288
- mode: one of {auto, min, max}.
289
- If `save_best_only=True`, the decision
290
- to overwrite the current save file is made
291
- based on either the maximization or the
292
- minimization of the monitored quantity. For `val_acc`,
293
- this should be `max`, for `val_loss` this should
294
- be `min`, etc. In `auto` mode, the direction is
295
- automatically inferred from the name of the monitored quantity.
296
- save_weights_only: if True, then only the model's weights will be
297
- saved (`torch.save(self.model.state_dict(), filepath)`), else the full model
298
- is saved (`torch.save(self.model.state_dict(), filepath)`).
299
- period: Interval (number of epochs) between checkpoints.
300
- """
301
-
302
- def __init__(self, filepath, monitor='val_loss', verbose=0,
303
- save_best_only=False, save_weights_only=False,
304
- mode='auto', period=1):
305
- super(ModelCheckpoint, self).__init__()
306
- self.monitor = monitor
307
- self.verbose = verbose
308
- self.filepath = filepath
309
- self.save_best_only = save_best_only
310
- self.save_weights_only = save_weights_only
311
- self.period = period
312
- self.epochs_since_last_save = 0
313
-
314
- if mode not in ['auto', 'min', 'max']:
315
- warnings.warn('ModelCheckpoint mode %s is unknown, '
316
- 'fallback to auto mode.' % (mode),
317
- RuntimeWarning)
318
- mode = 'auto'
319
-
320
- if mode == 'min':
321
- self.monitor_op = np.less
322
- self.best = np.Inf
323
- elif mode == 'max':
324
- self.monitor_op = np.greater
325
- self.best = -np.Inf
326
- else:
327
- if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
328
- self.monitor_op = np.greater
329
- self.best = -np.Inf
330
- else:
331
- self.monitor_op = np.less
332
- self.best = np.Inf
333
-
334
- def on_epoch_end(self, epoch, logs=None):
335
- import torch
336
- logs = logs or {}
337
- self.epochs_since_last_save += 1
338
- if self.epochs_since_last_save >= self.period:
339
- self.epochs_since_last_save = 0
340
- filepath = self.filepath.format(epoch=epoch + 1, **logs)
341
- if self.save_best_only:
342
- current = logs.get(self.monitor)
343
- if current is None:
344
- warnings.warn('Can save best model only with %s available, '
345
- 'skipping.' % (self.monitor), RuntimeWarning)
346
- else:
347
- if self.monitor_op(current, self.best):
348
- if self.verbose > 0:
349
- print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
350
- ' saving model to %s'
351
- % (epoch + 1, self.monitor, self.best,
352
- current, filepath))
353
- self.best = current
354
- if self.save_weights_only:
355
- torch.save(self.model.state_dict(), filepath)
356
- else:
357
- torch.save(self.model.state_dict(), filepath)
358
- else:
359
- if self.verbose > 0:
360
- print('\nEpoch %05d: %s did not improve' %
361
- (epoch + 1, self.monitor))
362
- else:
363
- if self.verbose > 0:
364
- print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
365
- if self.save_weights_only:
366
- torch.save(self.model.state_dict(), filepath)
367
- else:
368
- torch.save(self.model.state_dict(), filepath)
369
-
370
-
371
- class EarlyStopping(Callback):
372
- """Stop training when a monitored quantity has stopped improving.
373
-
374
- # Arguments
375
- monitor: quantity to be monitored.
376
- min_delta: minimum change in the monitored quantity
377
- to qualify as an improvement, i.e. an absolute
378
- change of less than min_delta, will count as no
379
- improvement.
380
- patience: number of epochs with no improvement
381
- after which training will be stopped.
382
- verbose: verbosity mode.
383
- mode: one of {auto, min, max}. In `min` mode,
384
- training will stop when the quantity
385
- monitored has stopped decreasing; in `max`
386
- mode it will stop when the quantity
387
- monitored has stopped increasing; in `auto`
388
- mode, the direction is automatically inferred
389
- from the name of the monitored quantity.
390
- """
391
-
392
- def __init__(self, monitor='val_loss',
393
- min_delta=0, patience=0, verbose=0, mode='auto'):
394
- super(EarlyStopping, self).__init__()
395
-
396
- self.monitor = monitor
397
- self.patience = patience
398
- self.verbose = verbose
399
- self.min_delta = min_delta
400
- self.wait = 0
401
- self.stopped_epoch = 0
402
-
403
- if mode not in ['auto', 'min', 'max']:
404
- warnings.warn('EarlyStopping mode %s is unknown, '
405
- 'fallback to auto mode.' % mode,
406
- RuntimeWarning)
407
- mode = 'auto'
408
-
409
- if mode == 'min':
410
- self.monitor_op = np.less
411
- elif mode == 'max':
412
- self.monitor_op = np.greater
413
- else:
414
- if 'acc' in self.monitor:
415
- self.monitor_op = np.greater
416
- else:
417
- self.monitor_op = np.less
418
-
419
- if self.monitor_op == np.greater:
420
- self.min_delta *= 1
421
- else:
422
- self.min_delta *= -1
423
-
424
- def on_train_begin(self, logs=None):
425
- # Allow instances to be re-used
426
- self.wait = 0
427
- self.stopped_epoch = 0
428
- self.best = np.Inf if self.monitor_op == np.less else -np.Inf
429
-
430
- def on_epoch_end(self, epoch, logs=None):
431
- current = logs.get(self.monitor)
432
- if current is None:
433
- warnings.warn(
434
- 'Early stopping conditioned on metric `%s` '
435
- 'which is not available. Available metrics are: %s' %
436
- (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
437
- )
438
- return
439
- if self.monitor_op(current - self.min_delta, self.best):
440
- self.best = current
441
- self.wait = 0
442
- else:
443
- self.wait += 1
444
- if self.wait >= self.patience:
445
- self.stopped_epoch = epoch
446
- self.model.stop_training = True
447
-
448
- def on_train_end(self, logs=None):
449
- if self.stopped_epoch > 0 and self.verbose > 0:
450
- print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
451
-
452
-
453
- class RemoteMonitor(Callback):
454
- """Callback used to stream events to a server.
455
-
456
- Requires the `requests` library.
457
- Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
458
- HTTP POST, with a `images` argument which is a
459
- JSON-encoded dictionary of event images.
460
-
461
- # Arguments
462
- root: String; root url of the target server.
463
- path: String; path relative to `root` to which the events will be sent.
464
- field: String; JSON field under which the images will be stored.
465
- headers: Dictionary; optional custom HTTP headers.
466
- """
467
-
468
- def __init__(self,
469
- root='http://localhost:9000',
470
- path='/publish/epoch/end/',
471
- field='images',
472
- headers=None):
473
- super(RemoteMonitor, self).__init__()
474
-
475
- self.root = root
476
- self.path = path
477
- self.field = field
478
- self.headers = headers
479
-
480
- def on_epoch_end(self, epoch, logs=None):
481
- if requests is None:
482
- raise ImportError('RemoteMonitor requires '
483
- 'the `requests` library.')
484
- logs = logs or {}
485
- send = {}
486
- send['epoch'] = epoch
487
- for k, v in logs.items():
488
- if isinstance(v, (np.ndarray, np.generic)):
489
- send[k] = v.item()
490
- else:
491
- send[k] = v
492
- try:
493
- requests.post(self.root + self.path,
494
- {self.field: json.dumps(send)},
495
- headers=self.headers)
496
- except requests.exceptions.RequestException:
497
- warnings.warn('Warning: could not reach RemoteMonitor '
498
- 'root server at ' + str(self.root))
499
-
500
-
501
- class TensorBoard(Callback):
502
- """TensorBoard basic visualizations.
503
-
504
- [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
505
- is a visualization tool provided with TensorFlow.
506
-
507
- This callback writes a log for TensorBoard, which allows
508
- you to visualize dynamic graphs of your training and test
509
- metrics, as well as activation histograms for the different
510
- layers in your model.
511
-
512
- If you have installed TensorFlow with pip, you should be able
513
- to launch TensorBoard from the command line:
514
- ```sh
515
- tensorboard --logdir=/full_path_to_your_logs
516
- ```
517
-
518
- When using a backend other than TensorFlow, TensorBoard will still work
519
- (if you have TensorFlow installed), but the only feature available will
520
- be the display of the losses and metrics plots.
521
-
522
- # Arguments
523
- log_dir: the path of the directory where to save the log
524
- files to be parsed by TensorBoard.
525
- histogram_freq: frequency (in epochs) at which to compute activation
526
- and weight histograms for the layers of the model. If set to 0,
527
- histograms won't be computed. Validation images (or split) must be
528
- specified for histogram visualizations.
529
- write_graph: whether to visualize the graph in TensorBoard.
530
- The log file can become quite large when
531
- write_graph is set to True.
532
- write_grads: whether to visualize gradient histograms in TensorBoard.
533
- `histogram_freq` must be greater than 0.
534
- batch_size: size of batch of inputs to feed to the network
535
- for histograms computation.
536
- write_images: whether to write model weights to visualize as
537
- image in TensorBoard.
538
- embeddings_freq: frequency (in epochs) at which selected embedding
539
- layers will be saved.
540
- embeddings_layer_names: a list of names of layers to keep eye on. If
541
- None or empty list all the embedding layer will be watched.
542
- embeddings_metadata: a dictionary which maps layer name to a file name
543
- in which metadata for this embedding layer is saved. See the
544
- [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
545
- about metadata files format. In case if the same metadata file is
546
- used for all embedding layers, string can be passed.
547
- """
548
-
549
- def __init__(self, log_dir='./logs',
550
- histogram_freq=0,
551
- batch_size=32,
552
- write_graph=True,
553
- write_grads=False,
554
- write_images=False,
555
- embeddings_freq=0,
556
- embeddings_layer_names=None,
557
- embeddings_metadata=None):
558
- super(TensorBoard, self).__init__()
559
- global tf, projector
560
- try:
561
- import tensorflow as tf
562
- from tensorflow.contrib.tensorboard.plugins import projector
563
- except ImportError:
564
- raise ImportError('You need the TensorFlow module installed to use TensorBoard.')
565
-
566
- if K.backend() != 'tensorflow':
567
- if histogram_freq != 0:
568
- warnings.warn('You are not using the TensorFlow backend. '
569
- 'histogram_freq was set to 0')
570
- histogram_freq = 0
571
- if write_graph:
572
- warnings.warn('You are not using the TensorFlow backend. '
573
- 'write_graph was set to False')
574
- write_graph = False
575
- if write_images:
576
- warnings.warn('You are not using the TensorFlow backend. '
577
- 'write_images was set to False')
578
- write_images = False
579
- if embeddings_freq != 0:
580
- warnings.warn('You are not using the TensorFlow backend. '
581
- 'embeddings_freq was set to 0')
582
- embeddings_freq = 0
583
-
584
- self.log_dir = log_dir
585
- self.histogram_freq = histogram_freq
586
- self.merged = None
587
- self.write_graph = write_graph
588
- self.write_grads = write_grads
589
- self.write_images = write_images
590
- self.embeddings_freq = embeddings_freq
591
- self.embeddings_layer_names = embeddings_layer_names
592
- self.embeddings_metadata = embeddings_metadata or {}
593
- self.batch_size = batch_size
594
-
595
- def set_model(self, model):
596
- self.model = model
597
- if K.backend() == 'tensorflow':
598
- self.sess = K.get_session()
599
- if self.histogram_freq and self.merged is None:
600
- for layer in self.model.layers:
601
-
602
- for weight in layer.weights:
603
- mapped_weight_name = weight.name.replace(':', '_')
604
- tf.summary.histogram(mapped_weight_name, weight)
605
- if self.write_grads:
606
- grads = model.optimizer.get_gradients(model.total_loss,
607
- weight)
608
-
609
- def is_indexed_slices(grad):
610
- return type(grad).__name__ == 'IndexedSlices'
611
- grads = [
612
- grad.values if is_indexed_slices(grad) else grad
613
- for grad in grads]
614
- tf.summary.histogram('{}_grad'.format(mapped_weight_name), grads)
615
- if self.write_images:
616
- w_img = tf.squeeze(weight)
617
- shape = K.int_shape(w_img)
618
- if len(shape) == 2: # dense layer kernel case
619
- if shape[0] > shape[1]:
620
- w_img = tf.transpose(w_img)
621
- shape = K.int_shape(w_img)
622
- w_img = tf.reshape(w_img, [1,
623
- shape[0],
624
- shape[1],
625
- 1])
626
- elif len(shape) == 3: # convnet case
627
- if K.image_data_format() == 'channels_last':
628
- # switch to channels_first to display
629
- # every kernel as a separate image
630
- w_img = tf.transpose(w_img, perm=[2, 0, 1])
631
- shape = K.int_shape(w_img)
632
- w_img = tf.reshape(w_img, [shape[0],
633
- shape[1],
634
- shape[2],
635
- 1])
636
- elif len(shape) == 1: # bias case
637
- w_img = tf.reshape(w_img, [1,
638
- shape[0],
639
- 1,
640
- 1])
641
- else:
642
- # not possible to handle 3D convnets etc.
643
- continue
644
-
645
- shape = K.int_shape(w_img)
646
- assert len(shape) == 4 and shape[-1] in [1, 3, 4]
647
- tf.summary.image(mapped_weight_name, w_img)
648
-
649
- if hasattr(layer, 'output'):
650
- tf.summary.histogram('{}_out'.format(layer.name),
651
- layer.output)
652
- self.merged = tf.summary.merge_all()
653
-
654
- if self.write_graph:
655
- self.writer = tf.summary.FileWriter(self.log_dir,
656
- self.sess.graph)
657
- else:
658
- self.writer = tf.summary.FileWriter(self.log_dir)
659
-
660
- if self.embeddings_freq:
661
- embeddings_layer_names = self.embeddings_layer_names
662
-
663
- if not embeddings_layer_names:
664
- embeddings_layer_names = [layer.name for layer in self.model.layers
665
- if type(layer).__name__ == 'Embedding']
666
-
667
- embeddings = {layer.name: layer.weights[0]
668
- for layer in self.model.layers
669
- if layer.name in embeddings_layer_names}
670
-
671
- self.saver = tf.train.Saver(list(embeddings.values()))
672
-
673
- embeddings_metadata = {}
674
-
675
- if not isinstance(self.embeddings_metadata, str):
676
- embeddings_metadata = self.embeddings_metadata
677
- else:
678
- embeddings_metadata = {layer_name: self.embeddings_metadata
679
- for layer_name in embeddings.keys()}
680
-
681
- config = projector.ProjectorConfig()
682
- self.embeddings_ckpt_path = os.path.join(self.log_dir,
683
- 'keras_embedding.ckpt')
684
-
685
- for layer_name, tensor in embeddings.items():
686
- embedding = config.embeddings.add()
687
- embedding.tensor_name = tensor.name
688
-
689
- if layer_name in embeddings_metadata:
690
- embedding.metadata_path = embeddings_metadata[layer_name]
691
-
692
- projector.visualize_embeddings(self.writer, config)
693
-
694
- def on_epoch_end(self, epoch, logs=None):
695
- logs = logs or {}
696
-
697
- if not self.validation_data and self.histogram_freq:
698
- raise ValueError('If printing histograms, validation_data must be '
699
- 'provided, and cannot be a generator.')
700
- if self.validation_data and self.histogram_freq:
701
- if epoch % self.histogram_freq == 0:
702
-
703
- val_data = self.validation_data
704
- tensors = (self.model.inputs +
705
- self.model.targets +
706
- self.model.sample_weights)
707
-
708
- if self.model.uses_learning_phase:
709
- tensors += [K.learning_phase()]
710
-
711
- assert len(val_data) == len(tensors)
712
- val_size = val_data[0].shape[0]
713
- i = 0
714
- while i < val_size:
715
- step = min(self.batch_size, val_size - i)
716
- if self.model.uses_learning_phase:
717
- # do not slice the learning phase
718
- batch_val = [x[i:i + step] for x in val_data[:-1]]
719
- batch_val.append(val_data[-1])
720
- else:
721
- batch_val = [x[i:i + step] for x in val_data]
722
- assert len(batch_val) == len(tensors)
723
- feed_dict = dict(zip(tensors, batch_val))
724
- result = self.sess.run([self.merged], feed_dict=feed_dict)
725
- summary_str = result[0]
726
- self.writer.add_summary(summary_str, epoch)
727
- i += self.batch_size
728
-
729
- if self.embeddings_freq and self.embeddings_ckpt_path:
730
- if epoch % self.embeddings_freq == 0:
731
- self.saver.save(self.sess,
732
- self.embeddings_ckpt_path,
733
- epoch)
734
-
735
- for name, value in logs.items():
736
- if name in ['batch', 'size']:
737
- continue
738
- summary = tf.Summary()
739
- summary_value = summary.value.add()
740
- summary_value.simple_value = value.item()
741
- summary_value.tag = name
742
- self.writer.add_summary(summary, epoch)
743
- self.writer.flush()
744
-
745
- def on_train_end(self, _):
746
- self.writer.close()
747
-
748
-
749
- class CSVLogger(Callback):
750
- """Callback that streams epoch results to a csv file.
751
-
752
- Supports all values that can be represented as a string,
753
- including 1D iterables such as np.ndarray.
754
-
755
- # Example
756
-
757
- ```python
758
- csv_logger = CSVLogger('training.log')
759
- model.fit(X_train, Y_train, callbacks=[csv_logger])
760
- ```
761
-
762
- # Arguments
763
- filename: filename of the csv file, e.g. 'run/log.csv'.
764
- separator: string used to separate elements in the csv file.
765
- append: True: append if file exists (useful for continuing
766
- training). False: overwrite existing file,
767
- output_on_train_end: An additional output file to write to
768
- write to when training ends. An example is
769
- CSVLogger(filename='./mylog.csv', output_on_train_end=os.sys.stdout)
770
- """
771
-
772
- def __init__(self, filename, separator=',', append=False, output_on_train_end=None):
773
- self.sep = separator
774
- self.filename = filename
775
- self.append = append
776
- self.writer = None
777
- self.keys = None
778
- self.append_header = True
779
- self.file_flags = 'b' if six.PY2 and os.name == 'nt' else ''
780
- self.output_on_train_end = output_on_train_end
781
- super(CSVLogger, self).__init__()
782
-
783
- def on_train_begin(self, logs=None):
784
- if self.append:
785
- if os.path.exists(self.filename):
786
- with open(self.filename, 'r' + self.file_flags) as f:
787
- self.append_header = not bool(len(f.readline()))
788
- self.csv_file = open(self.filename, 'a' + self.file_flags)
789
- else:
790
- self.csv_file = open(self.filename, 'w' + self.file_flags)
791
-
792
- def on_epoch_end(self, epoch, logs=None):
793
- logs = logs or {}
794
-
795
- def handle_value(k):
796
- is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
797
- if isinstance(k, six.string_types):
798
- return k
799
- elif isinstance(k, Iterable) and not is_zero_dim_ndarray:
800
- return '"[%s]"' % (', '.join(map(str, k)))
801
- else:
802
- return k
803
-
804
- if self.keys is None:
805
- self.keys = sorted(logs.keys())
806
-
807
- if self.model is not None and getattr(self.model, 'stop_training', False):
808
- # We set NA so that csv parsers do not fail for this last epoch.
809
- logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
810
-
811
- if not self.writer:
812
- class CustomDialect(csv.excel):
813
- delimiter = self.sep
814
-
815
- self.writer = csv.DictWriter(self.csv_file,
816
- fieldnames=['epoch'] + self.keys, dialect=CustomDialect)
817
- if self.append_header:
818
- self.writer.writeheader()
819
-
820
- row_dict = OrderedDict({'epoch': epoch})
821
- row_dict.update((key, handle_value(logs[key])) for key in self.keys)
822
- self.writer.writerow(row_dict)
823
- self.csv_file.flush()
824
-
825
- def on_train_end(self, logs=None):
826
- self.csv_file.close()
827
- if os.path.exists(self.filename):
828
- with open(self.filename, 'r' + self.file_flags) as f:
829
- print(f.read(), file=self.output_on_train_end)
830
- self.writer = None
831
-
832
-
833
- class LambdaCallback(Callback):
834
- r"""Callback for creating simple, custom callbacks on-the-fly.
835
-
836
- This callback is constructed with anonymous functions that will be called
837
- at the appropriate time. Note that the callbacks expects positional
838
- arguments, as:
839
-
840
- - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
841
- `epoch`, `logs`
842
- - `on_batch_begin` and `on_batch_end` expect two positional arguments:
843
- `batch`, `logs`
844
- - `on_train_begin` and `on_train_end` expect one positional argument:
845
- `logs`
846
-
847
- # Arguments
848
- on_epoch_begin: called at the beginning of every epoch.
849
- on_epoch_end: called at the end of every epoch.
850
- on_batch_begin: called at the beginning of every batch.
851
- on_batch_end: called at the end of every batch.
852
- on_train_begin: called at the beginning of model training.
853
- on_train_end: called at the end of model training.
854
-
855
- # Example
856
-
857
- ```python
858
- # Print the batch number at the beginning of every batch.
859
- batch_print_callback = LambdaCallback(
860
- on_batch_begin=lambda batch,logs: print(batch))
861
-
862
- # Stream the epoch loss to a file in JSON format. The file content
863
- # is not well-formed JSON but rather has a JSON object per line.
864
- import json
865
- json_log = open('loss_log.json', mode='wt', buffering=1)
866
- json_logging_callback = LambdaCallback(
867
- on_epoch_end=lambda epoch, logs: json_log.write(
868
- json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
869
- on_train_end=lambda logs: json_log.close()
870
- )
871
-
872
- # Terminate some processes after having finished model training.
873
- processes = ...
874
- cleanup_callback = LambdaCallback(
875
- on_train_end=lambda logs: [
876
- p.terminate() for p in processes if p.is_alive()])
877
-
878
- model.fit(...,
879
- callbacks=[batch_print_callback,
880
- json_logging_callback,
881
- cleanup_callback])
882
- ```
883
- """
884
-
885
- def __init__(self,
886
- on_epoch_begin=None,
887
- on_epoch_end=None,
888
- on_batch_begin=None,
889
- on_batch_end=None,
890
- on_train_begin=None,
891
- on_train_end=None,
892
- **kwargs):
893
- super(LambdaCallback, self).__init__()
894
- self.__dict__.update(kwargs)
895
- if on_epoch_begin is not None:
896
- self.on_epoch_begin = on_epoch_begin
897
- else:
898
- self.on_epoch_begin = lambda epoch, logs: None
899
- if on_epoch_end is not None:
900
- self.on_epoch_end = on_epoch_end
901
- else:
902
- self.on_epoch_end = lambda epoch, logs: None
903
- if on_batch_begin is not None:
904
- self.on_batch_begin = on_batch_begin
905
- else:
906
- self.on_batch_begin = lambda batch, logs: None
907
- if on_batch_end is not None:
908
- self.on_batch_end = on_batch_end
909
- else:
910
- self.on_batch_end = lambda batch, logs: None
911
- if on_train_begin is not None:
912
- self.on_train_begin = on_train_begin
913
- else:
914
- self.on_train_begin = lambda logs: None
915
- if on_train_end is not None:
916
- self.on_train_end = on_train_end
917
- else:
918
- self.on_train_end = lambda logs: None
919
- from sys import stderr
920
-
921
-
922
- class TQDMCallback(Callback):
923
- def __init__(self, outer_description="Training",
924
- inner_description_initial="Epoch: {epoch}",
925
- inner_description_update="Epoch: {epoch} - {metrics}",
926
- metric_format="{name}: {value:0.3f}",
927
- separator=", ",
928
- leave_inner=True,
929
- leave_outer=True,
930
- show_inner=True,
931
- show_outer=True,
932
- output_file=stderr,
933
- initial=0):
934
- """
935
- Construct a callback that will create and update progress bars.
936
-
937
- :param outer_description: string for outer progress bar
938
- :param inner_description_initial: initial format for epoch ("Epoch: {epoch}")
939
- :param inner_description_update: format after metrics collected ("Epoch: {epoch} - {metrics}")
940
- :param metric_format: format for each metric name/value pair ("{name}: {value:0.3f}")
941
- :param separator: separator between metrics (", ")
942
- :param leave_inner: True to leave inner bars
943
- :param leave_outer: True to leave outer bars
944
- :param show_inner: False to hide inner bars
945
- :param show_outer: False to hide outer bar
946
- :param output_file: output file (default sys.stderr)
947
- :param initial: Initial counter state
948
- """
949
- self.outer_description = outer_description
950
- self.inner_description_initial = inner_description_initial
951
- self.inner_description_update = inner_description_update
952
- self.metric_format = metric_format
953
- self.separator = separator
954
- self.leave_inner = leave_inner
955
- self.leave_outer = leave_outer
956
- self.show_inner = show_inner
957
- self.show_outer = show_outer
958
- self.output_file = output_file
959
- self.tqdm_outer = None
960
- self.tqdm_inner = None
961
- self.epoch = None
962
- self.running_logs = None
963
- self.inner_count = None
964
- self.initial = initial
965
-
966
- def tqdm(self, desc, total, leave, initial=0):
967
- """
968
- Extension point. Override to provide custom options to tqdm initializer.
969
- :param desc: Description string
970
- :param total: Total number of updates
971
- :param leave: Leave progress bar when done
972
- :param initial: Initial counter state
973
- :return: new progress bar
974
- """
975
- return tqdm(desc=desc, total=total, leave=leave, file=self.output_file, initial=initial)
976
-
977
- def build_tqdm_outer(self, desc, total):
978
- """
979
- Extension point. Override to provide custom options to outer progress bars (Epoch loop)
980
- :param desc: Description
981
- :param total: Number of epochs
982
- :return: new progress bar
983
- """
984
- return self.tqdm(desc=desc, total=total, leave=self.leave_outer, initial=self.initial)
985
-
986
- def build_tqdm_inner(self, desc, total):
987
- """
988
- Extension point. Override to provide custom options to inner progress bars (Batch loop)
989
- :param desc: Description
990
- :param total: Number of batches
991
- :return: new progress bar
992
- """
993
- return self.tqdm(desc=desc, total=total, leave=self.leave_inner)
994
-
995
- def on_epoch_begin(self, epoch, logs={}):
996
- self.epoch = epoch
997
- desc = self.inner_description_initial.format(epoch=self.epoch)
998
- self.mode = 0 # samples
999
- if 'samples' in self.params:
1000
- self.inner_total = self.params['samples']
1001
- elif 'nb_sample' in self.params:
1002
- self.inner_total = self.params['nb_sample']
1003
- else:
1004
- self.mode = 1 # steps
1005
- self.inner_total = self.params['steps']
1006
- if self.show_inner:
1007
- self.tqdm_inner = self.build_tqdm_inner(desc=desc, total=self.inner_total)
1008
- self.inner_count = 0
1009
- self.running_logs = {}
1010
-
1011
- def on_epoch_end(self, epoch, logs={}):
1012
- metrics = self.format_metrics(logs)
1013
- desc = self.inner_description_update.format(epoch=epoch, metrics=metrics)
1014
- if self.show_inner:
1015
- self.tqdm_inner.desc = desc
1016
- # set miniters and mininterval to 0 so last update displays
1017
- self.tqdm_inner.miniters = 0
1018
- self.tqdm_inner.mininterval = 0
1019
- self.tqdm_inner.update(self.inner_total - self.tqdm_inner.n)
1020
- self.tqdm_inner.close()
1021
- if self.show_outer:
1022
- self.tqdm_outer.update(1)
1023
-
1024
- def on_batch_begin(self, batch, logs={}):
1025
- pass
1026
-
1027
- def on_batch_end(self, batch, logs={}):
1028
- if self.mode == 0:
1029
- update = logs['size']
1030
- else:
1031
- update = 1
1032
- self.inner_count += update
1033
- if self.inner_count < self.inner_total:
1034
- self.append_logs(logs)
1035
- metrics = self.format_metrics(self.running_logs)
1036
- desc = self.inner_description_update.format(epoch=self.epoch, metrics=metrics)
1037
- if self.show_inner:
1038
- self.tqdm_inner.desc = desc
1039
- self.tqdm_inner.update(update)
1040
-
1041
- def on_train_begin(self, logs={}):
1042
- if self.show_outer:
1043
- epochs = (self.params['epochs'] if 'epochs' in self.params
1044
- else self.params['nb_epoch'])
1045
- self.tqdm_outer = self.build_tqdm_outer(desc=self.outer_description,
1046
- total=epochs)
1047
-
1048
- def on_train_end(self, logs={}):
1049
- if self.show_outer:
1050
- self.tqdm_outer.close()
1051
-
1052
- def append_logs(self, logs):
1053
- metrics = self.params['metrics']
1054
- for metric, value in six.iteritems(logs):
1055
- if metric in metrics:
1056
- if metric in self.running_logs:
1057
- self.running_logs[metric].append(value[()])
1058
- else:
1059
- self.running_logs[metric] = [value[()]]
1060
-
1061
- def format_metrics(self, logs):
1062
- metrics = self.params['metrics']
1063
- strings = [self.metric_format.format(name=metric, value=np.mean(logs[metric], axis=None)) for metric in metrics
1064
- if
1065
- metric in logs]
1066
- return self.separator.join(strings)