baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
6.13 kB
from gym.spaces import Discrete
from src.rlkit.data_management.simple_replay_buffer import SimpleReplayBuffer, EnsembleSimpleReplayBuffer
from src.rlkit.data_management.simple_replay_buffer import RandomReplayBuffer, GaussianReplayBuffer
from src.rlkit.envs.env_utils import get_dim
import numpy as np
class EnvReplayBuffer(SimpleReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
env,
env_info_sizes=None
):
"""
:param max_replay_buffer_size:
:param env:
"""
self.env = env
self._ob_space = env.observation_space
self._action_space = env.action_space
if env_info_sizes is None:
if hasattr(env, 'info_sizes'):
env_info_sizes = env.info_sizes
else:
env_info_sizes = dict()
super().__init__(
max_replay_buffer_size=max_replay_buffer_size,
observation_dim=get_dim(self._ob_space),
action_dim=get_dim(self._action_space),
env_info_sizes=env_info_sizes
)
def add_sample(self, observation, action, reward, terminal,
next_observation, **kwargs):
if isinstance(self._action_space, Discrete):
new_action = np.zeros(self._action_dim)
new_action[action] = 1
else:
new_action = action
return super().add_sample(
observation=observation,
action=new_action,
reward=reward,
next_observation=next_observation,
terminal=terminal,
**kwargs
)
class EnsembleEnvReplayBuffer(EnsembleSimpleReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
env,
num_ensemble,
log_dir,
env_info_sizes=None
):
"""
:param max_replay_buffer_size:
:param env:
"""
self.env = env
self._ob_space = env.observation_space
self._action_space = env.action_space
if env_info_sizes is None:
if hasattr(env, 'info_sizes'):
env_info_sizes = env.info_sizes
else:
env_info_sizes = dict()
super().__init__(
max_replay_buffer_size=max_replay_buffer_size,
observation_dim=get_dim(self._ob_space),
action_dim=get_dim(self._action_space),
env_info_sizes=env_info_sizes,
num_ensemble=num_ensemble,
log_dir=log_dir,
)
def add_sample(self, observation, action, reward, terminal,
next_observation, mask, **kwargs):
if isinstance(self._action_space, Discrete):
new_action = np.zeros(self._action_dim)
new_action[action] = 1
else:
new_action = action
return super().add_sample(
observation=observation,
action=new_action,
reward=reward,
next_observation=next_observation,
terminal=terminal,
mask=mask,
**kwargs
)
class RandomEnvReplayBuffer(RandomReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
env,
single_flag,
equal_flag,
lower,
upper,
env_info_sizes=None
):
"""
:param max_replay_buffer_size:
:param env:
"""
self.env = env
self._ob_space = env.observation_space
self._action_space = env.action_space
if env_info_sizes is None:
if hasattr(env, 'info_sizes'):
env_info_sizes = env.info_sizes
else:
env_info_sizes = dict()
super().__init__(
max_replay_buffer_size=max_replay_buffer_size,
observation_dim=get_dim(self._ob_space),
action_dim=get_dim(self._action_space),
env_info_sizes=env_info_sizes,
single_flag=single_flag,
equal_flag=equal_flag,
lower=lower,
upper=upper,
)
def add_sample(self, observation, action, reward, terminal,
next_observation, **kwargs):
if isinstance(self._action_space, Discrete):
new_action = np.zeros(self._action_dim)
new_action[action] = 1
else:
new_action = action
return super().add_sample(
observation=observation,
action=new_action,
reward=reward,
next_observation=next_observation,
terminal=terminal,
**kwargs
)
class GaussianEnvReplayBuffer(GaussianReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
env,
prob,
std,
env_info_sizes=None
):
"""
:param max_replay_buffer_size:
:param env:
"""
self.env = env
self._ob_space = env.observation_space
self._action_space = env.action_space
if env_info_sizes is None:
if hasattr(env, 'info_sizes'):
env_info_sizes = env.info_sizes
else:
env_info_sizes = dict()
super().__init__(
max_replay_buffer_size=max_replay_buffer_size,
observation_dim=get_dim(self._ob_space),
action_dim=get_dim(self._action_space),
env_info_sizes=env_info_sizes,
prob=prob,
std=std,
)
def add_sample(self, observation, action, reward, terminal,
next_observation, **kwargs):
if isinstance(self._action_space, Discrete):
new_action = np.zeros(self._action_dim)
new_action[action] = 1
else:
new_action = action
return super().add_sample(
observation=observation,
action=new_action,
reward=reward,
next_observation=next_observation,
terminal=terminal,
**kwargs
)