diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index b914c52b338..fb684e7c043 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -24,6 +24,7 @@ TensorDict, TensorDictBase, ) +from tensordict.base import _NESTED_TENSORS_AS_LISTS from tensordict.memmap import MemoryMappedTensor from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -1120,7 +1121,12 @@ def max_size_along_dim0(data_shape): out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok) if torchrl_logger.isEnabledFor(logging.DEBUG): for key, tensor in sorted( - out.items(include_nested=True, leaves_only=True), key=str + out.items( + include_nested=True, + leaves_only=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + ), + key=str, ): try: filesize = os.path.getsize(tensor.filename) / 1024 / 1024