Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 12, 2024
1 parent 02b0aef commit f7ecad8
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 102 deletions.
45 changes: 33 additions & 12 deletions test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
# LICENSE file in the root directory of this source tree.
import argparse
import importlib.util
from typing import cast

import pytest

import torch

from tensordict import TensorDict
from torchrl.data import LazyTensorStorage, ListStorage
from torchrl.data.map import BinaryToDecimal, QueryModule, SipHash, TensorDictMap
from torchrl.data.map import (
BinaryToDecimal,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
)
from torchrl.envs import GymEnv

_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
Expand Down Expand Up @@ -60,13 +65,29 @@ def test_binary_to_decimal():
assert (decimal == torch.Tensor([3, 2])).all()


def test_sip_hash():
a = torch.rand((3, 2))
b = a.clone()
hash_module = SipHash()
hash_a = cast(torch.Tensor, hash_module(a))
hash_b = cast(torch.Tensor, hash_module(b))
assert (hash_a == hash_b).all()
class TestHash:
def test_sip_hash(self):
a = torch.rand((3, 2))
b = a.clone()
hash_module = SipHash(as_tensor=True)
hash_a = torch.tensor(hash_module(a))
hash_b = torch.tensor(hash_module(b))
assert (hash_a == hash_b).all()

@pytest.mark.parametrize("n_components", [None, 14])
@pytest.mark.parametrize("scale", [0.001, 0.01, 1, 100, 1000])
def test_randomprojection_hash(self, n_components, scale):
torch.manual_seed(0)
r = RandomProjectionHash(n_components=n_components)
x = torch.randn(10000, 100).mul_(scale)
y = r(x)
if n_components is None:
assert r.n_components == r._N_COMPONENTS_DEFAULT
else:
assert r.n_components == n_components

assert y.shape == (10000,)
assert y.unique().numel() == y.numel()


def test_query():
Expand Down Expand Up @@ -104,7 +125,7 @@ def test_query_module():

tensor_dict_storage = TensorDictMap(
query_module=query_module,
key_to_storage={"index": embedding_storage},
storage=embedding_storage,
)

