From f7ecad840f3c4be62b3deef33ab5d8a508df6326 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 12 Jul 2024 14:04:19 +0100 Subject: [PATCH] amend --- test/test_storage_map.py | 45 +++++++++---- torchrl/data/map/hash.py | 89 +++++++++++++----------- torchrl/data/map/query.py | 7 +- torchrl/data/map/tdstorage.py | 90 +++++++++++++++---------- torchrl/data/replay_buffers/storages.py | 23 ++++--- 5 files changed, 152 insertions(+), 102 deletions(-) diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 9315c6eb6f6..4f59db03c92 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import argparse import importlib.util -from typing import cast import pytest @@ -12,7 +11,13 @@ from tensordict import TensorDict from torchrl.data import LazyTensorStorage, ListStorage -from torchrl.data.map import BinaryToDecimal, QueryModule, SipHash, TensorDictMap +from torchrl.data.map import ( + BinaryToDecimal, + QueryModule, + RandomProjectionHash, + SipHash, + TensorDictMap, +) from torchrl.envs import GymEnv _has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec( @@ -60,13 +65,29 @@ def test_binary_to_decimal(): assert (decimal == torch.Tensor([3, 2])).all() -def test_sip_hash(): - a = torch.rand((3, 2)) - b = a.clone() - hash_module = SipHash() - hash_a = cast(torch.Tensor, hash_module(a)) - hash_b = cast(torch.Tensor, hash_module(b)) - assert (hash_a == hash_b).all() +class TestHash: + def test_sip_hash(self): + a = torch.rand((3, 2)) + b = a.clone() + hash_module = SipHash(as_tensor=True) + hash_a = torch.tensor(hash_module(a)) + hash_b = torch.tensor(hash_module(b)) + assert (hash_a == hash_b).all() + + @pytest.mark.parametrize("n_components", [None, 14]) + @pytest.mark.parametrize("scale", [0.001, 0.01, 1, 100, 1000]) + def test_randomprojection_hash(self, n_components, scale): + torch.manual_seed(0) + r = RandomProjectionHash(n_components=n_components) + x = torch.randn(10000, 100).mul_(scale) + y = r(x) + if n_components is None: + assert r.n_components == r._N_COMPONENTS_DEFAULT + else: + assert r.n_components == n_components + + assert y.shape == (10000,) + assert y.unique().numel() == y.numel() def test_query(): @@ -104,7 +125,7 @@ def test_query_module(): tensor_dict_storage = TensorDictMap( query_module=query_module, - key_to_storage={"index": embedding_storage}, + storage=embedding_storage, ) index = TensorDict( @@ -140,7 +161,7 @@ def test_storage(): tensor_dict_storage = TensorDictMap( query_module=query_module, - key_to_storage={"index": embedding_storage}, + storage=embedding_storage, ) index = TensorDict( @@ -162,7 +183,7 @@ def test_storage(): new_index["key3"] = torch.Tensor([[4], [5], [6], [7]]) retrieve_value = tensor_dict_storage[new_index] - assert cast(torch.Tensor, retrieve_value["index"] == value["index"]).all() + assert (retrieve_value["index"] == value["index"]).all() @pytest.mark.skipif(not _has_gym, reason="gym not installed") diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index 90a49b6bebc..71acd1fd5af 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -85,16 +85,20 @@ class SipHash(torch.nn.Module): A hash function module based on SipHash implementation in python. - .. warning:: This module relies on the builtin ``hash`` function. - To get reproducible results across runs, the ``PYTHONHASHSEED`` environment - variable must be set before the code is run (changing this value during code - execution is without effect). + Args: + as_tensor (bool, optional): if ``True``, the bytes will be turned into integers + through the builtin ``hash`` function and mapped to a tensor. Default: ``True``. + + .. warning:: This module relies on the builtin ``hash`` function. + To get reproducible results across runs, the ``PYTHONHASHSEED`` environment + variable must be set before the code is run (changing this value during code + execution is without effect). Examples: >>> # Assuming we set PYTHONHASHSEED=0 prior to running this code >>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) >>> b = a.clone() - >>> hash_module = SipHash() + >>> hash_module = SipHash(as_tensor=True) >>> hash_a = hash_module(a) >>> hash_a tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521]) @@ -102,30 +106,41 @@ class SipHash(torch.nn.Module): >>> assert (hash_a == hash_b).all() """ - def forward(self, x: torch.Tensor) -> torch.Tensor: + def __init__(self, as_tensor: bool = True): + super().__init__() + self.as_tensor = as_tensor + + def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]: hash_values = [] + if x.dtype in (torch.bfloat16,): + x = x.to(torch.float16) for x_i in x.detach().cpu().numpy(): - hash_value = hash(x_i.tobytes()) + hash_value = x_i.tobytes() hash_values.append(hash_value) - - return torch.tensor(hash_values, dtype=torch.int64) + if not self.as_tensor: + return hash_value + return torch.tensor([hash(x) for x in hash_values], dtype=torch.int64) class RandomProjectionHash(SipHash): - """A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through SipHash. + """A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through :class:`~.SipHash`. This module requires sklearn to be installed. Keyword Args: n_components (int, optional): the low-dimensional number of components of the projections. Defaults to 16. - projection_type (str, optional): the projection type to use. - Must be one of ``"gaussian"`` or ``"sparse_random"``. Defaults to "gaussian". dtype_cast (torch.dtype, optional): the dtype to cast the projection to. - Defaults to ``torch.float16``. - lazy (bool, optional): if ``True``, the random projection is fit on the first batch of data - received. Defaults to ``False``. + Defaults to ``torch.bfloat16``. + as_tensor (bool, optional): if ``True``, the bytes will be turned into integers + through the builtin ``hash`` function and mapped to a tensor. Default: ``True``. + + .. warning:: This module relies on the builtin ``hash`` function. + To get reproducible results across runs, the ``PYTHONHASHSEED`` environment + variable must be set before the code is run (changing this value during code + execution is without effect). + init_method: TODO """ _N_COMPONENTS_DEFAULT = 16 @@ -134,47 +149,43 @@ def __init__( self, *, n_components: int | None = None, - projection_type: str = "sparse_random", - dtype_cast=torch.float16, - lazy: bool = False, + dtype_cast=torch.bfloat16, + as_tensor: bool = True, + init_method: Callable[[torch.Tensor], torch.Tensor | None] | None = None, **kwargs, ): if n_components is None: n_components = self._N_COMPONENTS_DEFAULT - super().__init__() - from sklearn.random_projection import ( - GaussianRandomProjection, - SparseRandomProjection, - ) + super().__init__(as_tensor=as_tensor) + self.register_buffer("_n_components", torch.as_tensor(n_components)) - self.lazy = lazy - self._init = not lazy + self._init = False + if init_method is None: + init_method = torch.nn.init.normal_ + self.init_method = init_method self.dtype_cast = dtype_cast - if projection_type.lower() == "gaussian": - self.transform = GaussianRandomProjection( - n_components=n_components, **kwargs - ) - elif projection_type.lower() in ("sparse_random", "sparse-random"): - self.transform = SparseRandomProjection(n_components=n_components, **kwargs) - else: - raise ValueError( - f"Only 'gaussian' and 'sparse_random' projections are supported. Got projection_type={projection_type}." - ) + self.register_buffer("transform", torch.nn.UninitializedBuffer()) + + @property + def n_components(self): + return self._n_components.item() def fit(self, x): """Fits the random projection to the input data.""" - self.transform.fit(x) + self.transform.materialize( + (x.shape[-1], self.n_components), dtype=self.dtype_cast, device=x.device + ) + self.init_method(self.transform) self._init = True def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.lazy and not self._init: + if not self._init: self.fit(x) elif not self._init: raise RuntimeError( f"The {type(self).__name__} has not been initialized. Call fit before calling this method." ) - x = self.transform.transform(x) - x = torch.as_tensor(x, dtype=self.dtype_cast) + x = x.to(self.dtype_cast) @ self.transform return super().forward(x) diff --git a/torchrl/data/map/query.py b/torchrl/data/map/query.py index 17fcd4f9f22..9309b72afdd 100644 --- a/torchrl/data/map/query.py +++ b/torchrl/data/map/query.py @@ -3,17 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import abc -from abc import abstractmethod -from typing import Any, Callable, Dict, Generic, List, Mapping, TypeVar +from typing import Any, Callable, Dict, List, Mapping, TypeVar import torch - import torch.nn as nn - from tensordict import NestedKey, TensorDict, TensorDictBase from tensordict.nn.common import TensorDictModuleBase - from torchrl.data.map import SipHash K = TypeVar("K") diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index fbd632e55b3..5229e938a65 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -4,19 +4,15 @@ # LICENSE file in the root directory of this source tree. import abc +import functools from abc import abstractmethod -from typing import Callable, Dict, Generic, List, TypeVar +from typing import Any, Callable, Dict, Generic, List, TypeVar import torch - -import torch.nn as nn - -from tensordict import NestedKey, TensorDict, TensorDictBase -from tensordict.nn import TensorDictModuleBase +from tensordict import NestedKey, TensorDictBase from tensordict.nn.common import TensorDictModuleBase - -from torchrl.data import LazyTensorStorage, Storage from torchrl.data.map import QueryModule, RandomProjectionHash +from torchrl.data.replay_buffers.storages import _get_default_collate, LazyTensorStorage K = TypeVar("K") V = TypeVar("V") @@ -61,12 +57,16 @@ class TensorDictMap( returns another tensordict as output similar to TensorDictModuleBase. However, it provides additional functionality like python map: - Args: + Keyword Args: query_module (TensorDictModuleBase): a query module, typically an instance of :class:`~tensordict.nn.QueryModule`, used to map a set of tensordict entries to a hash key. - key_to_storage (Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]]): + storage (Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]]): a dictionary representing the map from an index key to a tensor storage. + collate_fn (callable, optional): a function to use to collate samples from the + storage. Defaults to a custom value for each known storage type (stack for + :class:`~torchrl.data.ListStorage`, identity for :class:`~torchrl.data.TensorStorage` + subtypes and others). Examples: >>> import torch @@ -80,7 +80,7 @@ class TensorDictMap( >>> embedding_storage = LazyTensorStorage(1000) >>> tensor_dict_storage = TensorDictMap( ... query_module=query_module, - ... key_to_storage={"out": embedding_storage}, + ... storage={"out": embedding_storage}, ... ) >>> index = TensorDict( ... { @@ -111,17 +111,43 @@ def __init__( self, *, query_module: QueryModule, - key_to_storage: Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]], + storage: Dict[NestedKey, TensorMap[torch.Tensor, torch.Tensor]], + collate_fn: Callable[[Any], Any] | None = None, ): self.in_keys = query_module.in_keys - self.out_keys = list(key_to_storage.keys()) super().__init__() self.query_module = query_module self.index_key = query_module.index_key - self.key_to_storage = key_to_storage + self.storage = storage self.batch_added = False + if collate_fn is None: + collate_fn = _get_default_collate(self.storage) + self.collate_fn = collate_fn + + @property + def out_keys(self) -> List[NestedKey]: + out_keys = self.__dict__.get("_out_keys") + if out_keys is not None: + return out_keys[0] + storage = self.storage + if isinstance(storage, TensorDictStorage) and is_tensor_collection( + storage._storage + ): + out_keys = list(storage._storage.keys(True, True)) + self._out_keys = (out_keys, True) + return self.out_keys + raise AttributeError( + f"No out-keys found in the storage of type {type(storage)}" + ) + + @out_keys.setter + def out_keys(self, value): + self._out_keys = (value, False) + + def _has_lazy_out_keys(self): + return self._out_keys[1] @classmethod def from_tensordict_pair( @@ -130,7 +156,7 @@ def from_tensordict_pair( dest, in_keys: List[NestedKey], out_keys: List[NestedKey] | None = None, - storage_type: type = lambda: LazyTensorStorage(1000), + storage_constructor: type = functools.partial(LazyTensorStorage, 1000), hash_module: Callable | None = None, ): """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. @@ -142,7 +168,7 @@ def from_tensordict_pair( out_keys (List[NestedKey]): a list of keys to return in the output tensordict. All keys absent from out_keys, even if present in ``dest``, will not be stored in the storage. Defaults to ``None`` (all keys are registered). - storage_type (type, optional): a type of tensor storage. + storage_constructor (type, optional): a type of tensor storage. Defaults to :class:`~tensordict.nn.storage.LazyDynamicStorage`. Other options include :class:`~tensordict.nn.storage.FixedStorage`. hash_module (Callable, optional): a hash function to use in the :class:`~tensordict.nn.storage.QueryModule`. @@ -190,15 +216,14 @@ def from_tensordict_pair( query_module = QueryModule(in_keys, hash_module=hash_module) # Build key_to_storage - if out_keys is None: - out_keys = list(dest.keys(True, True)) - key_to_storage = {} - for key in out_keys: - key_to_storage[key] = storage_type() - return cls(query_module=query_module, key_to_storage=key_to_storage) + storage = storage_constructor() + result = cls(query_module=query_module, storage=storage) + if out_keys is not None: + result.out_keys = out_keys + return result def clear(self) -> None: - for mem in self.key_to_storage.values(): + for mem in self.storage.values(): mem.clear() def _to_index(self, item: TensorDictBase, extend: bool) -> torch.Tensor: @@ -228,29 +253,26 @@ def __getitem__(self, item: TensorDictBase) -> TensorDictBase: index = self._to_index(item, extend=False) - res = TensorDict({}, batch_size=item.batch_size) - for k in self.out_keys: - storage: Storage = self.key_to_storage[k] - res.set(k, storage[index]) - + res = self.storage[index] + res = self.collate_fn(res) res = self._maybe_remove_batch(res) return res def __setitem__(self, item: TensorDictBase, value: TensorDictBase): + if not self._has_lazy_out_keys: + # TODO: make this work with pytrees and avoid calling select if keys match + value = value.select(self.out_keys) item, value = self._maybe_add_batch(item, value) - index = self._to_index(item, extend=True) - for k in self.out_keys: - storage: Storage = self.key_to_storage[k] - storage.set(index, value[k]) + self.storage.set(index, value) def __len__(self): - return len(next(iter(self.key_to_storage.values()))) + return len(next(iter(self.storage.values()))) def contains(self, item: TensorDictBase) -> torch.Tensor: item, _ = self._maybe_add_batch(item, None) index = self._to_index(item, extend=False) - res = next(iter(self.key_to_storage.values())).contains(index) + res = self.storage.contains(index) res = self._maybe_remove_batch(res) return res diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 9d1722b3db7..efff5da204b 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -808,10 +808,10 @@ def repr_item(x): def contains(self, item): if isinstance(item, int): - if index < 0: - index += self._len_along_dim0 + if item < 0: + item += self._len_along_dim0 - return 0 <= index < self._len_along_dim0 + return 0 <= item < self._len_along_dim0 if isinstance(item, torch.Tensor): def _is_valid_index(idx): @@ -1300,10 +1300,14 @@ def _collate_list_tensordict(x): return out -def _stack_anything(x): - if is_tensor_collection(x[0]): - return LazyStackedTensorDict.maybe_dense_stack(x) - return torch.stack(x) +def _stack_anything(data): + if is_tensor_collection(data[0]): + return LazyStackedTensorDict.maybe_dense_stack(data) + return torch.utils._pytree.tree_map( + lambda *x: torch.stack(x), + *data, + is_leaf=lambda x: isinstance(x, torch.Tensor) or is_tensor_collection(x), + ) def _collate_id(x): @@ -1312,10 +1316,7 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): if isinstance(storage, ListStorage): - if _is_tensordict: - return _collate_list_tensordict - else: - return torch.utils.data._utils.collate.default_collate + return _stack_anything elif isinstance(storage, TensorStorage): return _collate_id else: