import numpy as np def rollout(env, agent, max_path_length=np.inf, render=False): """ The following value for the following keys will be a 2D array, with the first dimension corresponding to the time dimension. - observations - actions - rewards - next_observations - terminals The next two elements will be lists of dictionaries, with the index into the list being the index into the time - agent_infos - env_infos :param env: :param agent: :param max_path_length: :param render: :return: """ observations = [] actions = [] rewards = [] terminals = [] agent_infos = [] env_infos = [] o = env.reset() next_o = None path_length = 0 if render: env.render() while path_length < max_path_length: a, agent_info = agent.get_action(o) next_o, r, d, env_info = env.step(a) observations.append(o) rewards.append(r) terminals.append(d) actions.append(a) agent_infos.append(agent_info) env_infos.append(env_info) path_length += 1 if d: break o = next_o if render: env.render() actions = np.array(actions) if len(actions.shape) == 1: actions = np.expand_dims(actions, 1) observations = np.array(observations) if len(observations.shape) == 1: observations = np.expand_dims(observations, 1) next_o = np.array([next_o]) next_observations = np.vstack( ( observations[1:, :], np.expand_dims(next_o, 0) ) ) return dict( observations=observations, actions=actions, rewards=np.array(rewards).reshape(-1, 1), next_observations=next_observations, terminals=np.array(terminals).reshape(-1, 1), agent_infos=agent_infos, env_infos=env_infos, ) def split_paths(paths): """ Stack multiples obs/actions/etc. from different paths :param paths: List of paths, where one path is something returned from the rollout functino above. :return: Tuple. Every element will have shape batch_size X DIM, including the rewards and terminal flags. """ rewards = [path["rewards"].reshape(-1, 1) for path in paths] terminals = [path["terminals"].reshape(-1, 1) for path in paths] actions = [path["actions"] for path in paths] obs = [path["observations"] for path in paths] next_obs = [path["next_observations"] for path in paths] rewards = np.vstack(rewards) terminals = np.vstack(terminals) obs = np.vstack(obs) actions = np.vstack(actions) next_obs = np.vstack(next_obs) assert len(rewards.shape) == 2 assert len(terminals.shape) == 2 assert len(obs.shape) == 2 assert len(actions.shape) == 2 assert len(next_obs.shape) == 2 return rewards, terminals, obs, actions, next_obs def split_paths_to_dict(paths): rewards, terminals, obs, actions, next_obs = split_paths(paths) return dict( rewards=rewards, terminals=terminals, observations=obs, actions=actions, next_observations=next_obs, ) def get_stat_in_paths(paths, dict_name, scalar_name): if len(paths) == 0: return np.array([[]]) if type(paths[0][dict_name]) == dict: # Support rllab interface return [path[dict_name][scalar_name] for path in paths] return [ [info[scalar_name] for info in path[dict_name]] for path in paths ]