NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.62 kB
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluates text classification model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import time
# Dependency imports
import tensorflow as tf
import graphs
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('master', '',
'BNS name prefix of the Tensorflow eval master, '
'or "local".')
flags.DEFINE_string('eval_dir', '/tmp/text_eval',
'Directory where to write event logs.')
flags.DEFINE_string('eval_data', 'test', 'Specify which dataset is used. '
'("train", "valid", "test") ')
flags.DEFINE_string('checkpoint_dir', '/tmp/text_train',
'Directory where to read model checkpoints.')
flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run the eval.')
flags.DEFINE_integer('num_examples', 32, 'Number of examples to run.')
flags.DEFINE_bool('run_once', False, 'Whether to run eval only once.')
def restore_from_checkpoint(sess, saver):
"""Restore model from checkpoint.
Args:
sess: Session.
saver: Saver for restoring the checkpoint.
Returns:
bool: Whether the checkpoint was found and restored
"""
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if not ckpt or not ckpt.model_checkpoint_path:
tf.logging.info('No checkpoint found at %s', FLAGS.checkpoint_dir)
return False
saver.restore(sess, ckpt.model_checkpoint_path)
return True
def run_eval(eval_ops, summary_writer, saver):
"""Runs evaluation over FLAGS.num_examples examples.
Args:
eval_ops: dict<metric name, tuple(value, update_op)>
summary_writer: Summary writer.
saver: Saver.
Returns:
dict<metric name, value>, with value being the average over all examples.
"""
sv = tf.train.Supervisor(
logdir=FLAGS.eval_dir, saver=None, summary_op=None, summary_writer=None)
with sv.managed_session(
master=FLAGS.master, start_standard_services=False) as sess:
if not restore_from_checkpoint(sess, saver):
return
sv.start_queue_runners(sess)
metric_names, ops = zip(*eval_ops.items())
value_ops, update_ops = zip(*ops)
value_ops_dict = dict(zip(metric_names, value_ops))
# Run update ops
num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
tf.logging.info('Running %d batches for evaluation.', num_batches)
for i in range(num_batches):
if (i + 1) % 10 == 0:
tf.logging.info('Running batch %d/%d...', i + 1, num_batches)
if (i + 1) % 50 == 0:
_log_values(sess, value_ops_dict)
sess.run(update_ops)
_log_values(sess, value_ops_dict, summary_writer=summary_writer)
def _log_values(sess, value_ops, summary_writer=None):
"""Evaluate, log, and write summaries of the eval metrics in value_ops."""
metric_names, value_ops = zip(*value_ops.items())
values = sess.run(value_ops)
tf.logging.info('Eval metric values:')
summary = tf.summary.Summary()
for name, val in zip(metric_names, values):
summary.value.add(tag=name, simple_value=val)
tf.logging.info('%s = %.3f', name, val)
if summary_writer is not None:
global_step_val = sess.run(tf.train.get_global_step())
tf.logging.info('Finished eval for step ' + str(global_step_val))
summary_writer.add_summary(summary, global_step_val)
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.gfile.MakeDirs(FLAGS.eval_dir)
tf.logging.info('Building eval graph...')
output = graphs.get_model().eval_graph(FLAGS.eval_data)
eval_ops, moving_averaged_variables = output
saver = tf.train.Saver(moving_averaged_variables)
summary_writer = tf.summary.FileWriter(
FLAGS.eval_dir, graph=tf.get_default_graph())
while True:
run_eval(eval_ops, summary_writer, saver)
if FLAGS.run_once:
break
time.sleep(FLAGS.eval_interval_secs)
if __name__ == '__main__':
tf.app.run()