Spaces:
Sleeping
Sleeping
import numpy as np | |
from rlkit.data_management.obs_dict_replay_buffer import ObsDictRelabelingBuffer | |
import torch.multiprocessing as mp | |
import ctypes | |
class SharedObsDictRelabelingBuffer(ObsDictRelabelingBuffer): | |
""" | |
Same as an ObsDictRelabelingBuffer but the obs and next_obs are backed | |
by multiprocessing arrays. The replay buffer size is also shared. The | |
intended use case is for if one wants obs/next_obs to be shared between | |
processes. Accesses are synchronized internally by locks (mp takes care | |
of that). Technically, putting such large arrays in shared memory/requiring | |
synchronized access can be extremely slow, but it seems ok empirically. | |
This code also breaks a lot of functionality for the subprocess. For example, | |
random_batch is incorrect as actions and _idx_to_future_obs_idx are not | |
shared. If the subprocess needs all of the functionality, a mp.Array | |
must be used for all numpy arrays in the replay buffer. | |
""" | |
def __init__( | |
self, | |
*args, | |
**kwargs | |
): | |
self._shared_size = mp.Value(ctypes.c_long, 0) | |
ObsDictRelabelingBuffer.__init__(self, *args, **kwargs) | |
self._mp_array_info = {} | |
self._shared_obs_info = {} | |
self._shared_next_obs_info = {} | |
for obs_key, obs_arr in self._obs.items(): | |
ctype = ctypes.c_double | |
if obs_arr.dtype == np.uint8: | |
ctype = ctypes.c_uint8 | |
self._shared_obs_info[obs_key] = ( | |
mp.Array(ctype, obs_arr.size), | |
obs_arr.dtype, | |
obs_arr.shape, | |
) | |
self._shared_next_obs_info[obs_key] = ( | |
mp.Array(ctype, obs_arr.size), | |
obs_arr.dtype, | |
obs_arr.shape, | |
) | |
self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key]) | |
self._next_obs[obs_key] = to_np( | |
*self._shared_next_obs_info[obs_key]) | |
self._register_mp_array("_actions") | |
self._register_mp_array("_terminals") | |
def _register_mp_array(self, arr_instance_var_name): | |
""" | |
Use this function to register an array to be shared. This will wipe arr. | |
""" | |
assert hasattr(self, arr_instance_var_name), arr_instance_var_name | |
arr = getattr(self, arr_instance_var_name) | |
ctype = ctypes.c_double | |
if arr.dtype == np.uint8: | |
ctype = ctypes.c_uint8 | |
self._mp_array_info[arr_instance_var_name] = ( | |
mp.Array(ctype, arr.size), arr.dtype, arr.shape, | |
) | |
setattr( | |
self, | |
arr_instance_var_name, | |
to_np(*self._mp_array_info[arr_instance_var_name]) | |
) | |
def init_from_mp_info( | |
self, | |
mp_info, | |
): | |
""" | |
The intended use is to have a subprocess serialize/copy a | |
SharedObsDictRelabelingBuffer instance and call init_from on the | |
instance's shared variables. This can't be done during serialization | |
since multiprocessing shared objects can't be serialized and must be | |
passed directly to the subprocess as an argument to the fork call. | |
""" | |
shared_obs_info, shared_next_obs_info, mp_array_info, shared_size = mp_info | |
self._shared_obs_info = shared_obs_info | |
self._shared_next_obs_info = shared_next_obs_info | |
self._mp_array_info = mp_array_info | |
for obs_key in self._shared_obs_info.keys(): | |
self._obs[obs_key] = to_np(*self._shared_obs_info[obs_key]) | |
self._next_obs[obs_key] = to_np( | |
*self._shared_next_obs_info[obs_key]) | |
for arr_instance_var_name in self._mp_array_info.keys(): | |
setattr( | |
self, | |
arr_instance_var_name, | |
to_np(*self._mp_array_info[arr_instance_var_name]) | |
) | |
self._shared_size = shared_size | |
def get_mp_info(self): | |
return ( | |
self._shared_obs_info, | |
self._shared_next_obs_info, | |
self._mp_array_info, | |
self._shared_size, | |
) | |
def _size(self): | |
return self._shared_size.value | |
def _size(self, size): | |
self._shared_size.value = size | |
def to_np(shared_arr, np_dtype, shape): | |
return np.frombuffer(shared_arr.get_obj(), dtype=np_dtype).reshape(shape) | |