from gym.spaces import Discrete from src.rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer, EnsembleSimpleReplayBuffer from src.rlkit.data_management.simple_replay_buffer import RandomReplayBuffer, GaussianReplayBuffer from src.rlkit.envs.env_utils import get_dim import numpy as np class EnvReplayBuffer(SimpleReplayBuffer): def __init__( self, max_replay_buffer_size, env, env_info_sizes=None ): """ :param max_replay_buffer_size: :param env: """ self.env = env self._ob_space = env.observation_space self._action_space = env.action_space if env_info_sizes is None: if hasattr(env, 'info_sizes'): env_info_sizes = env.info_sizes else: env_info_sizes = dict() super().__init__( max_replay_buffer_size=max_replay_buffer_size, observation_dim=get_dim(self._ob_space), action_dim=get_dim(self._action_space), env_info_sizes=env_info_sizes ) def add_sample(self, observation, action, reward, terminal, next_observation, **kwargs): if isinstance(self._action_space, Discrete): new_action = np.zeros(self._action_dim) new_action[action] = 1 else: new_action = action return super().add_sample( observation=observation, action=new_action, reward=reward, next_observation=next_observation, terminal=terminal, **kwargs ) class EnsembleEnvReplayBuffer(EnsembleSimpleReplayBuffer): def __init__( self, max_replay_buffer_size, env, num_ensemble, log_dir, env_info_sizes=None ): """ :param max_replay_buffer_size: :param env: """ self.env = env self._ob_space = env.observation_space self._action_space = env.action_space if env_info_sizes is None: if hasattr(env, 'info_sizes'): env_info_sizes = env.info_sizes else: env_info_sizes = dict() super().__init__( max_replay_buffer_size=max_replay_buffer_size, observation_dim=get_dim(self._ob_space), action_dim=get_dim(self._action_space), env_info_sizes=env_info_sizes, num_ensemble=num_ensemble, log_dir=log_dir, ) def add_sample(self, observation, action, reward, terminal, next_observation, mask, **kwargs): if isinstance(self._action_space, Discrete): new_action = np.zeros(self._action_dim) new_action[action] = 1 else: new_action = action return super().add_sample( observation=observation, action=new_action, reward=reward, next_observation=next_observation, terminal=terminal, mask=mask, **kwargs ) class RandomEnvReplayBuffer(RandomReplayBuffer): def __init__( self, max_replay_buffer_size, env, single_flag, equal_flag, lower, upper, env_info_sizes=None ): """ :param max_replay_buffer_size: :param env: """ self.env = env self._ob_space = env.observation_space self._action_space = env.action_space if env_info_sizes is None: if hasattr(env, 'info_sizes'): env_info_sizes = env.info_sizes else: env_info_sizes = dict() super().__init__( max_replay_buffer_size=max_replay_buffer_size, observation_dim=get_dim(self._ob_space), action_dim=get_dim(self._action_space), env_info_sizes=env_info_sizes, single_flag=single_flag, equal_flag=equal_flag, lower=lower, upper=upper, ) def add_sample(self, observation, action, reward, terminal, next_observation, **kwargs): if isinstance(self._action_space, Discrete): new_action = np.zeros(self._action_dim) new_action[action] = 1 else: new_action = action return super().add_sample( observation=observation, action=new_action, reward=reward, next_observation=next_observation, terminal=terminal, **kwargs ) class GaussianEnvReplayBuffer(GaussianReplayBuffer): def __init__( self, max_replay_buffer_size, env, prob, std, env_info_sizes=None ): """ :param max_replay_buffer_size: :param env: """ self.env = env self._ob_space = env.observation_space self._action_space = env.action_space if env_info_sizes is None: if hasattr(env, 'info_sizes'): env_info_sizes = env.info_sizes else: env_info_sizes = dict() super().__init__( max_replay_buffer_size=max_replay_buffer_size, observation_dim=get_dim(self._ob_space), action_dim=get_dim(self._action_space), env_info_sizes=env_info_sizes, prob=prob, std=std, ) def add_sample(self, observation, action, reward, terminal, next_observation, **kwargs): if isinstance(self._action_space, Discrete): new_action = np.zeros(self._action_dim) new_action[action] = 1 else: new_action = action return super().add_sample( observation=observation, action=new_action, reward=reward, next_observation=next_observation, terminal=terminal, **kwargs )