NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
16.9 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Policy neural network.
Implements network which takes in input and produces actions
and log probabilities given a sampling distribution parameterization.
"""
import tensorflow as tf
import numpy as np
class Policy(object):
def __init__(self, env_spec, internal_dim,
fixed_std=True, recurrent=True,
input_prev_actions=True):
self.env_spec = env_spec
self.internal_dim = internal_dim
self.rnn_state_dim = self.internal_dim
self.fixed_std = fixed_std
self.recurrent = recurrent
self.input_prev_actions = input_prev_actions
self.matrix_init = tf.truncated_normal_initializer(stddev=0.01)
self.vector_init = tf.constant_initializer(0.0)
@property
def input_dim(self):
return (self.env_spec.total_obs_dim +
self.env_spec.total_sampled_act_dim * self.input_prev_actions)
@property
def output_dim(self):
return self.env_spec.total_sampling_act_dim
def get_cell(self):
"""Get RNN cell."""
self.cell_input_dim = self.internal_dim // 2
cell = tf.contrib.rnn.LSTMCell(self.cell_input_dim,
state_is_tuple=False,
reuse=tf.get_variable_scope().reuse)
cell = tf.contrib.rnn.OutputProjectionWrapper(
cell, self.output_dim,
reuse=tf.get_variable_scope().reuse)
return cell
def core(self, obs, prev_internal_state, prev_actions):
"""Core neural network taking in inputs and outputting sampling
distribution parameters."""
batch_size = tf.shape(obs[0])[0]
if not self.recurrent:
prev_internal_state = tf.zeros([batch_size, self.rnn_state_dim])
cell = self.get_cell()
b = tf.get_variable('input_bias', [self.cell_input_dim],
initializer=self.vector_init)
cell_input = tf.nn.bias_add(tf.zeros([batch_size, self.cell_input_dim]), b)
for i, (obs_dim, obs_type) in enumerate(self.env_spec.obs_dims_and_types):
w = tf.get_variable('w_state%d' % i, [obs_dim, self.cell_input_dim],
initializer=self.matrix_init)
if self.env_spec.is_discrete(obs_type):
cell_input += tf.matmul(tf.one_hot(obs[i], obs_dim), w)
elif self.env_spec.is_box(obs_type):
cell_input += tf.matmul(obs[i], w)
else:
assert False
if self.input_prev_actions:
if self.env_spec.combine_actions: # TODO(ofir): clean this up
prev_action = prev_actions[0]
for i, action_dim in enumerate(self.env_spec.orig_act_dims):
act = tf.mod(prev_action, action_dim)
w = tf.get_variable('w_prev_action%d' % i, [action_dim, self.cell_input_dim],
initializer=self.matrix_init)
cell_input += tf.matmul(tf.one_hot(act, action_dim), w)
prev_action = tf.to_int32(prev_action / action_dim)
else:
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types):
w = tf.get_variable('w_prev_action%d' % i, [act_dim, self.cell_input_dim],
initializer=self.matrix_init)
if self.env_spec.is_discrete(act_type):
cell_input += tf.matmul(tf.one_hot(prev_actions[i], act_dim), w)
elif self.env_spec.is_box(act_type):
cell_input += tf.matmul(prev_actions[i], w)
else:
assert False
output, next_state = cell(cell_input, prev_internal_state)
return output, next_state
def sample_action(self, logits, sampling_dim,
act_dim, act_type, greedy=False):
"""Sample an action from a distribution."""
if self.env_spec.is_discrete(act_type):
if greedy:
act = tf.argmax(logits, 1)
else:
act = tf.reshape(tf.multinomial(logits, 1), [-1])
elif self.env_spec.is_box(act_type):
means = logits[:, :sampling_dim / 2]
std = logits[:, sampling_dim / 2:]
if greedy:
act = means
else:
batch_size = tf.shape(logits)[0]
act = means + std * tf.random_normal([batch_size, act_dim])
else:
assert False
return act
def entropy(self, logits,
sampling_dim, act_dim, act_type):
"""Calculate entropy of distribution."""
if self.env_spec.is_discrete(act_type):
entropy = tf.reduce_sum(
-tf.nn.softmax(logits) * tf.nn.log_softmax(logits), -1)
elif self.env_spec.is_box(act_type):
means = logits[:, :sampling_dim / 2]
std = logits[:, sampling_dim / 2:]
entropy = tf.reduce_sum(
0.5 * (1 + tf.log(2 * np.pi * tf.square(std))), -1)
else:
assert False
return entropy
def self_kl(self, logits,
sampling_dim, act_dim, act_type):
"""Calculate KL of distribution with itself.
Used layer only for the gradients.
"""
if self.env_spec.is_discrete(act_type):
probs = tf.nn.softmax(logits)
log_probs = tf.nn.log_softmax(logits)
self_kl = tf.reduce_sum(
tf.stop_gradient(probs) *
(tf.stop_gradient(log_probs) - log_probs), -1)
elif self.env_spec.is_box(act_type):
means = logits[:, :sampling_dim / 2]
std = logits[:, sampling_dim / 2:]
my_means = tf.stop_gradient(means)
my_std = tf.stop_gradient(std)
self_kl = tf.reduce_sum(
tf.log(std / my_std) +
(tf.square(my_std) + tf.square(my_means - means)) /
(2.0 * tf.square(std)) - 0.5,
-1)
else:
assert False
return self_kl
def log_prob_action(self, action, logits,
sampling_dim, act_dim, act_type):
"""Calculate log-prob of action sampled from distribution."""
if self.env_spec.is_discrete(act_type):
act_log_prob = tf.reduce_sum(
tf.one_hot(action, act_dim) * tf.nn.log_softmax(logits), -1)
elif self.env_spec.is_box(act_type):
means = logits[:, :sampling_dim / 2]
std = logits[:, sampling_dim / 2:]
act_log_prob = (- 0.5 * tf.log(2 * np.pi * tf.square(std))
- 0.5 * tf.square(action - means) / tf.square(std))
act_log_prob = tf.reduce_sum(act_log_prob, -1)
else:
assert False
return act_log_prob
def sample_actions(self, output, actions=None, greedy=False):
"""Sample all actions given output of core network."""
sampled_actions = []
logits = []
log_probs = []
entropy = []
self_kl = []
start_idx = 0
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types):
sampling_dim = self.env_spec.sampling_dim(act_dim, act_type)
if self.fixed_std and self.env_spec.is_box(act_type):
act_logits = output[:, start_idx:start_idx + act_dim]
log_std = tf.get_variable('std%d' % i, [1, sampling_dim // 2])
# fix standard deviations to variable
act_logits = tf.concat(
[act_logits,
1e-6 + tf.exp(log_std) + 0 * act_logits], 1)
else:
act_logits = output[:, start_idx:start_idx + sampling_dim]
if actions is None:
act = self.sample_action(act_logits, sampling_dim,
act_dim, act_type,
greedy=greedy)
else:
act = actions[i]
ent = self.entropy(act_logits, sampling_dim, act_dim, act_type)
kl = self.self_kl(act_logits, sampling_dim, act_dim, act_type)
act_log_prob = self.log_prob_action(
act, act_logits,
sampling_dim, act_dim, act_type)
sampled_actions.append(act)
logits.append(act_logits)
log_probs.append(act_log_prob)
entropy.append(ent)
self_kl.append(kl)
start_idx += sampling_dim
assert start_idx == self.env_spec.total_sampling_act_dim
return sampled_actions, logits, log_probs, entropy, self_kl
def get_kl(self, my_logits, other_logits):
"""Calculate KL between one policy output and another."""
kl = []
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types):
sampling_dim = self.env_spec.sampling_dim(act_dim, act_type)
single_my_logits = my_logits[i]
single_other_logits = other_logits[i]
if self.env_spec.is_discrete(act_type):
my_probs = tf.nn.softmax(single_my_logits)
my_log_probs = tf.nn.log_softmax(single_my_logits)
other_log_probs = tf.nn.log_softmax(single_other_logits)
my_kl = tf.reduce_sum(my_probs * (my_log_probs - other_log_probs), -1)
elif self.env_spec.is_box(act_type):
my_means = single_my_logits[:, :sampling_dim / 2]
my_std = single_my_logits[:, sampling_dim / 2:]
other_means = single_other_logits[:, :sampling_dim / 2]
other_std = single_other_logits[:, sampling_dim / 2:]
my_kl = tf.reduce_sum(
tf.log(other_std / my_std) +
(tf.square(my_std) + tf.square(my_means - other_means)) /
(2.0 * tf.square(other_std)) - 0.5,
-1)
else:
assert False
kl.append(my_kl)
return kl
def single_step(self, prev, cur, greedy=False):
"""Single RNN step. Equivalently, single-time-step sampled actions."""
prev_internal_state, prev_actions, _, _, _, _ = prev
obs, actions = cur # state observed and action taken at this time step
# feed into RNN cell
output, next_state = self.core(
obs, prev_internal_state, prev_actions)
# sample actions with values and log-probs
(actions, logits, log_probs,
entropy, self_kl) = self.sample_actions(
output, actions=actions, greedy=greedy)
return (next_state, tuple(actions), tuple(logits), tuple(log_probs),
tuple(entropy), tuple(self_kl))
def sample_step(self, obs, prev_internal_state, prev_actions, greedy=False):
"""Sample single step from policy."""
(next_state, sampled_actions, logits, log_probs,
entropies, self_kls) = self.single_step(
(prev_internal_state, prev_actions, None, None, None, None),
(obs, None), greedy=greedy)
return next_state, sampled_actions
def multi_step(self, all_obs, initial_state, all_actions):
"""Calculate log-probs and other calculations on batch of episodes."""
batch_size = tf.shape(initial_state)[0]
time_length = tf.shape(all_obs[0])[0]
initial_actions = [act[0] for act in all_actions]
all_actions = [tf.concat([act[1:], act[0:1]], 0)
for act in all_actions] # "final" action is dummy
(internal_states, _, logits, log_probs,
entropies, self_kls) = tf.scan(
self.single_step,
(all_obs, all_actions),
initializer=self.get_initializer(
batch_size, initial_state, initial_actions))
# remove "final" computations
log_probs = [log_prob[:-1] for log_prob in log_probs]
entropies = [entropy[:-1] for entropy in entropies]
self_kls = [self_kl[:-1] for self_kl in self_kls]
return internal_states, logits, log_probs, entropies, self_kls
def get_initializer(self, batch_size, initial_state, initial_actions):
"""Get initializer for RNN."""
logits_init = []
log_probs_init = []
for act_dim, act_type in self.env_spec.act_dims_and_types:
sampling_dim = self.env_spec.sampling_dim(act_dim, act_type)
logits_init.append(tf.zeros([batch_size, sampling_dim]))
log_probs_init.append(tf.zeros([batch_size]))
entropy_init = [tf.zeros([batch_size]) for _ in self.env_spec.act_dims]
self_kl_init = [tf.zeros([batch_size]) for _ in self.env_spec.act_dims]
return (initial_state,
tuple(initial_actions),
tuple(logits_init), tuple(log_probs_init),
tuple(entropy_init),
tuple(self_kl_init))
def calculate_kl(self, my_logits, other_logits):
"""Calculate KL between one policy and another on batch of episodes."""
batch_size = tf.shape(my_logits[0])[1]
time_length = tf.shape(my_logits[0])[0]
reshaped_my_logits = [
tf.reshape(my_logit, [batch_size * time_length, -1])
for my_logit in my_logits]
reshaped_other_logits = [
tf.reshape(other_logit, [batch_size * time_length, -1])
for other_logit in other_logits]
kl = self.get_kl(reshaped_my_logits, reshaped_other_logits)
kl = [tf.reshape(kkl, [time_length, batch_size])
for kkl in kl]
return kl
class MLPPolicy(Policy):
"""Non-recurrent policy."""
def get_cell(self):
self.cell_input_dim = self.internal_dim
def mlp(cell_input, prev_internal_state):
w1 = tf.get_variable('w1', [self.cell_input_dim, self.internal_dim])
b1 = tf.get_variable('b1', [self.internal_dim])
w2 = tf.get_variable('w2', [self.internal_dim, self.internal_dim])
b2 = tf.get_variable('b2', [self.internal_dim])
w3 = tf.get_variable('w3', [self.internal_dim, self.internal_dim])
b3 = tf.get_variable('b3', [self.internal_dim])
proj = tf.get_variable(
'proj', [self.internal_dim, self.output_dim])
hidden = cell_input
hidden = tf.tanh(tf.nn.bias_add(tf.matmul(hidden, w1), b1))
hidden = tf.tanh(tf.nn.bias_add(tf.matmul(hidden, w2), b2))
output = tf.matmul(hidden, proj)
return output, hidden
return mlp
def single_step(self, obs, actions, prev_actions, greedy=False):
"""Single step."""
batch_size = tf.shape(obs[0])[0]
prev_internal_state = tf.zeros([batch_size, self.internal_dim])
output, next_state = self.core(
obs, prev_internal_state, prev_actions)
# sample actions with values and log-probs
(actions, logits, log_probs,
entropy, self_kl) = self.sample_actions(
output, actions=actions, greedy=greedy)
return (next_state, tuple(actions), tuple(logits), tuple(log_probs),
tuple(entropy), tuple(self_kl))
def sample_step(self, obs, prev_internal_state, prev_actions, greedy=False):
"""Sample single step from policy."""
(next_state, sampled_actions, logits, log_probs,
entropies, self_kls) = self.single_step(obs, None, prev_actions,
greedy=greedy)
return next_state, sampled_actions
def multi_step(self, all_obs, initial_state, all_actions):
"""Calculate log-probs and other calculations on batch of episodes."""
batch_size = tf.shape(initial_state)[0]
time_length = tf.shape(all_obs[0])[0]
# first reshape inputs as a single batch
reshaped_obs = []
for obs, (obs_dim, obs_type) in zip(all_obs, self.env_spec.obs_dims_and_types):
if self.env_spec.is_discrete(obs_type):
reshaped_obs.append(tf.reshape(obs, [time_length * batch_size]))
elif self.env_spec.is_box(obs_type):
reshaped_obs.append(tf.reshape(obs, [time_length * batch_size, obs_dim]))
reshaped_act = []
reshaped_prev_act = []
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types):
act = tf.concat([all_actions[i][1:], all_actions[i][0:1]], 0)
prev_act = all_actions[i]
if self.env_spec.is_discrete(act_type):
reshaped_act.append(tf.reshape(act, [time_length * batch_size]))
reshaped_prev_act.append(
tf.reshape(prev_act, [time_length * batch_size]))
elif self.env_spec.is_box(act_type):
reshaped_act.append(
tf.reshape(act, [time_length * batch_size, act_dim]))
reshaped_prev_act.append(
tf.reshape(prev_act, [time_length * batch_size, act_dim]))
# now inputs go into single step as one large batch
(internal_states, _, logits, log_probs,
entropies, self_kls) = self.single_step(
reshaped_obs, reshaped_act, reshaped_prev_act)
# reshape the outputs back to original time-major format
internal_states = tf.reshape(internal_states, [time_length, batch_size, -1])
logits = [tf.reshape(logit, [time_length, batch_size, -1])
for logit in logits]
log_probs = [tf.reshape(log_prob, [time_length, batch_size])[:-1]
for log_prob in log_probs]
entropies = [tf.reshape(ent, [time_length, batch_size])[:-1]
for ent in entropies]
self_kls = [tf.reshape(self_kl, [time_length, batch_size])[:-1]
for self_kl in self_kls]
return internal_states, logits, log_probs, entropies, self_kls