baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
1.24 kB
import numpy as np
class PathBuilder(dict):
"""
Usage:
```
path_builder = PathBuilder()
path.add_sample(
observations=1,
actions=2,
next_observations=3,
...
)
path.add_sample(
observations=4,
actions=5,
next_observations=6,
...
)
path = path_builder.get_all_stacked()
path['observations']
# output: [1, 4]
path['actions']
# output: [2, 5]
```
Note that the key should be "actions" and not "action" since the
resulting dictionary will have those keys.
"""
def __init__(self):
super().__init__()
self._path_length = 0
def add_all(self, **key_to_value):
for k, v in key_to_value.items():
if k not in self:
self[k] = [v]
else:
self[k].append(v)
self._path_length += 1
def get_all_stacked(self):
output_dict = dict()
for k, v in self.items():
output_dict[k] = stack_list(v)
return output_dict
def __len__(self):
return self._path_length
def stack_list(lst):
if isinstance(lst[0], dict):
return lst
else:
return np.array(lst)