baiyanlali-zhao's picture
init
eaf2e33
import numpy as np
def rollout(env, agent, max_path_length=np.inf, render=False):
"""
The following value for the following keys will be a 2D array, with the
first dimension corresponding to the time dimension.
- observations
- actions
- rewards
- next_observations
- terminals
The next two elements will be lists of dictionaries, with the index into
the list being the index into the time
- agent_infos
- env_infos
:param env:
:param agent:
:param max_path_length:
:param render:
:return:
"""
observations = []
actions = []
rewards = []
terminals = []
agent_infos = []
env_infos = []
o = env.reset()
next_o = None
path_length = 0
if render:
env.render()
while path_length < max_path_length:
a, agent_info = agent.get_action(o)
next_o, r, d, env_info = env.step(a)
observations.append(o)
rewards.append(r)
terminals.append(d)
actions.append(a)
agent_infos.append(agent_info)
env_infos.append(env_info)
path_length += 1
if d:
break
o = next_o
if render:
env.render()
actions = np.array(actions)
if len(actions.shape) == 1:
actions = np.expand_dims(actions, 1)
observations = np.array(observations)
if len(observations.shape) == 1:
observations = np.expand_dims(observations, 1)
next_o = np.array([next_o])
next_observations = np.vstack(
(
observations[1:, :],
np.expand_dims(next_o, 0)
)
)
return dict(
observations=observations,
actions=actions,
rewards=np.array(rewards).reshape(-1, 1),
next_observations=next_observations,
terminals=np.array(terminals).reshape(-1, 1),
agent_infos=agent_infos,
env_infos=env_infos,
)
def split_paths(paths):
"""
Stack multiples obs/actions/etc. from different paths
:param paths: List of paths, where one path is something returned from
the rollout functino above.
:return: Tuple. Every element will have shape batch_size X DIM, including
the rewards and terminal flags.
"""
rewards = [path["rewards"].reshape(-1, 1) for path in paths]
terminals = [path["terminals"].reshape(-1, 1) for path in paths]
actions = [path["actions"] for path in paths]
obs = [path["observations"] for path in paths]
next_obs = [path["next_observations"] for path in paths]
rewards = np.vstack(rewards)
terminals = np.vstack(terminals)
obs = np.vstack(obs)
actions = np.vstack(actions)
next_obs = np.vstack(next_obs)
assert len(rewards.shape) == 2
assert len(terminals.shape) == 2
assert len(obs.shape) == 2
assert len(actions.shape) == 2
assert len(next_obs.shape) == 2
return rewards, terminals, obs, actions, next_obs
def split_paths_to_dict(paths):
rewards, terminals, obs, actions, next_obs = split_paths(paths)
return dict(
rewards=rewards,
terminals=terminals,
observations=obs,
actions=actions,
next_observations=next_obs,
)
def get_stat_in_paths(paths, dict_name, scalar_name):
if len(paths) == 0:
return np.array([[]])
if type(paths[0][dict_name]) == dict:
# Support rllab interface
return [path[dict_name][scalar_name] for path in paths]
return [
[info[scalar_name] for info in path[dict_name]]
for path in paths
]