From eb3579381804bedb85cb3786c87b2fdb6d8607d5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 20 Jun 2024 13:28:37 +0100 Subject: [PATCH] [BugFix] Fix OOB sampling in PrioritizedSliceSampler (#2239) --- sota-implementations/cql/cql_offline.py | 3 ++- sota-implementations/cql/cql_online.py | 6 ++++-- torchrl/data/replay_buffers/samplers.py | 9 +++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index 35cb43a4fe7..d8185c8091c 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -162,7 +162,8 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.close() torchrl_logger.info(f"Training time: {time.time() - start_time}") - eval_env.close() + if not eval_env.is_closed: + eval_env.close() if __name__ == "__main__": diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 5e0cb633009..5f8f81357c8 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -227,8 +227,10 @@ def main(cfg: "DictConfig"): # noqa: F821 torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") collector.shutdown() - eval_env.close() - train_env.close() + if not eval_env.is_closed: + eval_env.close() + if not train_env.is_closed: + train_env.close() if __name__ == "__main__": diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 1d1499312b8..ca587014653 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -475,6 +475,15 @@ def sample(self, storage: Storage, batch_size: int) -> torch.Tensor: index = index.unsqueeze(0) index.clamp_max_(len(storage) - 1) weight = torch.as_tensor(self._sum_tree[index]) + # get indices where weight is 0 + zero_weight = weight == 0 + index = index + while zero_weight.any(): + index = torch.where(zero_weight, index - 1, index) + if (index < 0).any(): + raise RuntimeError("Failed to find a suitable index") + zero_weight = torch.as_tensor(self._sum_tree[index]) + zero_weight = weight == 0 # Importance sampling weight formula: # w_i = (p_i / sum(p) * N) ^ (-beta)