Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix DeviceCastTransform #2471

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5114,7 +5114,9 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
pass

@pytest.mark.parametrize("has_in_keys,", [True, False])
@pytest.mark.parametrize("reset_keys,", [None, ["_reset"] * 3])
@pytest.mark.parametrize(
"reset_keys,", [[("some", "nested", "reset")], ["_reset"] * 3, None]
)
def test_trans_multi_key(
self, has_in_keys, reset_keys, n_workers=2, batch_size=(3, 2), max_steps=5
):
Expand All @@ -5136,9 +5138,9 @@ def test_trans_multi_key(
)
with pytest.raises(
ValueError, match="Could not match the env reset_keys"
) if reset_keys is None else contextlib.nullcontext():
) if reset_keys == [("some", "nested", "reset")] else contextlib.nullcontext():
check_env_specs(env)
if reset_keys is not None:
if reset_keys != [("some", "nested", "reset")]:
td = env.rollout(max_steps, policy=policy)
for reward_key in env.reward_keys:
reward_key = _unravel_key_to_tuple(reward_key)
Expand Down Expand Up @@ -9955,16 +9957,27 @@ def test_transform_inverse(self):


class TestDeviceCastTransformPart(TransformBase):
@pytest.fixture(scope="class")
def _cast_device(self):
if torch.cuda.is_available():
yield torch.device("cuda:0")
elif torch.backends.mps.is_available():
yield torch.device("mps:0")
else:
yield torch.device("cpu:1")

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def test_single_trans_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -9978,12 +9991,14 @@ def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_i
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_serial_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def test_serial_trans_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -10000,13 +10015,13 @@ def make_env():
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_parallel_trans_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -10032,14 +10047,16 @@ def make_env():
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_serial_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def test_trans_serial_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(
SerialEnv(2, make_env),
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -10054,7 +10071,7 @@ def make_env():
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_parallel_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")
Expand All @@ -10066,7 +10083,7 @@ def make_env():
mp_start_method=mp_ctx if not torch.cuda.is_available() else "spawn",
),
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
Expand All @@ -10082,8 +10099,8 @@ def make_env():
except RuntimeError:
pass

def test_transform_no_env(self):
t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"])
def test_transform_no_env(self, _cast_device):
t = DeviceCastTransform(_cast_device, "cpu:0", in_keys=["a"], out_keys=["b"])
td = TensorDict({"a": torch.randn((), device="cpu:0")}, [], device="cpu:0")
tdt = t._call(td)
assert tdt.device is None
Expand All @@ -10092,26 +10109,28 @@ def test_transform_no_env(self):
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def test_transform_env(
self, in_keys, out_keys, in_keys_inv, out_keys_inv, _cast_device
):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
"cpu:1",
_cast_device,
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
assert env.transform.device == torch.device("cpu:1")
assert env.transform.device == _cast_device
assert env.transform.orig_device == torch.device("cpu:0")

