NCERL-Diverse-PCG / src /rlkit /data_management /simple_replay_buffer.py
baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
15.5 kB
from collections import OrderedDict
import numpy as np
import torch
from src.rlkit.data_management.replay_buffer import ReplayBuffer, EnsembleReplayBuffer
class SimpleReplayBuffer(ReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
observation_dim,
action_dim,
env_info_sizes,
):
self._observation_dim = observation_dim
self._action_dim = action_dim
self._max_replay_buffer_size = max_replay_buffer_size
self._observations = np.zeros((max_replay_buffer_size, observation_dim))
# It's a bit memory inefficient to save the observations twice,
# but it makes the code *much* easier since you no longer have to
# worry about termination conditions.
self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
self._actions = np.zeros((max_replay_buffer_size, action_dim))
# Make everything a 2D np array to make it easier for other code to
# reason about the shape of the data
self._rewards = np.zeros((max_replay_buffer_size, 1))
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
# Define self._env_infos[key][i] to be the return value of env_info[key]
# at time i
self._env_infos = {}
for key, size in env_info_sizes.items():
self._env_infos[key] = np.zeros((max_replay_buffer_size, size))
self._env_info_keys = env_info_sizes.keys()
self._top = 0
self._size = 0
def add_sample(self, observation, action, reward, next_observation,
terminal, env_info, **kwargs):
self._observations[self._top] = observation
self._actions[self._top] = action
self._rewards[self._top] = reward
self._terminals[self._top] = terminal
self._next_obs[self._top] = next_observation
for key in self._env_info_keys:
self._env_infos[key][self._top] = env_info[key]
self._advance()
def terminate_episode(self):
pass
def _advance(self):
self._top = (self._top + 1) % self._max_replay_buffer_size
if self._size < self._max_replay_buffer_size:
self._size += 1
def random_batch(self, batch_size):
indices = np.random.randint(0, self._size, batch_size)
batch = dict(
observations=self._observations[indices],
actions=self._actions[indices],
rewards=self._rewards[indices],
terminals=self._terminals[indices],
next_observations=self._next_obs[indices],
)
for key in self._env_info_keys:
assert key not in batch.keys()
batch[key] = self._env_infos[key][indices]
return batch
def rebuild_env_info_dict(self, idx):
return {
key: self._env_infos[key][idx]
for key in self._env_info_keys
}
def batch_env_info_dict(self, indices):
return {
key: self._env_infos[key][indices]
for key in self._env_info_keys
}
def num_steps_can_sample(self):
return self._size
def get_diagnostics(self):
return OrderedDict([
('size', self._size)
])
class EnsembleSimpleReplayBuffer(EnsembleReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
observation_dim,
action_dim,
env_info_sizes,
num_ensemble,
log_dir,
):
self._observation_dim = observation_dim
self._action_dim = action_dim
self._max_replay_buffer_size = max_replay_buffer_size
self._observations = np.zeros((max_replay_buffer_size, observation_dim))
# It's a bit memory inefficient to save the observations twice,
# but it makes the code *much* easier since you no longer have to
# worry about termination conditions.
self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
self._actions = np.zeros((max_replay_buffer_size, action_dim))
# Make everything a 2D np array to make it easier for other code to
# reason about the shape of the data
self._rewards = np.zeros((max_replay_buffer_size, 1))
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
# Define self._env_infos[key][i] to be the return value of env_info[key]
# at time i
self._env_infos = {}
for key, size in env_info_sizes.items():
self._env_infos[key] = np.zeros((max_replay_buffer_size, size))
self._env_info_keys = env_info_sizes.keys()
# define mask
self._mask = np.zeros((max_replay_buffer_size, num_ensemble))
self._top = 0
self._size = 0
self.buffer_dir = log_dir + '/buffer/'
def add_sample(self, observation, action, reward, next_observation,
terminal, mask, env_info, **kwargs):
self._observations[self._top] = observation
self._actions[self._top] = action
self._rewards[self._top] = reward
self._terminals[self._top] = terminal
self._next_obs[self._top] = next_observation
self._mask[self._top] = mask
for key in self._env_info_keys:
self._env_infos[key][self._top] = env_info[key]
self._advance()
def terminate_episode(self):
pass
def _advance(self):
self._top = (self._top + 1) % self._max_replay_buffer_size
if self._size < self._max_replay_buffer_size:
self._size += 1
def random_batch(self, batch_size):
indices = np.random.randint(0, self._size, batch_size)
batch = dict(
observations=self._observations[indices],
actions=self._actions[indices],
rewards=self._rewards[indices],
terminals=self._terminals[indices],
next_observations=self._next_obs[indices],
masks=self._mask[indices],
)
for key in self._env_info_keys:
assert key not in batch.keys()
batch[key] = self._env_infos[key][indices]
return batch
def rebuild_env_info_dict(self, idx):
return {
key: self._env_infos[key][idx]
for key in self._env_info_keys
}
def batch_env_info_dict(self, indices):
return {
key: self._env_infos[key][indices]
for key in self._env_info_keys
}
def num_steps_can_sample(self):
return self._size
def get_diagnostics(self):
return OrderedDict([
('size', self._size)
])
def save_buffer(self, epoch):
path = self.buffer_dir + '/replay_%d.pt' % (epoch)
payload = [
self._observations[:self._size],
self._actions[:self._size],
self._rewards[:self._size],
self._terminals[:self._size],
self._next_obs[:self._size],
self._mask[:self._size],
self._size,
]
torch.save(payload, path)
def load_buffer(self, epoch):
path = self.buffer_dir + '/replay_%d.pt' % (epoch)
payload = torch.load(path)
self._observations = payload[0]
self._actions = payload[1]
self._rewards = payload[2]
self._terminals = payload[3]
self._next_obs = payload[4]
self._mask = payload[5]
self._size = payload[6]
class RandomReplayBuffer(ReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
observation_dim,
action_dim,
env_info_sizes,
single_flag,
equal_flag,
lower,
upper,
):
self._observation_dim = observation_dim
self._action_dim = action_dim
self._max_replay_buffer_size = max_replay_buffer_size
self._observations = np.zeros((max_replay_buffer_size, observation_dim))
# It's a bit memory inefficient to save the observations twice,
# but it makes the code *much* easier since you no longer have to
# worry about termination conditions.
self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
self._actions = np.zeros((max_replay_buffer_size, action_dim))
# Make everything a 2D np array to make it easier for other code to
# reason about the shape of the data
self._rewards = np.zeros((max_replay_buffer_size, 1))
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
# Define self._env_infos[key][i] to be the return value of env_info[key]
# at time i
self._env_infos = {}
for key, size in env_info_sizes.items():
self._env_infos[key] = np.zeros((max_replay_buffer_size, size))
self._env_info_keys = env_info_sizes.keys()
self._top = 0
self._size = 0
# randomization
self.single_flag = single_flag
self.equal_flag = equal_flag
self.lower = lower
self.upper = upper
def add_sample(self, observation, action, reward, next_observation,
terminal, env_info, **kwargs):
self._observations[self._top] = observation
self._actions[self._top] = action
self._rewards[self._top] = reward
self._terminals[self._top] = terminal
self._next_obs[self._top] = next_observation
for key in self._env_info_keys:
self._env_infos[key][self._top] = env_info[key]
self._advance()
def terminate_episode(self):
pass
def _advance(self):
self._top = (self._top + 1) % self._max_replay_buffer_size
if self._size < self._max_replay_buffer_size:
self._size += 1
def random_batch(self, batch_size):
indices = np.random.randint(0, self._size, batch_size)
obs = self._observations[indices]
next_obs = self._next_obs[indices]
if self.single_flag == 0:
random_number_1 = np.random.uniform(self.lower, self.upper, obs.shape[0]).reshape(-1,1)
random_number_2 = np.random.uniform(self.lower, self.upper, obs.shape[0]).reshape(-1,1)
else:
random_number_1 = np.random.uniform(self.lower, self.upper, obs.shape[0]*obs.shape[1]).reshape(obs.shape[0],-1)
random_number_2 = np.random.uniform(self.lower, self.upper, obs.shape[0]*obs.shape[1]).reshape(obs.shape[0],-1)
if self.equal_flag == 0:
obs = obs * random_number_1
next_obs = next_obs * random_number_1
else:
obs = obs * random_number_1
next_obs = next_obs * random_number_2
batch = dict(
observations=obs,
actions=self._actions[indices],
rewards=self._rewards[indices],
terminals=self._terminals[indices],
next_observations=next_obs,
)
for key in self._env_info_keys:
assert key not in batch.keys()
batch[key] = self._env_infos[key][indices]
return batch
def rebuild_env_info_dict(self, idx):
return {
key: self._env_infos[key][idx]
for key in self._env_info_keys
}
def batch_env_info_dict(self, indices):
return {
key: self._env_infos[key][indices]
for key in self._env_info_keys
}
def num_steps_can_sample(self):
return self._size
def get_diagnostics(self):
return OrderedDict([
('size', self._size)
])
class GaussianReplayBuffer(ReplayBuffer):
def __init__(
self,
max_replay_buffer_size,
observation_dim,
action_dim,
env_info_sizes,
prob,
std,
):
self._observation_dim = observation_dim
self._action_dim = action_dim
self._max_replay_buffer_size = max_replay_buffer_size
self._observations = np.zeros((max_replay_buffer_size, observation_dim))
# It's a bit memory inefficient to save the observations twice,
# but it makes the code *much* easier since you no longer have to
# worry about termination conditions.
self._next_obs = np.zeros((max_replay_buffer_size, observation_dim))
self._actions = np.zeros((max_replay_buffer_size, action_dim))
# Make everything a 2D np array to make it easier for other code to
# reason about the shape of the data
self._rewards = np.zeros((max_replay_buffer_size, 1))
# self._terminals[i] = a terminal was received at time i
self._terminals = np.zeros((max_replay_buffer_size, 1), dtype='uint8')
# Define self._env_infos[key][i] to be the return value of env_info[key]
# at time i
self._env_infos = {}
for key, size in env_info_sizes.items():
self._env_infos[key] = np.zeros((max_replay_buffer_size, size))
self._env_info_keys = env_info_sizes.keys()
self._top = 0
self._size = 0
# randomization
self.prob = prob
self.std = std
def add_sample(self, observation, action, reward, next_observation,
terminal, env_info, **kwargs):
self._observations[self._top] = observation
self._actions[self._top] = action
self._rewards[self._top] = reward
self._terminals[self._top] = terminal
self._next_obs[self._top] = next_observation
for key in self._env_info_keys:
self._env_infos[key][self._top] = env_info[key]
self._advance()
def terminate_episode(self):
pass
def _advance(self):
self._top = (self._top + 1) % self._max_replay_buffer_size
if self._size < self._max_replay_buffer_size:
self._size += 1
def random_batch(self, batch_size):
indices = np.random.randint(0, self._size, batch_size)
obs = self._observations[indices]
next_obs = self._next_obs[indices]
num_batch, dim_input = obs.shape[0], obs.shape[1]
noise = np.random.normal(0, self.std, num_batch*dim_input).reshape(num_batch, -1)
mask = np.random.uniform(0, 1, num_batch).reshape(num_batch, -1) < self.prob
noise = noise * mask
obs = obs + noise
next_obs = next_obs + noise
batch = dict(
observations=obs,
actions=self._actions[indices],
rewards=self._rewards[indices],
terminals=self._terminals[indices],
next_observations=next_obs,
)
for key in self._env_info_keys:
assert key not in batch.keys()
batch[key] = self._env_infos[key][indices]
return batch
def rebuild_env_info_dict(self, idx):
return {
key: self._env_infos[key][idx]
for key in self._env_info_keys
}
def batch_env_info_dict(self, indices):
return {
key: self._env_infos[key][indices]
for key in self._env_info_keys
}
def num_steps_can_sample(self):
return self._size
def get_diagnostics(self):
return OrderedDict([
('size', self._size)
])