Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""Training executable for detection models. | |
This executable is used to train DetectionModels. There are two ways of | |
configuring the training job: | |
1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file | |
can be specified by --pipeline_config_path. | |
Example usage: | |
./train \ | |
--logtostderr \ | |
--train_dir=path/to/train_dir \ | |
--pipeline_config_path=pipeline_config.pbtxt | |
2) Three configuration files can be provided: a model_pb2.DetectionModel | |
configuration file to define what type of DetectionModel is being trained, an | |
input_reader_pb2.InputReader file to specify what training data will be used and | |
a train_pb2.TrainConfig file to configure training parameters. | |
Example usage: | |
./train \ | |
--logtostderr \ | |
--train_dir=path/to/train_dir \ | |
--model_config_path=model_config.pbtxt \ | |
--train_config_path=train_config.pbtxt \ | |
--input_config_path=train_input_config.pbtxt | |
""" | |
import functools | |
import json | |
import os | |
import tensorflow.compat.v1 as tf | |
from tensorflow.python.util.deprecation import deprecated | |
from object_detection.builders import dataset_builder | |
from object_detection.builders import graph_rewriter_builder | |
from object_detection.builders import model_builder | |
from object_detection.legacy import trainer | |
from object_detection.utils import config_util | |
tf.logging.set_verbosity(tf.logging.INFO) | |
flags = tf.app.flags | |
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.') | |
flags.DEFINE_integer('task', 0, 'task id') | |
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.') | |
flags.DEFINE_boolean('clone_on_cpu', False, | |
'Force clones to be deployed on CPU. Note that even if ' | |
'set to False (allowing ops to run on gpu), some ops may ' | |
'still be run on the CPU if they have no GPU kernel.') | |
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer ' | |
'replicas.') | |
flags.DEFINE_integer('ps_tasks', 0, | |
'Number of parameter server tasks. If None, does not use ' | |
'a parameter server.') | |
flags.DEFINE_string('train_dir', '', | |
'Directory to save the checkpoints and training summaries.') | |
flags.DEFINE_string('pipeline_config_path', '', | |
'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' | |
'file. If provided, other configs are ignored') | |
flags.DEFINE_string('train_config_path', '', | |
'Path to a train_pb2.TrainConfig config file.') | |
flags.DEFINE_string('input_config_path', '', | |
'Path to an input_reader_pb2.InputReader config file.') | |
flags.DEFINE_string('model_config_path', '', | |
'Path to a model_pb2.DetectionModel config file.') | |
FLAGS = flags.FLAGS | |
def main(_): | |
assert FLAGS.train_dir, '`train_dir` is missing.' | |
if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir) | |
if FLAGS.pipeline_config_path: | |
configs = config_util.get_configs_from_pipeline_file( | |
FLAGS.pipeline_config_path) | |
if FLAGS.task == 0: | |
tf.gfile.Copy(FLAGS.pipeline_config_path, | |
os.path.join(FLAGS.train_dir, 'pipeline.config'), | |
overwrite=True) | |
else: | |
configs = config_util.get_configs_from_multiple_files( | |
model_config_path=FLAGS.model_config_path, | |
train_config_path=FLAGS.train_config_path, | |
train_input_config_path=FLAGS.input_config_path) | |
if FLAGS.task == 0: | |
for name, config in [('model.config', FLAGS.model_config_path), | |
('train.config', FLAGS.train_config_path), | |
('input.config', FLAGS.input_config_path)]: | |
tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name), | |
overwrite=True) | |
model_config = configs['model'] | |
train_config = configs['train_config'] | |
input_config = configs['train_input_config'] | |
model_fn = functools.partial( | |
model_builder.build, | |
model_config=model_config, | |
is_training=True) | |
def get_next(config): | |
return dataset_builder.make_initializable_iterator( | |
dataset_builder.build(config)).get_next() | |
create_input_dict_fn = functools.partial(get_next, input_config) | |
env = json.loads(os.environ.get('TF_CONFIG', '{}')) | |
cluster_data = env.get('cluster', None) | |
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None | |
task_data = env.get('task', None) or {'type': 'master', 'index': 0} | |
task_info = type('TaskSpec', (object,), task_data) | |
# Parameters for a single worker. | |
ps_tasks = 0 | |
worker_replicas = 1 | |
worker_job_name = 'lonely_worker' | |
task = 0 | |
is_chief = True | |
master = '' | |
if cluster_data and 'worker' in cluster_data: | |
# Number of total worker replicas include "worker"s and the "master". | |
worker_replicas = len(cluster_data['worker']) + 1 | |
if cluster_data and 'ps' in cluster_data: | |
ps_tasks = len(cluster_data['ps']) | |
if worker_replicas > 1 and ps_tasks < 1: | |
raise ValueError('At least 1 ps task is needed for distributed training.') | |
if worker_replicas >= 1 and ps_tasks > 0: | |
# Set up distributed training. | |
server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc', | |
job_name=task_info.type, | |
task_index=task_info.index) | |
if task_info.type == 'ps': | |
server.join() | |
return | |
worker_job_name = '%s/task:%d' % (task_info.type, task_info.index) | |
task = task_info.index | |
is_chief = (task_info.type == 'master') | |
master = server.target | |
graph_rewriter_fn = None | |
if 'graph_rewriter_config' in configs: | |
graph_rewriter_fn = graph_rewriter_builder.build( | |
configs['graph_rewriter_config'], is_training=True) | |
trainer.train( | |
create_input_dict_fn, | |
model_fn, | |
train_config, | |
master, | |
task, | |
FLAGS.num_clones, | |
worker_replicas, | |
FLAGS.clone_on_cpu, | |
ps_tasks, | |
worker_job_name, | |
is_chief, | |
FLAGS.train_dir, | |
graph_hook_fn=graph_rewriter_fn) | |
if __name__ == '__main__': | |
tf.app.run() | |