def test_transform_compose(self):
def test_transform_compose(self, _cast_device):
t = Compose(
DeviceCastTransform(
"cpu:1",
_cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
Expand All @@ -10123,7 +10142,7 @@ def test_transform_compose(self):
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
"c": torch.randn((), device=_cast_device),
},
[],
device="cpu:0",
Expand All @@ -10134,11 +10153,11 @@ def test_transform_compose(self):
assert tdt.device is None
assert tdit.device is None

def test_transform_model(self):
def test_transform_model(self, _cast_device):
t = nn.Sequential(
Compose(
DeviceCastTransform(
"cpu:1",
_cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
Expand All @@ -10161,11 +10180,11 @@ def test_transform_model(self):

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@pytest.mark.parametrize("storage", [LazyTensorStorage])
def test_transform_rb(self, rbclass, storage):
def test_transform_rb(self, rbclass, storage, _cast_device):
# we don't test casting to cuda on Memmap tensor storage since it's discouraged
t = Compose(
DeviceCastTransform(
"cpu:1",
_cast_device,
"cpu:0",
in_keys=["a"],
out_keys=["b"],
Expand All @@ -10178,7 +10197,7 @@ def test_transform_rb(self, rbclass, storage):
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
"c": torch.randn((), device=_cast_device),
},
[],
device="cpu:0",
Expand Down
51 changes: 36 additions & 15 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3893,6 +3893,19 @@ class DeviceCastTransform(Transform):
a parent environment exists, it it retrieved from it. In all other cases,
it remains unspecified.

Keyword Args:
in_keys (list of NestedKey): the list of entries to map to a different device.
Defaults to ``None``.
out_keys (list of NestedKey): the output names of the entries mapped onto a device.
Defaults to the values of ``in_keys``.
in_keys_inv (list of NestedKey): the list of entries to map to a different device.
``in_keys_inv`` are the names expected by the base environment.
Defaults to ``None``.
out_keys_inv (list of NestedKey): the output names of the entries mapped onto a device.
``out_keys_inv`` are the names of the keys as seen from outside the transformed env.
Defaults to the values of ``in_keys_inv``.


Examples:
>>> td = TensorDict(
... {'obs': torch.ones(1, dtype=torch.double),
Expand Down Expand Up @@ -3920,6 +3933,10 @@ def __init__(
self.orig_device = (
torch.device(orig_device) if orig_device is not None else orig_device
)
if out_keys is None:
out_keys = copy(in_keys)
if out_keys_inv is None:
out_keys_inv = copy(in_keys_inv)
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
Expand Down Expand Up @@ -4043,52 +4060,54 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
if self._map_env_device:
return input_spec.to(self.device)
else:
input_spec.clear_device_()
return super().transform_input_spec(input_spec)

def transform_action_spec(self, full_action_spec: Composite) -> Composite:
full_action_spec = full_action_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key not in full_action_spec.keys(True, True):
continue
full_action_spec[out_key] = full_action_spec[in_key].to(self.device)
local_action_spec = full_action_spec.get(in_key, None)
if local_action_spec is not None:
full_action_spec[out_key] = local_action_spec.to(self.device)
return full_action_spec

def transform_state_spec(self, full_state_spec: Composite) -> Composite:
full_state_spec = full_state_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key not in full_state_spec.keys(True, True):
continue
full_state_spec[out_key] = full_state_spec[in_key].to(self.device)
local_state_spec = full_state_spec.get(in_key, None)
if local_state_spec is not None:
full_state_spec[out_key] = local_state_spec.to(self.device)
return full_state_spec

def transform_output_spec(self, output_spec: Composite) -> Composite:
if self._map_env_device:
return output_spec.to(self.device)
else:
output_spec.clear_device_()
return super().transform_output_spec(output_spec)

def transform_observation_spec(self, observation_spec: Composite) -> Composite:
observation_spec = observation_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in observation_spec.keys(True, True):
continue
observation_spec[out_key] = observation_spec[in_key].to(self.device)
local_obs_spec = observation_spec.get(in_key, None)
if local_obs_spec is not None:
observation_spec[out_key] = local_obs_spec.to(self.device)
return observation_spec

def transform_done_spec(self, full_done_spec: Composite) -> Composite:
full_done_spec = full_done_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in full_done_spec.keys(True, True):
continue
full_done_spec[out_key] = full_done_spec[in_key].to(self.device)
local_done_spec = full_done_spec.get(in_key, None)
if local_done_spec is not None:
full_done_spec[out_key] = local_done_spec.to(self.device)
return full_done_spec

def transform_reward_spec(self, full_reward_spec: Composite) -> Composite:
full_reward_spec = full_reward_spec.clear_device_()
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in full_reward_spec.keys(True, True):
continue
full_reward_spec[out_key] = full_reward_spec[in_key].to(self.device)
local_reward_spec = full_reward_spec.get(in_key, None)
if local_reward_spec is not None:
full_reward_spec[out_key] = local_reward_spec.to(self.device)
return full_reward_spec

def transform_env_device(self, device):
Expand Down Expand Up @@ -5494,6 +5513,8 @@ def reset_keys(self):
# We take the filtered reset keys, which are the only keys that really
# matter when calling reset, and check that they match the in_keys root.
reset_keys = parent._filtered_reset_keys
if len(reset_keys) == 1:
reset_keys = list(reset_keys) * len(self.in_keys)

def _check_match(reset_keys, in_keys):
# if this is called, the length of reset_keys and in_keys must match
Expand Down
Loading