index = TensorDict(
Expand Down Expand Up @@ -140,7 +161,7 @@ def test_storage():

tensor_dict_storage = TensorDictMap(
query_module=query_module,
key_to_storage={"index": embedding_storage},
storage=embedding_storage,
)

index = TensorDict(
Expand All @@ -162,7 +183,7 @@ def test_storage():
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert cast(torch.Tensor, retrieve_value["index"] == value["index"]).all()
assert (retrieve_value["index"] == value["index"]).all()


@pytest.mark.skipif(not _has_gym, reason="gym not installed")
Expand Down
89 changes: 50 additions & 39 deletions torchrl/data/map/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,47 +85,62 @@ class SipHash(torch.nn.Module):
A hash function module based on SipHash implementation in python.
.. warning:: This module relies on the builtin ``hash`` function.
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
variable must be set before the code is run (changing this value during code
execution is without effect).
Args:
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
through the builtin ``hash`` function and mapped to a tensor. Default: ``True``.
.. warning:: This module relies on the builtin ``hash`` function.
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
variable must be set before the code is run (changing this value during code
execution is without effect).
Examples:
>>> # Assuming we set PYTHONHASHSEED=0 prior to running this code
>>> a = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
>>> b = a.clone()
>>> hash_module = SipHash()
>>> hash_module = SipHash(as_tensor=True)
>>> hash_a = hash_module(a)
>>> hash_a
tensor([-4669941682990263259, -3778166555168484291, -9122128731510687521])
>>> hash_b = hash_module(b)
>>> assert (hash_a == hash_b).all()
"""

def forward(self, x: torch.Tensor) -> torch.Tensor:
def __init__(self, as_tensor: bool = True):
super().__init__()
self.as_tensor = as_tensor

def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]:
hash_values = []
if x.dtype in (torch.bfloat16,):
x = x.to(torch.float16)
for x_i in x.detach().cpu().numpy():
hash_value = hash(x_i.tobytes())
hash_value = x_i.tobytes()
hash_values.append(hash_value)

return torch.tensor(hash_values, dtype=torch.int64)
if not self.as_tensor:
return hash_value
return torch.tensor([hash(x) for x in hash_values], dtype=torch.int64)


class RandomProjectionHash(SipHash):
"""A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through SipHash.
"""A module that combines random projections with SipHash to get a low-dimensional tensor, easier to embed through :class:`~.SipHash`.
This module requires sklearn to be installed.
Keyword Args:
n_components (int, optional): the low-dimensional number of components of the projections.
Defaults to 16.
projection_type (str, optional): the projection type to use.
Must be one of ``"gaussian"`` or ``"sparse_random"``. Defaults to "gaussian".
dtype_cast (torch.dtype, optional): the dtype to cast the projection to.
Defaults to ``torch.float16``.
lazy (bool, optional): if ``True``, the random projection is fit on the first batch of data
received. Defaults to ``False``.
Defaults to ``torch.bfloat16``.
as_tensor (bool, optional): if ``True``, the bytes will be turned into integers
through the builtin ``hash`` function and mapped to a tensor. Default: ``True``.
.. warning:: This module relies on the builtin ``hash`` function.
To get reproducible results across runs, the ``PYTHONHASHSEED`` environment
variable must be set before the code is run (changing this value during code
execution is without effect).
init_method: TODO
"""

_N_COMPONENTS_DEFAULT = 16
Expand All @@ -134,47 +149,43 @@ def __init__(
self,
*,
n_components: int | None = None,
projection_type: str = "sparse_random",
dtype_cast=torch.float16,
lazy: bool = False,
dtype_cast=torch.bfloat16,
as_tensor: bool = True,
init_method: Callable[[torch.Tensor], torch.Tensor | None] | None = None,
**kwargs,
):
if n_components is None:
n_components = self._N_COMPONENTS_DEFAULT

super().__init__()
from sklearn.random_projection import (
GaussianRandomProjection,
SparseRandomProjection,
)
super().__init__(as_tensor=as_tensor)
self.register_buffer("_n_components", torch.as_tensor(n_components))

self.lazy = lazy
self._init = not lazy
self._init = False
if init_method is None:
init_method = torch.nn.init.normal_
self.init_method = init_method

self.dtype_cast = dtype_cast
if projection_type.lower() == "gaussian":
self.transform = GaussianRandomProjection(
n_components=n_components, **kwargs
)
elif projection_type.lower() in ("sparse_random", "sparse-random"):
self.transform = SparseRandomProjection(n_components=n_components, **kwargs)
else:
raise ValueError(
f"Only 'gaussian' and 'sparse_random' projections are supported. Got projection_type={projection_type}."
)
self.register_buffer("transform", torch.nn.UninitializedBuffer())

@property
def n_components(self):
return self._n_components.item()

def fit(self, x):
"""Fits the random projection to the input data."""
self.transform.fit(x)
self.transform.materialize(
(x.shape[-1], self.n_components), dtype=self.dtype_cast, device=x.device
)
self.init_method(self.transform)
self._init = True

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.lazy and not self._init:
if not self._init:
self.fit(x)
elif not self._init:
raise RuntimeError(
f"The {type(self).__name__} has not been initialized. Call fit before calling this method."
)
x = self.transform.transform(x)
x = torch.as_tensor(x, dtype=self.dtype_cast)
x = x.to(self.dtype_cast) @ self.transform
return super().forward(x)
7 changes: 1 addition & 6 deletions torchrl/data/map/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import abc
from abc import abstractmethod
from typing import Any, Callable, Dict, Generic, List, Mapping, TypeVar
from typing import Any, Callable, Dict, List, Mapping, TypeVar

import torch

import torch.nn as nn

from tensordict import NestedKey, TensorDict, TensorDictBase
from tensordict.nn.common import TensorDictModuleBase

from torchrl.data.map import SipHash

K = TypeVar("K")
Expand Down
Loading

0 comments on commit f7ecad8

Please sign in to comment.