From c7084c8c0b006f699069290db859d61e4f930cfe Mon Sep 17 00:00:00 2001 From: flodorner <38560862+flodorner@users.noreply.github.com> Date: Sat, 4 Jan 2020 12:46:07 +0100 Subject: [PATCH] Parallelized sampling from the replay buffer and building the segment tree (#608) * Parallelized sampling from the replay buffer and building the segment tree. --- docs/misc/changelog.rst | 2 + stable_baselines/common/segment_tree.py | 84 +++++++++++++++++-------- stable_baselines/deepq/replay_buffer.py | 35 +++++------ tests/test_segment_tree.py | 76 ++++++++++++++++++++++ 4 files changed, 152 insertions(+), 45 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 8d1d369cfb..9b6153e004 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Parallelized updating and sampling from the replay buffer in DQN. (@flodorner) - Docker build script, `scripts/build_docker.sh`, can push images automatically. @@ -597,3 +598,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk @XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching +@flodorner diff --git a/stable_baselines/common/segment_tree.py b/stable_baselines/common/segment_tree.py index 1a22d8eed0..6b184fa4e5 100644 --- a/stable_baselines/common/segment_tree.py +++ b/stable_baselines/common/segment_tree.py @@ -1,4 +1,18 @@ -import operator +import numpy as np + + +def unique(sorted_array): + """ + More efficient implementation of np.unique for sorted arrays + :param sorted_array: (np.ndarray) + :return:(np.ndarray) sorted_array without duplicate elements + """ + if len(sorted_array) == 1: + return sorted_array + left = sorted_array[:-1] + right = sorted_array[1:] + uniques = np.append(right != left, True) + return sorted_array[uniques] class SegmentTree(object): @@ -8,7 +22,7 @@ def __init__(self, capacity, operation, neutral_element): https://en.wikipedia.org/wiki/Segment_tree - Can be used as regular array, but with two + Can be used as regular array that supports Index arrays, but with two important differences: a) setting item's value is slightly slower. @@ -26,6 +40,7 @@ def __init__(self, capacity, operation, neutral_element): self._capacity = capacity self._value = [neutral_element for _ in range(2 * capacity)] self._operation = operation + self.neutral_element = neutral_element def _reduce_helper(self, start, end, node, node_start, node_end): if start == node_start and end == node_end: @@ -61,19 +76,25 @@ def reduce(self, start=0, end=None): return self._reduce_helper(start, end, 1, 0, self._capacity - 1) def __setitem__(self, idx, val): - # index of the leaf - idx += self._capacity - self._value[idx] = val - idx //= 2 - while idx >= 1: - self._value[idx] = self._operation( - self._value[2 * idx], - self._value[2 * idx + 1] + # indexes of the leaf + idxs = idx + self._capacity + self._value[idxs] = val + if isinstance(idxs, int): + idxs = np.array([idxs]) + # go up one level in the tree and remove duplicate indexes + idxs = unique(idxs // 2) + while len(idxs) > 1 or idxs[0] > 0: + # as long as there are non-zero indexes, update the corresponding values + self._value[idxs] = self._operation( + self._value[2 * idxs], + self._value[2 * idxs + 1] ) - idx //= 2 + # go up one level in the tree and remove duplicate indexes + idxs = unique(idxs // 2) def __getitem__(self, idx): - assert 0 <= idx < self._capacity + assert np.max(idx) < self._capacity + assert 0 <= np.min(idx) return self._value[self._capacity + idx] @@ -81,9 +102,10 @@ class SumSegmentTree(SegmentTree): def __init__(self, capacity): super(SumSegmentTree, self).__init__( capacity=capacity, - operation=operator.add, + operation=np.add, neutral_element=0.0 ) + self._value = np.array(self._value) def sum(self, start=0, end=None): """ @@ -98,23 +120,34 @@ def sum(self, start=0, end=None): def find_prefixsum_idx(self, prefixsum): """ Find the highest index `i` in the array such that - sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum + sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum if array values are probabilities, this function allows to sample indexes according to the discrete probability efficiently. - :param prefixsum: (float) upperbound on the sum of array prefix - :return: (int) highest index satisfying the prefixsum constraint + :param prefixsum: (np.ndarray) float upper bounds on the sum of array prefix + :return: (np.ndarray) highest indexes satisfying the prefixsum constraint """ - assert 0 <= prefixsum <= self.sum() + 1e-5 - idx = 1 - while idx < self._capacity: # while non-leaf - if self._value[2 * idx] > prefixsum: - idx = 2 * idx - else: - prefixsum -= self._value[2 * idx] - idx = 2 * idx + 1 + if isinstance(prefixsum, float): + prefixsum = np.array([prefixsum]) + assert 0 <= np.min(prefixsum) + assert np.max(prefixsum) <= self.sum() + 1e-5 + assert isinstance(prefixsum[0], float) + + idx = np.ones(len(prefixsum), dtype=int) + cont = np.ones(len(prefixsum), dtype=bool) + + while np.any(cont): # while not all nodes are leafs + idx[cont] = 2 * idx[cont] + prefixsum_new = np.where(self._value[idx] <= prefixsum, prefixsum - self._value[idx], prefixsum) + # prepare update of prefixsum for all right children + idx = np.where(np.logical_or(self._value[idx] > prefixsum, np.logical_not(cont)), idx, idx + 1) + # Select child node for non-leaf nodes + prefixsum = prefixsum_new + # update prefixsum + cont = idx < self._capacity + # collect leafs return idx - self._capacity @@ -122,9 +155,10 @@ class MinSegmentTree(SegmentTree): def __init__(self, capacity): super(MinSegmentTree, self).__init__( capacity=capacity, - operation=min, + operation=np.minimum, neutral_element=float('inf') ) + self._value = np.array(self._value) def min(self, start=0, end=None): """ diff --git a/stable_baselines/deepq/replay_buffer.py b/stable_baselines/deepq/replay_buffer.py index b274e51597..a7230989bd 100644 --- a/stable_baselines/deepq/replay_buffer.py +++ b/stable_baselines/deepq/replay_buffer.py @@ -134,13 +134,12 @@ def add(self, obs_t, action, reward, obs_tp1, done): self._it_min[idx] = self._max_priority ** self._alpha def _sample_proportional(self, batch_size): - res = [] - for _ in range(batch_size): - # TODO(szymon): should we ensure no repeats? - mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) - idx = self._it_sum.find_prefixsum_idx(mass) - res.append(idx) - return res + mass = [] + total = self._it_sum.sum(0, len(self._storage) - 1) + # TODO(szymon): should we ensure no repeats? + mass = np.random.random(size=batch_size) * total + idx = self._it_sum.find_prefixsum_idx(mass) + return idx def sample(self, batch_size, beta=0): """ @@ -166,16 +165,11 @@ def sample(self, batch_size, beta=0): assert beta > 0 idxes = self._sample_proportional(batch_size) - weights = [] p_min = self._it_min.min() / self._it_sum.sum() max_weight = (p_min * len(self._storage)) ** (-beta) - - for idx in idxes: - p_sample = self._it_sum[idx] / self._it_sum.sum() - weight = (p_sample * len(self._storage)) ** (-beta) - weights.append(weight / max_weight) - weights = np.array(weights) + p_sample = self._it_sum[idxes] / self._it_sum.sum() + weights = (p_sample * len(self._storage)) ** (-beta) / max_weight encoded_sample = self._encode_sample(idxes) return tuple(list(encoded_sample) + [weights, idxes]) @@ -191,10 +185,11 @@ def update_priorities(self, idxes, priorities): denoted by variable `idxes`. """ assert len(idxes) == len(priorities) - for idx, priority in zip(idxes, priorities): - assert priority > 0 - assert 0 <= idx < len(self._storage) - self._it_sum[idx] = priority ** self._alpha - self._it_min[idx] = priority ** self._alpha + assert np.min(priorities) > 0 + assert np.min(idxes) >= 0 + assert np.max(idxes) < len(self.storage) + self._it_sum[idxes] = priorities ** self._alpha + self._it_min[idxes] = priorities ** self._alpha + + self._max_priority = max(self._max_priority, np.max(priorities)) - self._max_priority = max(self._max_priority, priority) diff --git a/tests/test_segment_tree.py b/tests/test_segment_tree.py index 4719d67d37..db0c5bc8d8 100644 --- a/tests/test_segment_tree.py +++ b/tests/test_segment_tree.py @@ -9,6 +9,16 @@ def test_tree_set(): """ tree = SumSegmentTree(4) + tree[np.array([2, 3])] = [1.0, 3.0] + + assert np.isclose(tree.sum(), 4.0) + assert np.isclose(tree.sum(0, 2), 0.0) + assert np.isclose(tree.sum(0, 3), 1.0) + assert np.isclose(tree.sum(2, 3), 1.0) + assert np.isclose(tree.sum(2, -1), 1.0) + assert np.isclose(tree.sum(2, 4), 4.0) + + tree = SumSegmentTree(4) tree[2] = 1.0 tree[3] = 3.0 @@ -26,6 +36,17 @@ def test_tree_set_overlap(): """ tree = SumSegmentTree(4) + tree[np.array([2])] = 1.0 + tree[np.array([2])] = 3.0 + + assert np.isclose(tree.sum(), 3.0) + assert np.isclose(tree.sum(2, 3), 3.0) + assert np.isclose(tree.sum(2, -1), 3.0) + assert np.isclose(tree.sum(2, 4), 3.0) + assert np.isclose(tree.sum(1, 2), 0.0) + + tree = SumSegmentTree(4) + tree[2] = 1.0 tree[2] = 3.0 @@ -51,6 +72,19 @@ def test_prefixsum_idx(): assert tree.find_prefixsum_idx(1.01) == 3 assert tree.find_prefixsum_idx(3.00) == 3 assert tree.find_prefixsum_idx(4.00) == 3 + assert np.all(tree.find_prefixsum_idx([0.0, 0.5, 0.99, 1.01, 3.00, 4.00]) == [2, 2, 2, 3, 3, 3]) + + tree = SumSegmentTree(4) + + tree[np.array([2, 3])] = [1.0, 3.0] + + assert tree.find_prefixsum_idx(0.0) == 2 + assert tree.find_prefixsum_idx(0.5) == 2 + assert tree.find_prefixsum_idx(0.99) == 2 + assert tree.find_prefixsum_idx(1.01) == 3 + assert tree.find_prefixsum_idx(3.00) == 3 + assert tree.find_prefixsum_idx(4.00) == 3 + assert np.all(tree.find_prefixsum_idx([0.0, 0.5, 0.99, 1.01, 3.00, 4.00]) == [2, 2, 2, 3, 3, 3]) def test_prefixsum_idx2(): @@ -59,6 +93,17 @@ def test_prefixsum_idx2(): """ tree = SumSegmentTree(4) + tree[np.array([0, 1, 2, 3])] = [0.5, 1.0, 1.0, 3.0] + + assert tree.find_prefixsum_idx(0.00) == 0 + assert tree.find_prefixsum_idx(0.55) == 1 + assert tree.find_prefixsum_idx(0.99) == 1 + assert tree.find_prefixsum_idx(1.51) == 2 + assert tree.find_prefixsum_idx(3.00) == 3 + assert tree.find_prefixsum_idx(5.50) == 3 + + tree = SumSegmentTree(4) + tree[0] = 0.5 tree[1] = 1.0 tree[2] = 1.0 @@ -109,6 +154,37 @@ def test_max_interval_tree(): assert np.isclose(tree.min(2, -1), 4.0) assert np.isclose(tree.min(3, 4), 3.0) + tree = MinSegmentTree(4) + + tree[np.array([0, 2, 3])] = [1.0, 0.5, 3.0] + + assert np.isclose(tree.min(), 0.5) + assert np.isclose(tree.min(0, 2), 1.0) + assert np.isclose(tree.min(0, 3), 0.5) + assert np.isclose(tree.min(0, -1), 0.5) + assert np.isclose(tree.min(2, 4), 0.5) + assert np.isclose(tree.min(3, 4), 3.0) + + tree[np.array([2])] = 0.7 + + assert np.isclose(tree.min(), 0.7) + assert np.isclose(tree.min(0, 2), 1.0) + assert np.isclose(tree.min(0, 3), 0.7) + assert np.isclose(tree.min(0, -1), 0.7) + assert np.isclose(tree.min(2, 4), 0.7) + assert np.isclose(tree.min(3, 4), 3.0) + + tree[np.array([2])] = 4.0 + + assert np.isclose(tree.min(), 1.0) + assert np.isclose(tree.min(0, 2), 1.0) + assert np.isclose(tree.min(0, 3), 1.0) + assert np.isclose(tree.min(0, -1), 1.0) + assert np.isclose(tree.min(2, 4), 3.0) + assert np.isclose(tree.min(2, 3), 4.0) + assert np.isclose(tree.min(2, -1), 4.0) + assert np.isclose(tree.min(3, 4), 3.0) + if __name__ == '__main__': test_tree_set()