NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
18.8 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer for coordinating single or multi-replica training.
Main point of entry for running models. Specifies most of
the parameters used by different algorithms.
"""
import tensorflow as tf
import numpy as np
import random
import os
import pickle
from six.moves import xrange
import controller
import model
import policy
import baseline
import objective
import full_episode_objective
import trust_region
import optimizers
import replay_buffer
import expert_paths
import gym_wrapper
import env_spec
app = tf.app
flags = tf.flags
logging = tf.logging
gfile = tf.gfile
FLAGS = flags.FLAGS
flags.DEFINE_string('env', 'Copy-v0', 'environment name')
flags.DEFINE_integer('batch_size', 100, 'batch size')
flags.DEFINE_integer('replay_batch_size', None, 'replay batch size; defaults to batch_size')
flags.DEFINE_integer('num_samples', 1,
'number of samples from each random seed initialization')
flags.DEFINE_integer('max_step', 200, 'max number of steps to train on')
flags.DEFINE_integer('cutoff_agent', 0,
'number of steps at which to cut-off agent. '
'Defaults to always cutoff')
flags.DEFINE_integer('num_steps', 100000, 'number of training steps')
flags.DEFINE_integer('validation_frequency', 100,
'every so many steps, output some stats')
flags.DEFINE_float('target_network_lag', 0.95,
'This exponential decay on online network yields target '
'network')
flags.DEFINE_string('sample_from', 'online',
'Sample actions from "online" network or "target" network')
flags.DEFINE_string('objective', 'pcl',
'pcl/upcl/a3c/trpo/reinforce/urex')
flags.DEFINE_bool('trust_region_p', False,
'use trust region for policy optimization')
flags.DEFINE_string('value_opt', None,
'leave as None to optimize it along with policy '
'(using critic_weight). Otherwise set to '
'"best_fit" (least squares regression), "lbfgs", or "grad"')
flags.DEFINE_float('max_divergence', 0.01,
'max divergence (i.e. KL) to allow during '
'trust region optimization')
flags.DEFINE_float('learning_rate', 0.01, 'learning rate')
flags.DEFINE_float('clip_norm', 5.0, 'clip norm')
flags.DEFINE_float('clip_adv', 0.0, 'Clip advantages at this value. '
'Leave as 0 to not clip at all.')
flags.DEFINE_float('critic_weight', 0.1, 'critic weight')
flags.DEFINE_float('tau', 0.1, 'entropy regularizer.'
'If using decaying tau, this is the final value.')
flags.DEFINE_float('tau_decay', None,
'decay tau by this much every 100 steps')
flags.DEFINE_float('tau_start', 0.1,
'start tau at this value')
flags.DEFINE_float('eps_lambda', 0.0, 'relative entropy regularizer.')
flags.DEFINE_bool('update_eps_lambda', False,
'Update lambda automatically based on last 100 episodes.')
flags.DEFINE_float('gamma', 1.0, 'discount')
flags.DEFINE_integer('rollout', 10, 'rollout')
flags.DEFINE_bool('use_target_values', False,
'use target network for value estimates')
flags.DEFINE_bool('fixed_std', True,
'fix the std in Gaussian distributions')
flags.DEFINE_bool('input_prev_actions', True,
'input previous actions to policy network')
flags.DEFINE_bool('recurrent', True,
'use recurrent connections')
flags.DEFINE_bool('input_time_step', False,
'input time step into value calucations')
flags.DEFINE_bool('use_online_batch', True, 'train on batches as they are sampled')
flags.DEFINE_bool('batch_by_steps', False,
'ensure each training batch has batch_size * max_step steps')
flags.DEFINE_bool('unify_episodes', False,
'Make sure replay buffer holds entire episodes, '
'even across distinct sampling steps')
flags.DEFINE_integer('replay_buffer_size', 5000, 'replay buffer size')
flags.DEFINE_float('replay_buffer_alpha', 0.5, 'replay buffer alpha param')
flags.DEFINE_integer('replay_buffer_freq', 0,
'replay buffer frequency (only supports -1/0/1)')
flags.DEFINE_string('eviction', 'rand',
'how to evict from replay buffer: rand/rank/fifo')
flags.DEFINE_string('prioritize_by', 'rewards',
'Prioritize replay buffer by "rewards" or "step"')
flags.DEFINE_integer('num_expert_paths', 0,
'number of expert paths to seed replay buffer with')
flags.DEFINE_integer('internal_dim', 256, 'RNN internal dim')
flags.DEFINE_integer('value_hidden_layers', 0,
'number of hidden layers in value estimate')
flags.DEFINE_integer('tf_seed', 42, 'random seed for tensorflow')
flags.DEFINE_string('save_trajectories_dir', None,
'directory to save trajectories to, if desired')
flags.DEFINE_string('load_trajectories_file', None,
'file to load expert trajectories from')
# supervisor flags
flags.DEFINE_bool('supervisor', False, 'use supervisor training')
flags.DEFINE_integer('task_id', 0, 'task id')
flags.DEFINE_integer('ps_tasks', 0, 'number of ps tasks')
flags.DEFINE_integer('num_replicas', 1, 'number of replicas used')
flags.DEFINE_string('master', 'local', 'name of master')
flags.DEFINE_string('save_dir', '', 'directory to save model to')
flags.DEFINE_string('load_path', '', 'path of saved model to load (if none in save_dir)')
class Trainer(object):
"""Coordinates single or multi-replica training."""
def __init__(self):
self.batch_size = FLAGS.batch_size
self.replay_batch_size = FLAGS.replay_batch_size
if self.replay_batch_size is None:
self.replay_batch_size = self.batch_size
self.num_samples = FLAGS.num_samples
self.env_str = FLAGS.env
self.env = gym_wrapper.GymWrapper(self.env_str,
distinct=FLAGS.batch_size // self.num_samples,
count=self.num_samples)
self.eval_env = gym_wrapper.GymWrapper(
self.env_str,
distinct=FLAGS.batch_size // self.num_samples,
count=self.num_samples)
self.env_spec = env_spec.EnvSpec(self.env.get_one())
self.max_step = FLAGS.max_step
self.cutoff_agent = FLAGS.cutoff_agent
self.num_steps = FLAGS.num_steps
self.validation_frequency = FLAGS.validation_frequency
self.target_network_lag = FLAGS.target_network_lag
self.sample_from = FLAGS.sample_from
assert self.sample_from in ['online', 'target']
self.critic_weight = FLAGS.critic_weight
self.objective = FLAGS.objective
self.trust_region_p = FLAGS.trust_region_p
self.value_opt = FLAGS.value_opt
assert not self.trust_region_p or self.objective in ['pcl', 'trpo']
assert self.objective != 'trpo' or self.trust_region_p
assert self.value_opt is None or self.value_opt == 'None' or \
self.critic_weight == 0.0
self.max_divergence = FLAGS.max_divergence
self.learning_rate = FLAGS.learning_rate
self.clip_norm = FLAGS.clip_norm
self.clip_adv = FLAGS.clip_adv
self.tau = FLAGS.tau
self.tau_decay = FLAGS.tau_decay
self.tau_start = FLAGS.tau_start
self.eps_lambda = FLAGS.eps_lambda
self.update_eps_lambda = FLAGS.update_eps_lambda
self.gamma = FLAGS.gamma
self.rollout = FLAGS.rollout
self.use_target_values = FLAGS.use_target_values
self.fixed_std = FLAGS.fixed_std
self.input_prev_actions = FLAGS.input_prev_actions
self.recurrent = FLAGS.recurrent
assert not self.trust_region_p or not self.recurrent
self.input_time_step = FLAGS.input_time_step
assert not self.input_time_step or (self.cutoff_agent <= self.max_step)
self.use_online_batch = FLAGS.use_online_batch
self.batch_by_steps = FLAGS.batch_by_steps
self.unify_episodes = FLAGS.unify_episodes
if self.unify_episodes:
assert self.batch_size == 1
self.replay_buffer_size = FLAGS.replay_buffer_size
self.replay_buffer_alpha = FLAGS.replay_buffer_alpha
self.replay_buffer_freq = FLAGS.replay_buffer_freq
assert self.replay_buffer_freq in [-1, 0, 1]
self.eviction = FLAGS.eviction
self.prioritize_by = FLAGS.prioritize_by
assert self.prioritize_by in ['rewards', 'step']
self.num_expert_paths = FLAGS.num_expert_paths
self.internal_dim = FLAGS.internal_dim
self.value_hidden_layers = FLAGS.value_hidden_layers
self.tf_seed = FLAGS.tf_seed
self.save_trajectories_dir = FLAGS.save_trajectories_dir
self.save_trajectories_file = (
os.path.join(
self.save_trajectories_dir, self.env_str.replace('-', '_'))
if self.save_trajectories_dir else None)
self.load_trajectories_file = FLAGS.load_trajectories_file
self.hparams = dict((attr, getattr(self, attr))
for attr in dir(self)
if not attr.startswith('__') and
not callable(getattr(self, attr)))
def hparams_string(self):
return '\n'.join('%s: %s' % item for item in sorted(self.hparams.items()))
def get_objective(self):
tau = self.tau
if self.tau_decay is not None:
assert self.tau_start >= self.tau
tau = tf.maximum(
tf.train.exponential_decay(
self.tau_start, self.global_step, 100, self.tau_decay),
self.tau)
if self.objective in ['pcl', 'a3c', 'trpo', 'upcl']:
cls = (objective.PCL if self.objective in ['pcl', 'upcl'] else
objective.TRPO if self.objective == 'trpo' else
objective.ActorCritic)
policy_weight = 1.0
return cls(self.learning_rate,
clip_norm=self.clip_norm,
policy_weight=policy_weight,
critic_weight=self.critic_weight,
tau=tau, gamma=self.gamma, rollout=self.rollout,
eps_lambda=self.eps_lambda, clip_adv=self.clip_adv,
use_target_values=self.use_target_values)
elif self.objective in ['reinforce', 'urex']:
cls = (full_episode_objective.Reinforce
if self.objective == 'reinforce' else
full_episode_objective.UREX)
return cls(self.learning_rate,
clip_norm=self.clip_norm,
num_samples=self.num_samples,
tau=tau, bonus_weight=1.0) # TODO: bonus weight?
else:
assert False, 'Unknown objective %s' % self.objective
def get_policy(self):
if self.recurrent:
cls = policy.Policy
else:
cls = policy.MLPPolicy
return cls(self.env_spec, self.internal_dim,
fixed_std=self.fixed_std,
recurrent=self.recurrent,
input_prev_actions=self.input_prev_actions)
def get_baseline(self):
cls = (baseline.UnifiedBaseline if self.objective == 'upcl' else
baseline.Baseline)
return cls(self.env_spec, self.internal_dim,
input_prev_actions=self.input_prev_actions,
input_time_step=self.input_time_step,
input_policy_state=self.recurrent, # may want to change this
n_hidden_layers=self.value_hidden_layers,
hidden_dim=self.internal_dim,
tau=self.tau)
def get_trust_region_p_opt(self):
if self.trust_region_p:
return trust_region.TrustRegionOptimization(
max_divergence=self.max_divergence)
else:
return None
def get_value_opt(self):
if self.value_opt == 'grad':
return optimizers.GradOptimization(
learning_rate=self.learning_rate, max_iter=5, mix_frac=0.05)
elif self.value_opt == 'lbfgs':
return optimizers.LbfgsOptimization(max_iter=25, mix_frac=0.1)
elif self.value_opt == 'best_fit':
return optimizers.BestFitOptimization(mix_frac=1.0)
else:
return None
def get_model(self):
cls = model.Model
return cls(self.env_spec, self.global_step,
target_network_lag=self.target_network_lag,
sample_from=self.sample_from,
get_policy=self.get_policy,
get_baseline=self.get_baseline,
get_objective=self.get_objective,
get_trust_region_p_opt=self.get_trust_region_p_opt,
get_value_opt=self.get_value_opt)
def get_replay_buffer(self):
if self.replay_buffer_freq <= 0:
return None
else:
assert self.objective in ['pcl', 'upcl'], 'Can\'t use replay buffer with %s' % (
self.objective)
cls = replay_buffer.PrioritizedReplayBuffer
return cls(self.replay_buffer_size,
alpha=self.replay_buffer_alpha,
eviction_strategy=self.eviction)
def get_buffer_seeds(self):
return expert_paths.sample_expert_paths(
self.num_expert_paths, self.env_str, self.env_spec,
load_trajectories_file=self.load_trajectories_file)
def get_controller(self, env):
"""Get controller."""
cls = controller.Controller
return cls(env, self.env_spec, self.internal_dim,
use_online_batch=self.use_online_batch,
batch_by_steps=self.batch_by_steps,
unify_episodes=self.unify_episodes,
replay_batch_size=self.replay_batch_size,
max_step=self.max_step,
cutoff_agent=self.cutoff_agent,
save_trajectories_file=self.save_trajectories_file,
use_trust_region=self.trust_region_p,
use_value_opt=self.value_opt not in [None, 'None'],
update_eps_lambda=self.update_eps_lambda,
prioritize_by=self.prioritize_by,
get_model=self.get_model,
get_replay_buffer=self.get_replay_buffer,
get_buffer_seeds=self.get_buffer_seeds)
def do_before_step(self, step):
pass
def run(self):
"""Run training."""
is_chief = FLAGS.task_id == 0 or not FLAGS.supervisor
sv = None
def init_fn(sess, saver):
ckpt = None
if FLAGS.save_dir and sv is None:
load_dir = FLAGS.save_dir
ckpt = tf.train.get_checkpoint_state(load_dir)
if ckpt and ckpt.model_checkpoint_path:
logging.info('restoring from %s', ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
elif FLAGS.load_path:
logging.info('restoring from %s', FLAGS.load_path)
saver.restore(sess, FLAGS.load_path)
if FLAGS.supervisor:
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)):
self.global_step = tf.contrib.framework.get_or_create_global_step()
tf.set_random_seed(FLAGS.tf_seed)
self.controller = self.get_controller(self.env)
self.model = self.controller.model
self.controller.setup()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self.eval_controller = self.get_controller(self.eval_env)
self.eval_controller.setup(train=False)
saver = tf.train.Saver(max_to_keep=10)
step = self.model.global_step
sv = tf.Supervisor(logdir=FLAGS.save_dir,
is_chief=is_chief,
saver=saver,
save_model_secs=600,
summary_op=None, # we define it ourselves
save_summaries_secs=60,
global_step=step,
init_fn=lambda sess: init_fn(sess, saver))
sess = sv.PrepareSession(FLAGS.master)
else:
tf.set_random_seed(FLAGS.tf_seed)
self.global_step = tf.contrib.framework.get_or_create_global_step()
self.controller = self.get_controller(self.env)
self.model = self.controller.model
self.controller.setup()
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self.eval_controller = self.get_controller(self.eval_env)
self.eval_controller.setup(train=False)
saver = tf.train.Saver(max_to_keep=10)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
init_fn(sess, saver)
self.sv = sv
self.sess = sess
logging.info('hparams:\n%s', self.hparams_string())
model_step = sess.run(self.model.global_step)
if model_step >= self.num_steps:
logging.info('training has reached final step')
return
losses = []
rewards = []
all_ep_rewards = []
for step in xrange(1 + self.num_steps):
if sv is not None and sv.ShouldStop():
logging.info('stopping supervisor')
break
self.do_before_step(step)
(loss, summary,
total_rewards, episode_rewards) = self.controller.train(sess)
_, greedy_episode_rewards = self.eval_controller.eval(sess)
self.controller.greedy_episode_rewards = greedy_episode_rewards
losses.append(loss)
rewards.append(total_rewards)
all_ep_rewards.extend(episode_rewards)
if (random.random() < 0.1 and summary and episode_rewards and
is_chief and sv and sv._summary_writer):
sv.summary_computed(sess, summary)
model_step = sess.run(self.model.global_step)
if is_chief and step % self.validation_frequency == 0:
logging.info('at training step %d, model step %d: '
'avg loss %f, avg reward %f, '
'episode rewards: %f, greedy rewards: %f',
step, model_step,
np.mean(losses), np.mean(rewards),
np.mean(all_ep_rewards),
np.mean(greedy_episode_rewards))
losses = []
rewards = []
all_ep_rewards = []
if model_step >= self.num_steps:
logging.info('training has reached final step')
break
if is_chief and sv is not None:
logging.info('saving final model to %s', sv.save_path)
sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
def main(unused_argv):
logging.set_verbosity(logging.INFO)
trainer = Trainer()
trainer.run()
if __name__ == '__main__':
app.run()