From 9b1ebb2f63438def890cfb083cf02cdd79daac50 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 20 Jun 2024 13:38:24 +0100 Subject: [PATCH] [BugFix] Fix collector tests where device ordinal is needed (#2240) --- test/test_collector.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 7c54d279b63..12ec490e7e2 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -53,7 +53,13 @@ from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential from torch import nn -from torchrl._utils import _replace_last, logger as torchrl_logger, prod, seed_generator +from torchrl._utils import ( + _make_ordinal_device, + _replace_last, + logger as torchrl_logger, + prod, + seed_generator, +) from torchrl.collectors import aSyncDataCollector, SyncDataCollector from torchrl.collectors.collectors import ( _Interruptor, @@ -285,7 +291,9 @@ def __init__(self, default_device): self.action_spec = UnboundedContinuousTensorSpec( (), device=self.default_device ) - assert self.device == torch.device(self.default_device) + assert self.device == _make_ordinal_device( + torch.device(self.default_device) + ) assert self.full_observation_spec is not None assert self.full_done_spec is not None assert self.full_state_spec is not None @@ -293,7 +301,9 @@ def __init__(self, default_device): assert self.full_reward_spec is not None def _step(self, tensordict): - assert tensordict.device == torch.device(self.default_device) + assert tensordict.device == _make_ordinal_device( + torch.device(self.default_device) + ) with torch.device(self.default_device): return TensorDict( { @@ -339,7 +349,9 @@ class PolicyWithDevice(TensorDictModuleBase): default_device = "cuda:0" if torch.cuda.device_count() else "cpu" def forward(self, tensordict): - assert tensordict.device == torch.device(self.default_device) + assert tensordict.device == _make_ordinal_device( + torch.device(self.default_device) + ) return tensordict.set("action", torch.zeros((), device=self.default_device)) @pytest.mark.parametrize("main_device", get_default_devices()) @@ -1436,7 +1448,7 @@ def env_fn(seed): ) assert collector._use_buffers batch = next(collector.iterator()) - assert batch.device == torch.device(storing_device) + assert batch.device == _make_ordinal_device(torch.device(storing_device)) collector.shutdown() collector = MultiSyncDataCollector( @@ -1459,7 +1471,7 @@ def env_fn(seed): cat_results="stack", ) batch = next(collector.iterator()) - assert batch.device == torch.device(storing_device) + assert batch.device == _make_ordinal_device(torch.device(storing_device)) collector.shutdown() collector = MultiaSyncDataCollector( @@ -1481,7 +1493,7 @@ def env_fn(seed): ], ) batch = next(collector.iterator()) - assert batch.device == torch.device(storing_device) + assert batch.device == _make_ordinal_device(torch.device(storing_device)) collector.shutdown() del collector