Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Replay buffer. | |
Implements replay buffer in Python. | |
""" | |
import random | |
import numpy as np | |
from six.moves import xrange | |
class ReplayBuffer(object): | |
def __init__(self, max_size): | |
self.max_size = max_size | |
self.cur_size = 0 | |
self.buffer = {} | |
self.init_length = 0 | |
def __len__(self): | |
return self.cur_size | |
def seed_buffer(self, episodes): | |
self.init_length = len(episodes) | |
self.add(episodes, np.ones(self.init_length)) | |
def add(self, episodes, *args): | |
"""Add episodes to buffer.""" | |
idx = 0 | |
while self.cur_size < self.max_size and idx < len(episodes): | |
self.buffer[self.cur_size] = episodes[idx] | |
self.cur_size += 1 | |
idx += 1 | |
if idx < len(episodes): | |
remove_idxs = self.remove_n(len(episodes) - idx) | |
for remove_idx in remove_idxs: | |
self.buffer[remove_idx] = episodes[idx] | |
idx += 1 | |
assert len(self.buffer) == self.cur_size | |
def remove_n(self, n): | |
"""Get n items for removal.""" | |
# random removal | |
idxs = random.sample(xrange(self.init_length, self.cur_size), n) | |
return idxs | |
def get_batch(self, n): | |
"""Get batch of episodes to train on.""" | |
# random batch | |
idxs = random.sample(xrange(self.cur_size), n) | |
return [self.buffer[idx] for idx in idxs], None | |
def update_last_batch(self, delta): | |
pass | |
class PrioritizedReplayBuffer(ReplayBuffer): | |
def __init__(self, max_size, alpha=0.2, | |
eviction_strategy='rand'): | |
self.max_size = max_size | |
self.alpha = alpha | |
self.eviction_strategy = eviction_strategy | |
assert self.eviction_strategy in ['rand', 'fifo', 'rank'] | |
self.remove_idx = 0 | |
self.cur_size = 0 | |
self.buffer = {} | |
self.priorities = np.zeros(self.max_size) | |
self.init_length = 0 | |
def __len__(self): | |
return self.cur_size | |
def add(self, episodes, priorities, new_idxs=None): | |
"""Add episodes to buffer.""" | |
if new_idxs is None: | |
idx = 0 | |
new_idxs = [] | |
while self.cur_size < self.max_size and idx < len(episodes): | |
self.buffer[self.cur_size] = episodes[idx] | |
new_idxs.append(self.cur_size) | |
self.cur_size += 1 | |
idx += 1 | |
if idx < len(episodes): | |
remove_idxs = self.remove_n(len(episodes) - idx) | |
for remove_idx in remove_idxs: | |
self.buffer[remove_idx] = episodes[idx] | |
new_idxs.append(remove_idx) | |
idx += 1 | |
else: | |
assert len(new_idxs) == len(episodes) | |
for new_idx, ep in zip(new_idxs, episodes): | |
self.buffer[new_idx] = ep | |
self.priorities[new_idxs] = priorities | |
self.priorities[0:self.init_length] = np.max( | |
self.priorities[self.init_length:]) | |
assert len(self.buffer) == self.cur_size | |
return new_idxs | |
def remove_n(self, n): | |
"""Get n items for removal.""" | |
assert self.init_length + n <= self.cur_size | |
if self.eviction_strategy == 'rand': | |
# random removal | |
idxs = random.sample(xrange(self.init_length, self.cur_size), n) | |
elif self.eviction_strategy == 'fifo': | |
# overwrite elements in cyclical fashion | |
idxs = [ | |
self.init_length + | |
(self.remove_idx + i) % (self.max_size - self.init_length) | |
for i in xrange(n)] | |
self.remove_idx = idxs[-1] + 1 - self.init_length | |
elif self.eviction_strategy == 'rank': | |
# remove lowest-priority indices | |
idxs = np.argpartition(self.priorities, n-1)[:n] | |
return idxs | |
def sampling_distribution(self): | |
p = self.priorities[:self.cur_size] | |
p = np.exp(self.alpha * (p - np.max(p))) | |
norm = np.sum(p) | |
if norm > 0: | |
uniform = 0.0 | |
p = p / norm * (1 - uniform) + 1.0 / self.cur_size * uniform | |
else: | |
p = np.ones(self.cur_size) / self.cur_size | |
return p | |
def get_batch(self, n): | |
"""Get batch of episodes to train on.""" | |
p = self.sampling_distribution() | |
idxs = np.random.choice(self.cur_size, size=int(n), replace=False, p=p) | |
self.last_batch = idxs | |
return [self.buffer[idx] for idx in idxs], p[idxs] | |
def update_last_batch(self, delta): | |
"""Update last batch idxs with new priority.""" | |
self.priorities[self.last_batch] = np.abs(delta) | |
self.priorities[0:self.init_length] = np.max( | |
self.priorities[self.init_length:]) | |