Skip to content

Commit

Permalink
[BugFix] Fix collector tests where device ordinal is needed (#2240)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 20, 2024
1 parent eb35793 commit 9b1ebb2
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential

from torch import nn
from torchrl._utils import _replace_last, logger as torchrl_logger, prod, seed_generator
from torchrl._utils import (
_make_ordinal_device,
_replace_last,
logger as torchrl_logger,
prod,
seed_generator,
)
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import (
_Interruptor,
Expand Down Expand Up @@ -285,15 +291,19 @@ def __init__(self, default_device):
self.action_spec = UnboundedContinuousTensorSpec(
(), device=self.default_device
)
assert self.device == torch.device(self.default_device)
assert self.device == _make_ordinal_device(
torch.device(self.default_device)
)
assert self.full_observation_spec is not None
assert self.full_done_spec is not None
assert self.full_state_spec is not None
assert self.full_action_spec is not None
assert self.full_reward_spec is not None

def _step(self, tensordict):
assert tensordict.device == torch.device(self.default_device)
assert tensordict.device == _make_ordinal_device(
torch.device(self.default_device)
)
with torch.device(self.default_device):
return TensorDict(
{
Expand Down Expand Up @@ -339,7 +349,9 @@ class PolicyWithDevice(TensorDictModuleBase):
default_device = "cuda:0" if torch.cuda.device_count() else "cpu"

def forward(self, tensordict):
assert tensordict.device == torch.device(self.default_device)
assert tensordict.device == _make_ordinal_device(
torch.device(self.default_device)
)
return tensordict.set("action", torch.zeros((), device=self.default_device))

@pytest.mark.parametrize("main_device", get_default_devices())
Expand Down Expand Up @@ -1436,7 +1448,7 @@ def env_fn(seed):
)
assert collector._use_buffers
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
assert batch.device == _make_ordinal_device(torch.device(storing_device))
collector.shutdown()

collector = MultiSyncDataCollector(
Expand All @@ -1459,7 +1471,7 @@ def env_fn(seed):
cat_results="stack",
)
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
assert batch.device == _make_ordinal_device(torch.device(storing_device))
collector.shutdown()

collector = MultiaSyncDataCollector(
Expand All @@ -1481,7 +1493,7 @@ def env_fn(seed):
],
)
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
assert batch.device == _make_ordinal_device(torch.device(storing_device))
collector.shutdown()
del collector

Expand Down

0 comments on commit 9b1ebb2

Please sign in to comment.