# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Tracks and saves training progress (models and other data such as the current location in the lm1b corpus) for later reloading. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from base import utils from corpus_processing import unlabeled_data class TrainingProgress(object): def __init__(self, config, sess, checkpoint_saver, best_model_saver, restore_if_possible=True): self.config = config self.checkpoint_saver = checkpoint_saver self.best_model_saver = best_model_saver tf.gfile.MakeDirs(config.checkpoints_dir) if restore_if_possible and tf.gfile.Exists(config.progress): history, current_file, current_line = utils.load_cpickle( config.progress, memoized=False) self.history = history self.unlabeled_data_reader = unlabeled_data.UnlabeledDataReader( config, current_file, current_line) utils.log("Continuing from global step", dict(self.history[-1])["step"], "(lm1b file {:}, line {:})".format(current_file, current_line)) self.checkpoint_saver.restore(sess, tf.train.latest_checkpoint( self.config.checkpoints_dir)) else: utils.log("No previous checkpoint found - starting from scratch") self.history = [] self.unlabeled_data_reader = ( unlabeled_data.UnlabeledDataReader(config)) def write(self, sess, global_step): self.checkpoint_saver.save(sess, self.config.checkpoint, global_step=global_step) utils.write_cpickle( (self.history, self.unlabeled_data_reader.current_file, self.unlabeled_data_reader.current_line), self.config.progress) def save_if_best_dev_model(self, sess, global_step): best_avg_score = 0 for i, results in enumerate(self.history): if any("train" in metric for metric, value in results): continue total, count = 0, 0 for metric, value in results: if "f1" in metric or "las" in metric or "accuracy" in metric: total += value count += 1 avg_score = total / count if avg_score >= best_avg_score: best_avg_score = avg_score if i == len(self.history) - 1: utils.log("New best model! Saving...") self.best_model_saver.save(sess, self.config.best_model_checkpoint, global_step=global_step)