NCTC / models /research /pcl_rl /replay_buffer.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
4.93 kB
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Replay buffer.
Implements replay buffer in Python.
"""
import random
import numpy as np
from six.moves import xrange
class ReplayBuffer(object):
def __init__(self, max_size):
self.max_size = max_size
self.cur_size = 0
self.buffer = {}
self.init_length = 0
def __len__(self):
return self.cur_size
def seed_buffer(self, episodes):
self.init_length = len(episodes)
self.add(episodes, np.ones(self.init_length))
def add(self, episodes, *args):
"""Add episodes to buffer."""
idx = 0
while self.cur_size < self.max_size and idx < len(episodes):
self.buffer[self.cur_size] = episodes[idx]
self.cur_size += 1
idx += 1
if idx < len(episodes):
remove_idxs = self.remove_n(len(episodes) - idx)
for remove_idx in remove_idxs:
self.buffer[remove_idx] = episodes[idx]
idx += 1
assert len(self.buffer) == self.cur_size
def remove_n(self, n):
"""Get n items for removal."""
# random removal
idxs = random.sample(xrange(self.init_length, self.cur_size), n)
return idxs
def get_batch(self, n):
"""Get batch of episodes to train on."""
# random batch
idxs = random.sample(xrange(self.cur_size), n)
return [self.buffer[idx] for idx in idxs], None
def update_last_batch(self, delta):
pass
class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, max_size, alpha=0.2,
eviction_strategy='rand'):
self.max_size = max_size
self.alpha = alpha
self.eviction_strategy = eviction_strategy
assert self.eviction_strategy in ['rand', 'fifo', 'rank']
self.remove_idx = 0
self.cur_size = 0
self.buffer = {}
self.priorities = np.zeros(self.max_size)
self.init_length = 0
def __len__(self):
return self.cur_size
def add(self, episodes, priorities, new_idxs=None):
"""Add episodes to buffer."""
if new_idxs is None:
idx = 0
new_idxs = []
while self.cur_size < self.max_size and idx < len(episodes):
self.buffer[self.cur_size] = episodes[idx]
new_idxs.append(self.cur_size)
self.cur_size += 1
idx += 1
if idx < len(episodes):
remove_idxs = self.remove_n(len(episodes) - idx)
for remove_idx in remove_idxs:
self.buffer[remove_idx] = episodes[idx]
new_idxs.append(remove_idx)
idx += 1
else:
assert len(new_idxs) == len(episodes)
for new_idx, ep in zip(new_idxs, episodes):
self.buffer[new_idx] = ep
self.priorities[new_idxs] = priorities
self.priorities[0:self.init_length] = np.max(
self.priorities[self.init_length:])
assert len(self.buffer) == self.cur_size
return new_idxs
def remove_n(self, n):
"""Get n items for removal."""
assert self.init_length + n <= self.cur_size
if self.eviction_strategy == 'rand':
# random removal
idxs = random.sample(xrange(self.init_length, self.cur_size), n)
elif self.eviction_strategy == 'fifo':
# overwrite elements in cyclical fashion
idxs = [
self.init_length +
(self.remove_idx + i) % (self.max_size - self.init_length)
for i in xrange(n)]
self.remove_idx = idxs[-1] + 1 - self.init_length
elif self.eviction_strategy == 'rank':
# remove lowest-priority indices
idxs = np.argpartition(self.priorities, n-1)[:n]
return idxs
def sampling_distribution(self):
p = self.priorities[:self.cur_size]
p = np.exp(self.alpha * (p - np.max(p)))
norm = np.sum(p)
if norm > 0:
uniform = 0.0
p = p / norm * (1 - uniform) + 1.0 / self.cur_size * uniform
else:
p = np.ones(self.cur_size) / self.cur_size
return p
def get_batch(self, n):
"""Get batch of episodes to train on."""
p = self.sampling_distribution()
idxs = np.random.choice(self.cur_size, size=int(n), replace=False, p=p)
self.last_batch = idxs
return [self.buffer[idx] for idx in idxs], p[idxs]
def update_last_batch(self, delta):
"""Update last batch idxs with new priority."""
self.priorities[self.last_batch] = np.abs(delta)
self.priorities[0:self.init_length] = np.max(
self.priorities[self.init_length:])