Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Trainer for coordinating single or multi-replica training. | |
Main point of entry for running models. Specifies most of | |
the parameters used by different algorithms. | |
""" | |
import tensorflow as tf | |
import numpy as np | |
import random | |
import os | |
import pickle | |
from six.moves import xrange | |
import controller | |
import model | |
import policy | |
import baseline | |
import objective | |
import full_episode_objective | |
import trust_region | |
import optimizers | |
import replay_buffer | |
import expert_paths | |
import gym_wrapper | |
import env_spec | |
app = tf.app | |
flags = tf.flags | |
logging = tf.logging | |
gfile = tf.gfile | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string('env', 'Copy-v0', 'environment name') | |
flags.DEFINE_integer('batch_size', 100, 'batch size') | |
flags.DEFINE_integer('replay_batch_size', None, 'replay batch size; defaults to batch_size') | |
flags.DEFINE_integer('num_samples', 1, | |
'number of samples from each random seed initialization') | |
flags.DEFINE_integer('max_step', 200, 'max number of steps to train on') | |
flags.DEFINE_integer('cutoff_agent', 0, | |
'number of steps at which to cut-off agent. ' | |
'Defaults to always cutoff') | |
flags.DEFINE_integer('num_steps', 100000, 'number of training steps') | |
flags.DEFINE_integer('validation_frequency', 100, | |
'every so many steps, output some stats') | |
flags.DEFINE_float('target_network_lag', 0.95, | |
'This exponential decay on online network yields target ' | |
'network') | |
flags.DEFINE_string('sample_from', 'online', | |
'Sample actions from "online" network or "target" network') | |
flags.DEFINE_string('objective', 'pcl', | |
'pcl/upcl/a3c/trpo/reinforce/urex') | |
flags.DEFINE_bool('trust_region_p', False, | |
'use trust region for policy optimization') | |
flags.DEFINE_string('value_opt', None, | |
'leave as None to optimize it along with policy ' | |
'(using critic_weight). Otherwise set to ' | |
'"best_fit" (least squares regression), "lbfgs", or "grad"') | |
flags.DEFINE_float('max_divergence', 0.01, | |
'max divergence (i.e. KL) to allow during ' | |
'trust region optimization') | |
flags.DEFINE_float('learning_rate', 0.01, 'learning rate') | |
flags.DEFINE_float('clip_norm', 5.0, 'clip norm') | |
flags.DEFINE_float('clip_adv', 0.0, 'Clip advantages at this value. ' | |
'Leave as 0 to not clip at all.') | |
flags.DEFINE_float('critic_weight', 0.1, 'critic weight') | |
flags.DEFINE_float('tau', 0.1, 'entropy regularizer.' | |
'If using decaying tau, this is the final value.') | |
flags.DEFINE_float('tau_decay', None, | |
'decay tau by this much every 100 steps') | |
flags.DEFINE_float('tau_start', 0.1, | |
'start tau at this value') | |
flags.DEFINE_float('eps_lambda', 0.0, 'relative entropy regularizer.') | |
flags.DEFINE_bool('update_eps_lambda', False, | |
'Update lambda automatically based on last 100 episodes.') | |
flags.DEFINE_float('gamma', 1.0, 'discount') | |
flags.DEFINE_integer('rollout', 10, 'rollout') | |
flags.DEFINE_bool('use_target_values', False, | |
'use target network for value estimates') | |
flags.DEFINE_bool('fixed_std', True, | |
'fix the std in Gaussian distributions') | |
flags.DEFINE_bool('input_prev_actions', True, | |
'input previous actions to policy network') | |
flags.DEFINE_bool('recurrent', True, | |
'use recurrent connections') | |
flags.DEFINE_bool('input_time_step', False, | |
'input time step into value calucations') | |
flags.DEFINE_bool('use_online_batch', True, 'train on batches as they are sampled') | |
flags.DEFINE_bool('batch_by_steps', False, | |
'ensure each training batch has batch_size * max_step steps') | |
flags.DEFINE_bool('unify_episodes', False, | |
'Make sure replay buffer holds entire episodes, ' | |
'even across distinct sampling steps') | |
flags.DEFINE_integer('replay_buffer_size', 5000, 'replay buffer size') | |
flags.DEFINE_float('replay_buffer_alpha', 0.5, 'replay buffer alpha param') | |
flags.DEFINE_integer('replay_buffer_freq', 0, | |
'replay buffer frequency (only supports -1/0/1)') | |
flags.DEFINE_string('eviction', 'rand', | |
'how to evict from replay buffer: rand/rank/fifo') | |
flags.DEFINE_string('prioritize_by', 'rewards', | |
'Prioritize replay buffer by "rewards" or "step"') | |
flags.DEFINE_integer('num_expert_paths', 0, | |
'number of expert paths to seed replay buffer with') | |
flags.DEFINE_integer('internal_dim', 256, 'RNN internal dim') | |
flags.DEFINE_integer('value_hidden_layers', 0, | |
'number of hidden layers in value estimate') | |
flags.DEFINE_integer('tf_seed', 42, 'random seed for tensorflow') | |
flags.DEFINE_string('save_trajectories_dir', None, | |
'directory to save trajectories to, if desired') | |
flags.DEFINE_string('load_trajectories_file', None, | |
'file to load expert trajectories from') | |
# supervisor flags | |
flags.DEFINE_bool('supervisor', False, 'use supervisor training') | |
flags.DEFINE_integer('task_id', 0, 'task id') | |
flags.DEFINE_integer('ps_tasks', 0, 'number of ps tasks') | |
flags.DEFINE_integer('num_replicas', 1, 'number of replicas used') | |
flags.DEFINE_string('master', 'local', 'name of master') | |
flags.DEFINE_string('save_dir', '', 'directory to save model to') | |
flags.DEFINE_string('load_path', '', 'path of saved model to load (if none in save_dir)') | |
class Trainer(object): | |
"""Coordinates single or multi-replica training.""" | |
def __init__(self): | |
self.batch_size = FLAGS.batch_size | |
self.replay_batch_size = FLAGS.replay_batch_size | |
if self.replay_batch_size is None: | |
self.replay_batch_size = self.batch_size | |
self.num_samples = FLAGS.num_samples | |
self.env_str = FLAGS.env | |
self.env = gym_wrapper.GymWrapper(self.env_str, | |
distinct=FLAGS.batch_size // self.num_samples, | |
count=self.num_samples) | |
self.eval_env = gym_wrapper.GymWrapper( | |
self.env_str, | |
distinct=FLAGS.batch_size // self.num_samples, | |
count=self.num_samples) | |
self.env_spec = env_spec.EnvSpec(self.env.get_one()) | |
self.max_step = FLAGS.max_step | |
self.cutoff_agent = FLAGS.cutoff_agent | |
self.num_steps = FLAGS.num_steps | |
self.validation_frequency = FLAGS.validation_frequency | |
self.target_network_lag = FLAGS.target_network_lag | |
self.sample_from = FLAGS.sample_from | |
assert self.sample_from in ['online', 'target'] | |
self.critic_weight = FLAGS.critic_weight | |
self.objective = FLAGS.objective | |
self.trust_region_p = FLAGS.trust_region_p | |
self.value_opt = FLAGS.value_opt | |
assert not self.trust_region_p or self.objective in ['pcl', 'trpo'] | |
assert self.objective != 'trpo' or self.trust_region_p | |
assert self.value_opt is None or self.value_opt == 'None' or \ | |
self.critic_weight == 0.0 | |
self.max_divergence = FLAGS.max_divergence | |
self.learning_rate = FLAGS.learning_rate | |
self.clip_norm = FLAGS.clip_norm | |
self.clip_adv = FLAGS.clip_adv | |
self.tau = FLAGS.tau | |
self.tau_decay = FLAGS.tau_decay | |
self.tau_start = FLAGS.tau_start | |
self.eps_lambda = FLAGS.eps_lambda | |
self.update_eps_lambda = FLAGS.update_eps_lambda | |
self.gamma = FLAGS.gamma | |
self.rollout = FLAGS.rollout | |
self.use_target_values = FLAGS.use_target_values | |
self.fixed_std = FLAGS.fixed_std | |
self.input_prev_actions = FLAGS.input_prev_actions | |
self.recurrent = FLAGS.recurrent | |
assert not self.trust_region_p or not self.recurrent | |
self.input_time_step = FLAGS.input_time_step | |
assert not self.input_time_step or (self.cutoff_agent <= self.max_step) | |
self.use_online_batch = FLAGS.use_online_batch | |
self.batch_by_steps = FLAGS.batch_by_steps | |
self.unify_episodes = FLAGS.unify_episodes | |
if self.unify_episodes: | |
assert self.batch_size == 1 | |
self.replay_buffer_size = FLAGS.replay_buffer_size | |
self.replay_buffer_alpha = FLAGS.replay_buffer_alpha | |
self.replay_buffer_freq = FLAGS.replay_buffer_freq | |
assert self.replay_buffer_freq in [-1, 0, 1] | |
self.eviction = FLAGS.eviction | |
self.prioritize_by = FLAGS.prioritize_by | |
assert self.prioritize_by in ['rewards', 'step'] | |
self.num_expert_paths = FLAGS.num_expert_paths | |
self.internal_dim = FLAGS.internal_dim | |
self.value_hidden_layers = FLAGS.value_hidden_layers | |
self.tf_seed = FLAGS.tf_seed | |
self.save_trajectories_dir = FLAGS.save_trajectories_dir | |
self.save_trajectories_file = ( | |
os.path.join( | |
self.save_trajectories_dir, self.env_str.replace('-', '_')) | |
if self.save_trajectories_dir else None) | |
self.load_trajectories_file = FLAGS.load_trajectories_file | |
self.hparams = dict((attr, getattr(self, attr)) | |
for attr in dir(self) | |
if not attr.startswith('__') and | |
not callable(getattr(self, attr))) | |
def hparams_string(self): | |
return '\n'.join('%s: %s' % item for item in sorted(self.hparams.items())) | |
def get_objective(self): | |
tau = self.tau | |
if self.tau_decay is not None: | |
assert self.tau_start >= self.tau | |
tau = tf.maximum( | |
tf.train.exponential_decay( | |
self.tau_start, self.global_step, 100, self.tau_decay), | |
self.tau) | |
if self.objective in ['pcl', 'a3c', 'trpo', 'upcl']: | |
cls = (objective.PCL if self.objective in ['pcl', 'upcl'] else | |
objective.TRPO if self.objective == 'trpo' else | |
objective.ActorCritic) | |
policy_weight = 1.0 | |
return cls(self.learning_rate, | |
clip_norm=self.clip_norm, | |
policy_weight=policy_weight, | |
critic_weight=self.critic_weight, | |
tau=tau, gamma=self.gamma, rollout=self.rollout, | |
eps_lambda=self.eps_lambda, clip_adv=self.clip_adv, | |
use_target_values=self.use_target_values) | |
elif self.objective in ['reinforce', 'urex']: | |
cls = (full_episode_objective.Reinforce | |
if self.objective == 'reinforce' else | |
full_episode_objective.UREX) | |
return cls(self.learning_rate, | |
clip_norm=self.clip_norm, | |
num_samples=self.num_samples, | |
tau=tau, bonus_weight=1.0) # TODO: bonus weight? | |
else: | |
assert False, 'Unknown objective %s' % self.objective | |
def get_policy(self): | |
if self.recurrent: | |
cls = policy.Policy | |
else: | |
cls = policy.MLPPolicy | |
return cls(self.env_spec, self.internal_dim, | |
fixed_std=self.fixed_std, | |
recurrent=self.recurrent, | |
input_prev_actions=self.input_prev_actions) | |
def get_baseline(self): | |
cls = (baseline.UnifiedBaseline if self.objective == 'upcl' else | |
baseline.Baseline) | |
return cls(self.env_spec, self.internal_dim, | |
input_prev_actions=self.input_prev_actions, | |
input_time_step=self.input_time_step, | |
input_policy_state=self.recurrent, # may want to change this | |
n_hidden_layers=self.value_hidden_layers, | |
hidden_dim=self.internal_dim, | |
tau=self.tau) | |
def get_trust_region_p_opt(self): | |
if self.trust_region_p: | |
return trust_region.TrustRegionOptimization( | |
max_divergence=self.max_divergence) | |
else: | |
return None | |
def get_value_opt(self): | |
if self.value_opt == 'grad': | |
return optimizers.GradOptimization( | |
learning_rate=self.learning_rate, max_iter=5, mix_frac=0.05) | |
elif self.value_opt == 'lbfgs': | |
return optimizers.LbfgsOptimization(max_iter=25, mix_frac=0.1) | |
elif self.value_opt == 'best_fit': | |
return optimizers.BestFitOptimization(mix_frac=1.0) | |
else: | |
return None | |
def get_model(self): | |
cls = model.Model | |
return cls(self.env_spec, self.global_step, | |
target_network_lag=self.target_network_lag, | |
sample_from=self.sample_from, | |
get_policy=self.get_policy, | |
get_baseline=self.get_baseline, | |
get_objective=self.get_objective, | |
get_trust_region_p_opt=self.get_trust_region_p_opt, | |
get_value_opt=self.get_value_opt) | |
def get_replay_buffer(self): | |
if self.replay_buffer_freq <= 0: | |
return None | |
else: | |
assert self.objective in ['pcl', 'upcl'], 'Can\'t use replay buffer with %s' % ( | |
self.objective) | |
cls = replay_buffer.PrioritizedReplayBuffer | |
return cls(self.replay_buffer_size, | |
alpha=self.replay_buffer_alpha, | |
eviction_strategy=self.eviction) | |
def get_buffer_seeds(self): | |
return expert_paths.sample_expert_paths( | |
self.num_expert_paths, self.env_str, self.env_spec, | |
load_trajectories_file=self.load_trajectories_file) | |
def get_controller(self, env): | |
"""Get controller.""" | |
cls = controller.Controller | |
return cls(env, self.env_spec, self.internal_dim, | |
use_online_batch=self.use_online_batch, | |
batch_by_steps=self.batch_by_steps, | |
unify_episodes=self.unify_episodes, | |
replay_batch_size=self.replay_batch_size, | |
max_step=self.max_step, | |
cutoff_agent=self.cutoff_agent, | |
save_trajectories_file=self.save_trajectories_file, | |
use_trust_region=self.trust_region_p, | |
use_value_opt=self.value_opt not in [None, 'None'], | |
update_eps_lambda=self.update_eps_lambda, | |
prioritize_by=self.prioritize_by, | |
get_model=self.get_model, | |
get_replay_buffer=self.get_replay_buffer, | |
get_buffer_seeds=self.get_buffer_seeds) | |
def do_before_step(self, step): | |
pass | |
def run(self): | |
"""Run training.""" | |
is_chief = FLAGS.task_id == 0 or not FLAGS.supervisor | |
sv = None | |
def init_fn(sess, saver): | |
ckpt = None | |
if FLAGS.save_dir and sv is None: | |
load_dir = FLAGS.save_dir | |
ckpt = tf.train.get_checkpoint_state(load_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
logging.info('restoring from %s', ckpt.model_checkpoint_path) | |
saver.restore(sess, ckpt.model_checkpoint_path) | |
elif FLAGS.load_path: | |
logging.info('restoring from %s', FLAGS.load_path) | |
saver.restore(sess, FLAGS.load_path) | |
if FLAGS.supervisor: | |
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)): | |
self.global_step = tf.contrib.framework.get_or_create_global_step() | |
tf.set_random_seed(FLAGS.tf_seed) | |
self.controller = self.get_controller(self.env) | |
self.model = self.controller.model | |
self.controller.setup() | |
with tf.variable_scope(tf.get_variable_scope(), reuse=True): | |
self.eval_controller = self.get_controller(self.eval_env) | |
self.eval_controller.setup(train=False) | |
saver = tf.train.Saver(max_to_keep=10) | |
step = self.model.global_step | |
sv = tf.Supervisor(logdir=FLAGS.save_dir, | |
is_chief=is_chief, | |
saver=saver, | |
save_model_secs=600, | |
summary_op=None, # we define it ourselves | |
save_summaries_secs=60, | |
global_step=step, | |
init_fn=lambda sess: init_fn(sess, saver)) | |
sess = sv.PrepareSession(FLAGS.master) | |
else: | |
tf.set_random_seed(FLAGS.tf_seed) | |
self.global_step = tf.contrib.framework.get_or_create_global_step() | |
self.controller = self.get_controller(self.env) | |
self.model = self.controller.model | |
self.controller.setup() | |
with tf.variable_scope(tf.get_variable_scope(), reuse=True): | |
self.eval_controller = self.get_controller(self.eval_env) | |
self.eval_controller.setup(train=False) | |
saver = tf.train.Saver(max_to_keep=10) | |
sess = tf.Session() | |
sess.run(tf.initialize_all_variables()) | |
init_fn(sess, saver) | |
self.sv = sv | |
self.sess = sess | |
logging.info('hparams:\n%s', self.hparams_string()) | |
model_step = sess.run(self.model.global_step) | |
if model_step >= self.num_steps: | |
logging.info('training has reached final step') | |
return | |
losses = [] | |
rewards = [] | |
all_ep_rewards = [] | |
for step in xrange(1 + self.num_steps): | |
if sv is not None and sv.ShouldStop(): | |
logging.info('stopping supervisor') | |
break | |
self.do_before_step(step) | |
(loss, summary, | |
total_rewards, episode_rewards) = self.controller.train(sess) | |
_, greedy_episode_rewards = self.eval_controller.eval(sess) | |
self.controller.greedy_episode_rewards = greedy_episode_rewards | |
losses.append(loss) | |
rewards.append(total_rewards) | |
all_ep_rewards.extend(episode_rewards) | |
if (random.random() < 0.1 and summary and episode_rewards and | |
is_chief and sv and sv._summary_writer): | |
sv.summary_computed(sess, summary) | |
model_step = sess.run(self.model.global_step) | |
if is_chief and step % self.validation_frequency == 0: | |
logging.info('at training step %d, model step %d: ' | |
'avg loss %f, avg reward %f, ' | |
'episode rewards: %f, greedy rewards: %f', | |
step, model_step, | |
np.mean(losses), np.mean(rewards), | |
np.mean(all_ep_rewards), | |
np.mean(greedy_episode_rewards)) | |
losses = [] | |
rewards = [] | |
all_ep_rewards = [] | |
if model_step >= self.num_steps: | |
logging.info('training has reached final step') | |
break | |
if is_chief and sv is not None: | |
logging.info('saving final model to %s', sv.save_path) | |
sv.saver.save(sess, sv.save_path, global_step=sv.global_step) | |
def main(unused_argv): | |
logging.set_verbosity(logging.INFO) | |
trainer = Trainer() | |
trainer.run() | |
if __name__ == '__main__': | |
app.run() | |