Spaces:
Running
Running
File size: 1,391 Bytes
48c5871 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# from collections import deque
import numpy as np
import random
import torch
import pickle as pickle
class rpm(object):
# replay memory
def __init__(self, buffer_size):
self.buffer_size = buffer_size
self.buffer = []
self.index = 0
def append(self, obj):
if self.size() > self.buffer_size:
print('buffer size larger than set value, trimming...')
self.buffer = self.buffer[(self.size() - self.buffer_size):]
elif self.size() == self.buffer_size:
self.buffer[self.index] = obj
self.index += 1
self.index %= self.buffer_size
else:
self.buffer.append(obj)
def size(self):
return len(self.buffer)
def sample_batch(self, batch_size, device, only_state=False):
if self.size() < batch_size:
batch = random.sample(self.buffer, self.size())
else:
batch = random.sample(self.buffer, batch_size)
if only_state:
res = torch.stack(tuple(item[3] for item in batch), dim=0)
return res.to(device)
else:
item_count = 5
res = []
for i in range(5):
k = torch.stack(tuple(item[i] for item in batch), dim=0)
res.append(k.to(device))
return res[0], res[1], res[2], res[3], res[4]
|