NCERL-Diverse-PCG / src /rlkit /data_management /shared_obs_dict_replay_buffer.py
baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
4.4 kB
import numpy as np
from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer
import torch.multiprocessing as mp
import ctypes
class SharedObsDictRelabelingBuffer(ObsDictRelabelingBuffer):
"""
Same as an ObsDictRelabelingBuffer but the obs and next_obs are backed
by multiprocessing arrays. The replay buffer size is also shared. The
intended use case is for if one wants obs/next_obs to be shared between
processes. Accesses are synchronized internally by locks (mp takes care
of that). Technically, putting such large arrays in shared memory/requiring
synchronized access can be extremely slow, but it seems ok empirically.
This code also breaks a lot of functionality for the subprocess. For example,
random_batch is incorrect as actions and _idx_to_future_obs_idx are not
shared. If the subprocess needs all of the functionality, a mp.Array
must be used for all numpy arrays in the replay buffer.
"""
def __init__(
self,
*args,
**kwargs
):
self._shared_size = mp.Value(ctypes.c_long, 0)
ObsDictRelabelingBuffer.__init__(self, *args, **kwargs)
self._mp_array_info = {}
self._shared_obs_info = {}
self._shared_next_obs_info = {}
for obs_key, obs_arr in self._obs.items():
ctype = ctypes.c_double
if obs_arr.dtype == np.uint8:
ctype = ctypes.c_uint8
self._shared_obs_info[obs_key] = (
mp.Array(ctype, obs_arr.size),
obs_arr.dtype,
obs_arr.shape,
)
self._shared_next_obs_info[obs_key] = (
mp.Array(ctype, obs_arr.size),
obs_arr.dtype,
obs_arr.shape,
)
self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key])
self._next_obs[obs_key] = to_np(
*self._shared_next_obs_info[obs_key])
self._register_mp_array("_actions")
self._register_mp_array("_terminals")
def _register_mp_array(self, arr_instance_var_name):
"""
Use this function to register an array to be shared. This will wipe arr.
"""
assert hasattr(self, arr_instance_var_name), arr_instance_var_name
arr = getattr(self, arr_instance_var_name)
ctype = ctypes.c_double
if arr.dtype == np.uint8:
ctype = ctypes.c_uint8
self._mp_array_info[arr_instance_var_name] = (
mp.Array(ctype, arr.size), arr.dtype, arr.shape,
)
setattr(
self,
arr_instance_var_name,
to_np(*self._mp_array_info[arr_instance_var_name])
)
def init_from_mp_info(
self,
mp_info,
):
"""
The intended use is to have a subprocess serialize/copy a
SharedObsDictRelabelingBuffer instance and call init_from on the
instance's shared variables. This can't be done during serialization
since multiprocessing shared objects can't be serialized and must be
passed directly to the subprocess as an argument to the fork call.
"""
shared_obs_info, shared_next_obs_info, mp_array_info, shared_size = mp_info
self._shared_obs_info = shared_obs_info
self._shared_next_obs_info = shared_next_obs_info
self._mp_array_info = mp_array_info
for obs_key in self._shared_obs_info.keys():
self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key])
self._next_obs[obs_key] = to_np(
*self._shared_next_obs_info[obs_key])
for arr_instance_var_name in self._mp_array_info.keys():
setattr(
self,
arr_instance_var_name,
to_np(*self._mp_array_info[arr_instance_var_name])
)
self._shared_size = shared_size
def get_mp_info(self):
return (
self._shared_obs_info,
self._shared_next_obs_info,
self._mp_array_info,
self._shared_size,
)
@property
def _size(self):
return self._shared_size.value
@_size.setter
def _size(self, size):
self._shared_size.value = size
def to_np(shared_arr, np_dtype, shape):
return np.frombuffer(shared_arr.get_obj(), dtype=np_dtype).reshape(shape)