baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
4.82 kB
import abc
class ReplayBuffer(object, metaclass=abc.ABCMeta):
"""
A class used to save and replay data.
"""
@abc.abstractmethod
def add_sample(self, observation, action, reward, next_observation,
terminal, **kwargs):
"""
Add a transition tuple.
"""
pass
@abc.abstractmethod
def terminate_episode(self):
"""
Let the replay buffer know that the episode has terminated in case some
special book-keeping has to happen.
:return:
"""
pass
@abc.abstractmethod
def num_steps_can_sample(self, **kwargs):
"""
:return: # of unique items that can be sampled.
"""
pass
def add_path(self, path):
"""
Add a path to the replay buffer.
This default implementation naively goes through every step, but you
may want to optimize this.
NOTE: You should NOT call "terminate_episode" after calling add_path.
It's assumed that this function handles the episode termination.
:param path: Dict like one outputted by rlkit.samplers.util.rollout
"""
for i, (
obs,
action,
reward,
next_obs,
terminal,
agent_info,
env_info
) in enumerate(zip(
path["observations"],
path["actions"],
path["rewards"],
path["next_observations"],
path["terminals"],
path["agent_infos"],
path["env_infos"],
)):
self.add_sample(
observation=obs,
action=action,
reward=reward,
next_observation=next_obs,
terminal=terminal,
agent_info=agent_info,
env_info=env_info,
)
self.terminate_episode()
def add_paths(self, paths):
for path in paths:
self.add_path(path)
@abc.abstractmethod
def random_batch(self, batch_size):
"""
Return a batch of size `batch_size`.
:param batch_size:
:return:
"""
pass
def get_diagnostics(self):
return {}
def get_snapshot(self):
return {}
def end_epoch(self, epoch):
return
class EnsembleReplayBuffer(object, metaclass=abc.ABCMeta):
"""
A class used to save and replay data.
"""
@abc.abstractmethod
def add_sample(self, observation, action, reward, next_observation,
terminal, **kwargs):
"""
Add a transition tuple.
"""
pass
@abc.abstractmethod
def terminate_episode(self):
"""
Let the replay buffer know that the episode has terminated in case some
special book-keeping has to happen.
:return:
"""
pass
@abc.abstractmethod
def num_steps_can_sample(self, **kwargs):
"""
:return: # of unique items that can be sampled.
"""
pass
def add_path(self, path):
"""
Add a path to the replay buffer.
This default implementation naively goes through every step, but you
may want to optimize this.
NOTE: You should NOT call "terminate_episode" after calling add_path.
It's assumed that this function handles the episode termination.
:param path: Dict like one outputted by rlkit.samplers.util.rollout
"""
for i, (
obs,
action,
reward,
next_obs,
terminal,
agent_info,
env_info,
mask,
) in enumerate(zip(
path["observations"],
path["actions"],
path["rewards"],
path["next_observations"],
path["terminals"],
path["agent_infos"],
path["env_infos"],
path["masks"],
)):
self.add_sample(
observation=obs,
action=action,
reward=reward,
next_observation=next_obs,
terminal=terminal,
mask=mask,
agent_info=agent_info,
env_info=env_info,
)
self.terminate_episode()
def add_paths(self, paths):
for path in paths:
self.add_path(path)
@abc.abstractmethod
def random_batch(self, batch_size):
"""
Return a batch of size `batch_size`.
:param batch_size:
:return:
"""
pass
def get_diagnostics(self):
return {}
def get_snapshot(self):
return {}
def end_epoch(self, epoch):
return