# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Replay buffer. Implements replay buffer in Python. """ import random import numpy as np from six.moves import xrange class ReplayBuffer(object): def __init__(self, max_size): self.max_size = max_size self.cur_size = 0 self.buffer = {} self.init_length = 0 def __len__(self): return self.cur_size def seed_buffer(self, episodes): self.init_length = len(episodes) self.add(episodes, np.ones(self.init_length)) def add(self, episodes, *args): """Add episodes to buffer.""" idx = 0 while self.cur_size < self.max_size and idx < len(episodes): self.buffer[self.cur_size] = episodes[idx] self.cur_size += 1 idx += 1 if idx < len(episodes): remove_idxs = self.remove_n(len(episodes) - idx) for remove_idx in remove_idxs: self.buffer[remove_idx] = episodes[idx] idx += 1 assert len(self.buffer) == self.cur_size def remove_n(self, n): """Get n items for removal.""" # random removal idxs = random.sample(xrange(self.init_length, self.cur_size), n) return idxs def get_batch(self, n): """Get batch of episodes to train on.""" # random batch idxs = random.sample(xrange(self.cur_size), n) return [self.buffer[idx] for idx in idxs], None def update_last_batch(self, delta): pass class PrioritizedReplayBuffer(ReplayBuffer): def __init__(self, max_size, alpha=0.2, eviction_strategy='rand'): self.max_size = max_size self.alpha = alpha self.eviction_strategy = eviction_strategy assert self.eviction_strategy in ['rand', 'fifo', 'rank'] self.remove_idx = 0 self.cur_size = 0 self.buffer = {} self.priorities = np.zeros(self.max_size) self.init_length = 0 def __len__(self): return self.cur_size def add(self, episodes, priorities, new_idxs=None): """Add episodes to buffer.""" if new_idxs is None: idx = 0 new_idxs = [] while self.cur_size < self.max_size and idx < len(episodes): self.buffer[self.cur_size] = episodes[idx] new_idxs.append(self.cur_size) self.cur_size += 1 idx += 1 if idx < len(episodes): remove_idxs = self.remove_n(len(episodes) - idx) for remove_idx in remove_idxs: self.buffer[remove_idx] = episodes[idx] new_idxs.append(remove_idx) idx += 1 else: assert len(new_idxs) == len(episodes) for new_idx, ep in zip(new_idxs, episodes): self.buffer[new_idx] = ep self.priorities[new_idxs] = priorities self.priorities[0:self.init_length] = np.max( self.priorities[self.init_length:]) assert len(self.buffer) == self.cur_size return new_idxs def remove_n(self, n): """Get n items for removal.""" assert self.init_length + n <= self.cur_size if self.eviction_strategy == 'rand': # random removal idxs = random.sample(xrange(self.init_length, self.cur_size), n) elif self.eviction_strategy == 'fifo': # overwrite elements in cyclical fashion idxs = [ self.init_length + (self.remove_idx + i) % (self.max_size - self.init_length) for i in xrange(n)] self.remove_idx = idxs[-1] + 1 - self.init_length elif self.eviction_strategy == 'rank': # remove lowest-priority indices idxs = np.argpartition(self.priorities, n-1)[:n] return idxs def sampling_distribution(self): p = self.priorities[:self.cur_size] p = np.exp(self.alpha * (p - np.max(p))) norm = np.sum(p) if norm > 0: uniform = 0.0 p = p / norm * (1 - uniform) + 1.0 / self.cur_size * uniform else: p = np.ones(self.cur_size) / self.cur_size return p def get_batch(self, n): """Get batch of episodes to train on.""" p = self.sampling_distribution() idxs = np.random.choice(self.cur_size, size=int(n), replace=False, p=p) self.last_batch = idxs return [self.buffer[idx] for idx in idxs], p[idxs] def update_last_batch(self, delta): """Update last batch idxs with new priority.""" self.priorities[self.last_batch] = np.abs(delta) self.priorities[0:self.init_length] = np.max( self.priorities[self.init_length:])