# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Evaluates text classification model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import time # Dependency imports import tensorflow as tf import graphs flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_string('master', '', 'BNS name prefix of the Tensorflow eval master, ' 'or "local".') flags.DEFINE_string('eval_dir', '/tmp/text_eval', 'Directory where to write event logs.') flags.DEFINE_string('eval_data', 'test', 'Specify which dataset is used. ' '("train", "valid", "test") ') flags.DEFINE_string('checkpoint_dir', '/tmp/text_train', 'Directory where to read model checkpoints.') flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run the eval.') flags.DEFINE_integer('num_examples', 32, 'Number of examples to run.') flags.DEFINE_bool('run_once', False, 'Whether to run eval only once.') def restore_from_checkpoint(sess, saver): """Restore model from checkpoint. Args: sess: Session. saver: Saver for restoring the checkpoint. Returns: bool: Whether the checkpoint was found and restored """ ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if not ckpt or not ckpt.model_checkpoint_path: tf.logging.info('No checkpoint found at %s', FLAGS.checkpoint_dir) return False saver.restore(sess, ckpt.model_checkpoint_path) return True def run_eval(eval_ops, summary_writer, saver): """Runs evaluation over FLAGS.num_examples examples. Args: eval_ops: dict summary_writer: Summary writer. saver: Saver. Returns: dict, with value being the average over all examples. """ sv = tf.train.Supervisor( logdir=FLAGS.eval_dir, saver=None, summary_op=None, summary_writer=None) with sv.managed_session( master=FLAGS.master, start_standard_services=False) as sess: if not restore_from_checkpoint(sess, saver): return sv.start_queue_runners(sess) metric_names, ops = zip(*eval_ops.items()) value_ops, update_ops = zip(*ops) value_ops_dict = dict(zip(metric_names, value_ops)) # Run update ops num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) tf.logging.info('Running %d batches for evaluation.', num_batches) for i in range(num_batches): if (i + 1) % 10 == 0: tf.logging.info('Running batch %d/%d...', i + 1, num_batches) if (i + 1) % 50 == 0: _log_values(sess, value_ops_dict) sess.run(update_ops) _log_values(sess, value_ops_dict, summary_writer=summary_writer) def _log_values(sess, value_ops, summary_writer=None): """Evaluate, log, and write summaries of the eval metrics in value_ops.""" metric_names, value_ops = zip(*value_ops.items()) values = sess.run(value_ops) tf.logging.info('Eval metric values:') summary = tf.summary.Summary() for name, val in zip(metric_names, values): summary.value.add(tag=name, simple_value=val) tf.logging.info('%s = %.3f', name, val) if summary_writer is not None: global_step_val = sess.run(tf.train.get_global_step()) tf.logging.info('Finished eval for step ' + str(global_step_val)) summary_writer.add_summary(summary, global_step_val) def main(_): tf.logging.set_verbosity(tf.logging.INFO) tf.gfile.MakeDirs(FLAGS.eval_dir) tf.logging.info('Building eval graph...') output = graphs.get_model().eval_graph(FLAGS.eval_data) eval_ops, moving_averaged_variables = output saver = tf.train.Saver(moving_averaged_variables) summary_writer = tf.summary.FileWriter( FLAGS.eval_dir, graph=tf.get_default_graph()) while True: run_eval(eval_ops, summary_writer, saver) if FLAGS.run_once: break time.sleep(FLAGS.eval_interval_secs) if __name__ == '__main__': tf.app.run()