Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 17, 2024
2 parents 9004b20 + bdce685 commit 862e41f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
8 changes: 7 additions & 1 deletion benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
from packaging import version

from tensordict import TensorDict
from tensordict.nn import (
Expand Down Expand Up @@ -43,6 +44,11 @@
vec_td_lambda_return_estimate,
)

TORCH_VERSION = torch.__version__
FULLGRAPH = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse(
"2.5.0"
) # Anything from 2.5, incl. nightlies, allows for fullgraph


@pytest.fixture(scope="module")
def set_default_device():
Expand Down Expand Up @@ -147,7 +153,7 @@ def test_gae_speed(benchmark, gae_fn, gamma_tensor, batches, timesteps):
)


def _maybe_compile(fn, compile, td, fullgraph=True, warmup=3):
def _maybe_compile(fn, compile, td, fullgraph=FULLGRAPH, warmup=3):
if compile:
if isinstance(compile, str):
fn = torch.compile(fn, mode=compile, fullgraph=fullgraph)
Expand Down
14 changes: 14 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,20 @@ def test_errors(self, storage_type):
):
storage_type(data, max_size=4)

def test_existsok_lazymemmap(self, tmpdir):
storage0 = LazyMemmapStorage(10, scratch_dir=tmpdir)
rb = ReplayBuffer(storage=storage0)
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))

storage1 = LazyMemmapStorage(10, scratch_dir=tmpdir)
rb = ReplayBuffer(storage=storage1)
with pytest.raises(RuntimeError, match="existsok"):
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))

storage2 = LazyMemmapStorage(10, scratch_dir=tmpdir, existsok=True)
rb = ReplayBuffer(storage=storage2)
rb.extend(TensorDict(a=torch.randn(3), batch_size=[3]))

@pytest.mark.parametrize(
"data_type", ["tensor", "tensordict", "tensorclass", "pytree"]
)
Expand Down
9 changes: 8 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,8 @@ class LazyMemmapStorage(LazyTensorStorage):
Args:
max_size (int): size of the storage, i.e. maximum number of elements stored
in the buffer.
Keyword Args:
scratch_dir (str or path): directory where memmap-tensors will be written.
device (torch.device, optional): device where the sampled tensors will be
stored and sent. Default is :obj:`torch.device("cpu")`.
Expand All @@ -950,6 +952,9 @@ class LazyMemmapStorage(LazyTensorStorage):
measuring the storage size. For instance, a storage of shape ``[3, 4]``
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
Defaults to ``1``.
existsok (bool, optional): whether an error should be raised if any of the
tensors already exists on disk. Defaults to ``True``. If ``False``, the
tensor will be opened as is, not overewritten.
.. note:: When checkpointing a ``LazyMemmapStorage``, one can provide a path identical to where the storage is
already stored to avoid executing long copies of data that is already stored on disk.
Expand Down Expand Up @@ -1026,10 +1031,12 @@ def __init__(
scratch_dir=None,
device: torch.device = "cpu",
ndim: int = 1,
existsok: bool = False,
):
super().__init__(max_size, ndim=ndim)
self.initialized = False
self.scratch_dir = None
self.existsok = existsok
if scratch_dir is not None:
self.scratch_dir = str(scratch_dir)
if self.scratch_dir[-1] != "/":
Expand Down Expand Up @@ -1125,7 +1132,7 @@ def max_size_along_dim0(data_shape):
if is_tensor_collection(data):
out = data.clone().to(self.device)
out = out.expand(max_size_along_dim0(data.shape))
out = out.memmap_like(prefix=self.scratch_dir)
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
for key, tensor in sorted(
out.items(include_nested=True, leaves_only=True), key=str
):
Expand Down

0 comments on commit 862e41f

Please sign in to comment.