Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 15, 2025
2 parents ca4bd41 + d05004b commit 7f1d695
Show file tree
Hide file tree
Showing 8 changed files with 629 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ to be able to create this other composition:
FlattenObservation
FrameSkipTransform
GrayScale
Hash
InitTracker
KLRewardTransform
LineariseReward
Expand All @@ -853,6 +854,7 @@ to be able to create this other composition:
TimeMaxPool
ToTensorImage
TrajCounter
UnaryTransform
UnsqueezeTransform
VC1Transform
VIPRewardTransform
Expand Down
30 changes: 30 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import random
import string
from typing import Dict, List, Optional

import torch
Expand Down Expand Up @@ -1066,6 +1068,34 @@ def _step(
return tensordict


class CountingEnvWithString(CountingEnv):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.observation_spec.set(
"string",
NonTensor(
shape=self.batch_size,
device=self.device,
),
)

def get_random_string(self):
size = random.randint(4, 30)
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
res = super()._reset(tensordict, **kwargs)
random_string = self.get_random_string()
res["string"] = random_string
return res

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
res = super()._step(tensordict)
random_string = self.get_random_string()
res["string"] = random_string
return res


class MultiAgentCountingEnv(EnvBase):
"""A multi-agent env that is done after a given number of steps.
Expand Down
245 changes: 244 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
CountingBatchedEnv,
CountingEnv,
CountingEnvCountPolicy,
CountingEnvWithString,
DiscreteActionConvMockEnv,
DiscreteActionConvMockEnvNumpy,
EnvWithScalarAction,
Expand Down Expand Up @@ -66,6 +67,7 @@
CountingBatchedEnv,
CountingEnv,
CountingEnvCountPolicy,
CountingEnvWithString,
DiscreteActionConvMockEnv,
DiscreteActionConvMockEnvNumpy,
EnvWithScalarAction,
Expand All @@ -77,7 +79,7 @@
MultiKeyCountingEnvPolicy,
NestedCountingEnv,
)
from tensordict import TensorDict, TensorDictBase, unravel_key
from tensordict import NonTensorData, TensorDict, TensorDictBase, unravel_key
from tensordict.nn import TensorDictSequential
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
from torch import multiprocessing as mp, nn, Tensor
Expand Down Expand Up @@ -118,6 +120,7 @@
FrameSkipTransform,
GrayScale,
gSDENoise,
Hash,
InitTracker,
LineariseRewards,
MultiStepTransform,
Expand Down Expand Up @@ -2180,6 +2183,246 @@ def test_transform_no_env(self, device, batch):
pytest.skip("TrajCounter cannot be called without env")


class TestHash(TransformBase):
@pytest.mark.parametrize("datatype", ["tensor", "str", "NonTensorStack"])
def test_transform_no_env(self, datatype):
if datatype == "tensor":
obs = torch.tensor(10)
hash_fn = hash
elif datatype == "str":
obs = "abcdefg"
hash_fn = Hash.reproducible_hash
elif datatype == "NonTensorStack":
obs = torch.stack(
[
NonTensorData(data="abcde"),
NonTensorData(data="fghij"),
NonTensorData(data="klmno"),
]
)

def fn0(x):
return torch.stack([Hash.reproducible_hash(x_) for x_ in x])

hash_fn = fn0
else:
raise RuntimeError(f"please add a test case for datatype {datatype}")

td = TensorDict(
{
"observation": obs,
}
)

t = Hash(in_keys=["observation"], out_keys=["hashing"], hash_fn=hash_fn)
td_hashed = t(td)

assert td_hashed.get("observation") is td.get("observation")

if datatype == "NonTensorStack":
assert (
td_hashed["hashing"] == hash_fn(td.get("observation").tolist())
).all()
elif datatype == "str":
assert all(td_hashed["hashing"] == hash_fn(td["observation"]))
else:
assert td_hashed["hashing"] == hash_fn(td["observation"])

@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_single_trans_env_check(self, datatype):
if datatype == "tensor":
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
)
base_env = CountingEnv()
elif datatype == "str":
t = Hash(
in_keys=["string"],
out_keys=["hashing"],
)
base_env = CountingEnvWithString()
env = TransformedEnv(base_env, t)
check_env_specs(env)

@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_serial_trans_env_check(self, datatype):
def make_env():
if datatype == "tensor":
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
)
base_env = CountingEnv()

elif datatype == "str":
t = Hash(
in_keys=["string"],
out_keys=["hashing"],
)
base_env = CountingEnvWithString()

return TransformedEnv(base_env, t)

env = SerialEnv(2, make_env)
check_env_specs(env)

@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype):
def make_env():
if datatype == "tensor":
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
)
base_env = CountingEnv()
elif datatype == "str":
t = Hash(
in_keys=["string"],
out_keys=["hashing"],
)
base_env = CountingEnvWithString()
return TransformedEnv(base_env, t)

env = maybe_fork_ParallelEnv(2, make_env)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_trans_serial_env_check(self, datatype):
if datatype == "tensor":
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
)
base_env = CountingEnv
elif datatype == "str":
t = Hash(
in_keys=["string"],
out_keys=["hashing"],
hash_fn=lambda x: torch.stack([Hash.reproducible_hash(x_) for x_ in x]),
)
base_env = CountingEnvWithString

env = TransformedEnv(SerialEnv(2, base_env), t)
check_env_specs(env)

@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
if datatype == "tensor":
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
)
base_env = CountingEnv
elif datatype == "str":
t = Hash(
in_keys=["string"],
out_keys=["hashing"],
hash_fn=lambda x: torch.stack([Hash.reproducible_hash(x_) for x_ in x]),
)
base_env = CountingEnvWithString

env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t)
try:
check_env_specs(env)
finally:
try:
env.close()
except RuntimeError:
pass

@pytest.mark.parametrize("datatype", ["tensor", "str"])
def test_transform_compose(self, datatype):
if datatype == "tensor":
obs = torch.tensor(10)
elif datatype == "str":
obs = "abcdefg"

td = TensorDict(
{
"observation": obs,
}
)
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
)
t = Compose(t)
td_hashed = t(td)

assert td_hashed["observation"] is td["observation"]
assert td_hashed["hashing"] == hash(td["observation"])

def test_transform_model(self):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hashing"), ("hashing",)],
hash_fn=hash,
)
model = nn.Sequential(t, nn.Identity())
td = TensorDict(
{("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, []
)
td_out = model(td)
assert ("next", "hashing") in td_out.keys(True)
assert ("hashing",) in td_out.keys(True)
assert td_out["next", "hashing"] == hash(td["next", "observation"])
assert td_out["hashing"] == hash(td["observation"])

@pytest.mark.skipif(not _has_gym, reason="Gym not found")
def test_transform_env(self):
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
)
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
assert env.observation_spec["hashing"]
assert "observation" in env.observation_spec
assert "observation" in env.base_env.observation_spec
check_env_specs(env)

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hashing"), ("hashing",)],
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
)
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(t)
td = TensorDict(
{
"observation": torch.randn(3, 4),
"next": TensorDict(
{"observation": torch.randn(3, 4)},
[],
),
},
[],
).expand(10)
rb.extend(td)
td = rb.sample(2)
assert "hashing" in td.keys()
assert "observation" in td.keys()
assert ("next", "observation") in td.keys(True)

def test_transform_inverse(self):
raise pytest.skip("No inverse for Hash")


class TestStack(TransformBase):
def test_single_trans_env_check(self):
t = Stack(
Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
FrameSkipTransform,
GrayScale,
gSDENoise,
Hash,
InitTracker,
KLRewardTransform,
LineariseRewards,
Expand Down Expand Up @@ -97,6 +98,7 @@
TrajCounter,
Transform,
TransformedEnv,
UnaryTransform,
UnsqueezeTransform,
VC1Transform,
VecGymEnvTransform,
Expand Down
Loading

0 comments on commit 7f1d695

Please sign in to comment.