NCTC / models /research /seq2species /run_training.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
11.3 kB
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines training scheme for neural networks for Seq2Species prediction.
Defines and runs the loop for training a (optionally) depthwise separable
convolutional model for predicting taxonomic labels from short reads of DNA.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import flags
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
import build_model
import configuration
import input as seq2species_input
from protos import seq2label_pb2
import seq2label_utils
# Define non-tunable parameters.
flags.DEFINE_integer('num_filters', 1, 'Number of filters for conv model')
flags.DEFINE_string('hparams', '',
'Comma-separated list of name=value hyperparameter '
"pairs ('hp1=value1,hp2=value2'). Unspecified "
'hyperparameters will be filled with defaults.')
flags.DEFINE_integer('batch_size', 512, 'Size of batches during training.')
flags.DEFINE_integer('min_train_steps', 1000,
'Minimum number of training steps to run.')
flags.DEFINE_float('max_task_loss', 10.0,
"Terminate trial if task loss doesn't fall below this "
'within --min_train_steps.')
flags.DEFINE_integer('n_print_progress_every', 1000,
'Print training progress every '
'--n_print_progress_every global steps.')
flags.DEFINE_list('targets', ['species'],
'Names of taxonomic ranks to use as training targets.')
flags.DEFINE_float(
'noise_rate', 0.0, 'Rate [0.0, 1.0] at which to inject '
'base-flipping noise into input read sequences.')
# Define paths to logs and data.
flags.DEFINE_list(
'train_files', [], 'Full paths to the TFRecords containing the '
'training examples.')
flags.DEFINE_string(
'metadata_path', '', 'Full path of the text proto containing configuration '
'information about the set of training examples.')
flags.DEFINE_string('logdir', '/tmp/seq2species',
'Directory to which to write logs.')
# Define supervisor/checkpointing options.
flags.DEFINE_integer('task', 0, 'Task ID of the replica running the training.')
flags.DEFINE_string('master', '', 'Name of the TF master to use.')
flags.DEFINE_integer(
'save_model_secs', 900, 'Rate at which to save model parameters. '
'Set to 0 to disable checkpointing.')
flags.DEFINE_integer('recovery_wait_secs', 30,
'Wait to recover model from checkpoint '
'before timing out.')
flags.DEFINE_integer('save_summaries_secs', 900,
'Rate at which to save Tensorboard summaries.')
flags.DEFINE_integer('ps_tasks', 0,
'Number of tasks in the ps job; 0 if no ps is used.')
FLAGS = flags.FLAGS
RANDOM_SEED = 42
def wait_until(time_sec):
"""Stalls execution until a given time.
Args:
time_sec: time, in seconds, until which to loop idly.
"""
while time.time() < time_sec:
pass
def update_measures(measures, new_measures, loss_val, max_loss=None):
"""Updates tracking of experimental measures and infeasibilty.
Args:
measures: dict; mapping from measure name to measure value.
new_measures: dict; mapping from measure name to new measure values.
loss_val: float; value of loss metric by which to determine fesibility.
max_loss: float; maximum value at which to consider the loss feasible.
Side Effects:
Updates the given mapping of measures and values based on the current
experimental metrics stored in new_measures, and determines current
feasibility of the experiment based on the provided loss value.
"""
max_loss = max_loss if max_loss else np.finfo('f').max
measures['is_infeasible'] = (
loss_val >= max_loss or not np.isfinite(loss_val))
measures.update(new_measures)
def run_training(model, hparams, training_dataset, logdir, batch_size):
"""Trains the given model on random mini-batches of reads.
Args:
model: ConvolutionalNet instance containing the model graph and operations.
hparams: tf.contrib.training.Hparams object containing the model's
hyperparamters; see configuration.py for hyperparameter definitions.
training_dataset: an `InputDataset` that can feed labelled examples.
logdir: string; full path of directory to which to save checkpoints.
batch_size: integer batch size.
Yields:
Tuple comprising a dictionary of experimental measures and the save path
for train checkpoints and summaries.
"""
input_params = dict(batch_size=batch_size)
features, labels = training_dataset.input_fn(input_params)
model.build_graph(features, labels, tf.estimator.ModeKeys.TRAIN, batch_size)
is_chief = FLAGS.task == 0
scaffold = tf.train.Scaffold(
saver=tf.train.Saver(
tf.global_variables(),
max_to_keep=5,
keep_checkpoint_every_n_hours=1.0),
init_op=tf.global_variables_initializer(),
summary_op=model.summary_op)
with tf.train.MonitoredTrainingSession(
master=FLAGS.master,
checkpoint_dir=logdir,
is_chief=is_chief,
scaffold=scaffold,
save_summaries_secs=FLAGS.save_summaries_secs,
save_checkpoint_secs=FLAGS.save_model_secs,
max_wait_secs=FLAGS.recovery_wait_secs) as sess:
global_step = sess.run(model.global_step)
print('Initialized model at global step ', global_step)
init_time = time.time()
measures = {'is_infeasible': False}
if is_chief:
model_info = seq2label_utils.construct_seq2label_model_info(
hparams, 'conv', FLAGS.targets, FLAGS.metadata_path, FLAGS.batch_size,
FLAGS.num_filters, FLAGS.noise_rate)
write_message(model_info, os.path.join(logdir, 'model_info.pbtxt'))
ops = [
model.accuracy, model.weighted_accuracy, model.total_loss,
model.global_step, model.train_op
]
while not sess.should_stop() and global_step < hparams.train_steps:
accuracy, weighted_accuracy, loss, global_step, _ = sess.run(ops)
def gather_measures():
"""Updates the measures dictionary from this batch."""
new_measures = {'train_loss': loss, 'global_step': global_step}
for target in FLAGS.targets:
new_measures.update({
('train_accuracy/%s' % target): accuracy[target],
('train_weighted_accuracy/%s' % target): weighted_accuracy[target]
})
update_measures(
measures, new_measures, loss, max_loss=FLAGS.max_task_loss)
# Periodically track measures according to current mini-batch performance.
# Log a message.
if global_step % FLAGS.n_print_progress_every == 0:
log_message = ('\tstep: %d (%d sec), loss: %f' %
(global_step, time.time() - init_time, loss))
for target in FLAGS.targets:
log_message += (', accuracy/%s: %f ' % (target, accuracy[target]))
log_message += (', weighted_accuracy/%s: %f ' %
(target, weighted_accuracy[target]))
print(log_message)
# Gather new measures and update the measures dictionary.
gather_measures()
yield measures, scaffold.saver.last_checkpoints[-1]
# Check for additional stopping criteria.
if not np.isfinite(loss) or (loss >= FLAGS.max_task_loss and
global_step > FLAGS.min_train_steps):
break
# Always yield once at the end.
gather_measures()
yield measures, scaffold.saver.last_checkpoints[-1]
def write_message(message, filename):
"""Writes contents of the given message to the given filename as a text proto.
Args:
message: the proto message to save.
filename: full path of file to which to save the text proto.
Side Effects:
Outputs a text proto file to the given filename.
"""
message_string = text_format.MessageToString(message)
with tf.gfile.GFile(filename, 'w') as f:
f.write(message_string)
def write_measures(measures, checkpoint_file, init_time):
"""Writes performance measures to file.
Args:
measures: dict; mapping from measure name to measure value.
checkpoint_file: string; full save path for checkpoints and summaries.
init_time: int; start time for work on the current experiment.
Side Effects:
Writes given dictionary of performance measures for the current experiment
to a 'measures.pbtxt' file in the checkpoint directory.
"""
# Save experiment measures.
print('global_step: ', measures['global_step'])
experiment_measures = seq2label_pb2.Seq2LabelExperimentMeasures(
checkpoint_path=checkpoint_file,
steps=measures['global_step'],
experiment_infeasible=measures['is_infeasible'],
wall_time=time.time() - init_time) # Inaccurate for restarts.
for name, value in measures.iteritems():
if name not in ['is_infeasible', 'global_step']:
experiment_measures.measures.add(name=name, value=value)
measures_file = os.path.join(
os.path.dirname(checkpoint_file), 'measures.pbtxt')
write_message(experiment_measures, measures_file)
print('Wrote ', measures_file,
' containing the following experiment measures:\n', experiment_measures)
def main(unused_argv):
dataset_info = seq2species_input.load_dataset_info(FLAGS.metadata_path)
init_time = time.time()
# Determine model hyperparameters.
hparams = configuration.parse_hparams(FLAGS.hparams, FLAGS.num_filters)
print('Current Hyperparameters:')
for hp_name, hp_val in hparams.values().items():
print('\t', hp_name, ': ', hp_val)
# Initialize the model graph.
print('Constructing TensorFlow Graph.')
tf.reset_default_graph()
input_dataset = seq2species_input.InputDataset.from_tfrecord_files(
FLAGS.train_files,
'train',
FLAGS.targets,
dataset_info,
noise_rate=FLAGS.noise_rate,
random_seed=RANDOM_SEED)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
model = build_model.ConvolutionalNet(
hparams, dataset_info, targets=FLAGS.targets)
# Run the experiment.
measures, checkpoint_file = None, None
print('Starting model training.')
for cur_measures, cur_file in run_training(
model, hparams, input_dataset, FLAGS.logdir, batch_size=FLAGS.batch_size):
measures, checkpoint_file = cur_measures, cur_file
# Save experiment results.
write_measures(measures, checkpoint_file, init_time)
if __name__ == '__main__':
tf.app.run(main)