Spaces:
Sleeping
Sleeping
from gym.spaces import Discrete | |
from src.rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer, EnsembleSimpleReplayBuffer | |
from src.rlkit.data_management.simple_replay_buffer import RandomReplayBuffer, GaussianReplayBuffer | |
from src.rlkit.envs.env_utils import get_dim | |
import numpy as np | |
class EnvReplayBuffer(SimpleReplayBuffer): | |
def __init__( | |
self, | |
max_replay_buffer_size, | |
env, | |
env_info_sizes=None | |
): | |
""" | |
:param max_replay_buffer_size: | |
:param env: | |
""" | |
self.env = env | |
self._ob_space = env.observation_space | |
self._action_space = env.action_space | |
if env_info_sizes is None: | |
if hasattr(env, 'info_sizes'): | |
env_info_sizes = env.info_sizes | |
else: | |
env_info_sizes = dict() | |
super().__init__( | |
max_replay_buffer_size=max_replay_buffer_size, | |
observation_dim=get_dim(self._ob_space), | |
action_dim=get_dim(self._action_space), | |
env_info_sizes=env_info_sizes | |
) | |
def add_sample(self, observation, action, reward, terminal, | |
next_observation, **kwargs): | |
if isinstance(self._action_space, Discrete): | |
new_action = np.zeros(self._action_dim) | |
new_action[action] = 1 | |
else: | |
new_action = action | |
return super().add_sample( | |
observation=observation, | |
action=new_action, | |
reward=reward, | |
next_observation=next_observation, | |
terminal=terminal, | |
**kwargs | |
) | |
class EnsembleEnvReplayBuffer(EnsembleSimpleReplayBuffer): | |
def __init__( | |
self, | |
max_replay_buffer_size, | |
env, | |
num_ensemble, | |
log_dir, | |
env_info_sizes=None | |
): | |
""" | |
:param max_replay_buffer_size: | |
:param env: | |
""" | |
self.env = env | |
self._ob_space = env.observation_space | |
self._action_space = env.action_space | |
if env_info_sizes is None: | |
if hasattr(env, 'info_sizes'): | |
env_info_sizes = env.info_sizes | |
else: | |
env_info_sizes = dict() | |
super().__init__( | |
max_replay_buffer_size=max_replay_buffer_size, | |
observation_dim=get_dim(self._ob_space), | |
action_dim=get_dim(self._action_space), | |
env_info_sizes=env_info_sizes, | |
num_ensemble=num_ensemble, | |
log_dir=log_dir, | |
) | |
def add_sample(self, observation, action, reward, terminal, | |
next_observation, mask, **kwargs): | |
if isinstance(self._action_space, Discrete): | |
new_action = np.zeros(self._action_dim) | |
new_action[action] = 1 | |
else: | |
new_action = action | |
return super().add_sample( | |
observation=observation, | |
action=new_action, | |
reward=reward, | |
next_observation=next_observation, | |
terminal=terminal, | |
mask=mask, | |
**kwargs | |
) | |
class RandomEnvReplayBuffer(RandomReplayBuffer): | |
def __init__( | |
self, | |
max_replay_buffer_size, | |
env, | |
single_flag, | |
equal_flag, | |
lower, | |
upper, | |
env_info_sizes=None | |
): | |
""" | |
:param max_replay_buffer_size: | |
:param env: | |
""" | |
self.env = env | |
self._ob_space = env.observation_space | |
self._action_space = env.action_space | |
if env_info_sizes is None: | |
if hasattr(env, 'info_sizes'): | |
env_info_sizes = env.info_sizes | |
else: | |
env_info_sizes = dict() | |
super().__init__( | |
max_replay_buffer_size=max_replay_buffer_size, | |
observation_dim=get_dim(self._ob_space), | |
action_dim=get_dim(self._action_space), | |
env_info_sizes=env_info_sizes, | |
single_flag=single_flag, | |
equal_flag=equal_flag, | |
lower=lower, | |
upper=upper, | |
) | |
def add_sample(self, observation, action, reward, terminal, | |
next_observation, **kwargs): | |
if isinstance(self._action_space, Discrete): | |
new_action = np.zeros(self._action_dim) | |
new_action[action] = 1 | |
else: | |
new_action = action | |
return super().add_sample( | |
observation=observation, | |
action=new_action, | |
reward=reward, | |
next_observation=next_observation, | |
terminal=terminal, | |
**kwargs | |
) | |
class GaussianEnvReplayBuffer(GaussianReplayBuffer): | |
def __init__( | |
self, | |
max_replay_buffer_size, | |
env, | |
prob, | |
std, | |
env_info_sizes=None | |
): | |
""" | |
:param max_replay_buffer_size: | |
:param env: | |
""" | |
self.env = env | |
self._ob_space = env.observation_space | |
self._action_space = env.action_space | |
if env_info_sizes is None: | |
if hasattr(env, 'info_sizes'): | |
env_info_sizes = env.info_sizes | |
else: | |
env_info_sizes = dict() | |
super().__init__( | |
max_replay_buffer_size=max_replay_buffer_size, | |
observation_dim=get_dim(self._ob_space), | |
action_dim=get_dim(self._action_space), | |
env_info_sizes=env_info_sizes, | |
prob=prob, | |
std=std, | |
) | |
def add_sample(self, observation, action, reward, terminal, | |
next_observation, **kwargs): | |
if isinstance(self._action_space, Discrete): | |
new_action = np.zeros(self._action_dim) | |
new_action[action] = 1 | |
else: | |
new_action = action | |
return super().add_sample( | |
observation=observation, | |
action=new_action, | |
reward=reward, | |
next_observation=next_observation, | |
terminal=terminal, | |
**kwargs | |
) | |