From 12dd746518a917024e37aff5626cab7d160e512d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 4 Oct 2024 13:46:06 +0100 Subject: [PATCH] lint --- test/test_transforms.py | 14 ++++++++++++-- torchrl/collectors/collectors.py | 29 +++++++++++++++-------------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index d0bc785f5fe..589c32809cc 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -8678,8 +8678,18 @@ def test_compose_indexing(self): def test_compose_action_spec(self): # Create a Compose transform that renames "action" to "action_1" and then to "action_2" c = Compose( - RenameTransform(in_keys=(), out_keys=(), in_keys_inv=("action",), out_keys_inv=("action_1",)), - RenameTransform(in_keys=(), out_keys=(), in_keys_inv=("action_1",), out_keys_inv=("action_2",)), + RenameTransform( + in_keys=(), + out_keys=(), + in_keys_inv=("action",), + out_keys_inv=("action_1",), + ), + RenameTransform( + in_keys=(), + out_keys=(), + in_keys_inv=("action_1",), + out_keys_inv=("action_2",), + ), ) base_env = ContinuousActionVecMockEnv() env = TransformedEnv(base_env, c) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b5a1764dc2d..9ccd2e2aa80 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -1675,21 +1675,22 @@ def __init__( self.cat_results = cat_results def _check_replay_buffer_init(self): - try: - if not self.replay_buffer._storage.initialized: - if isinstance(self.create_env_fn, EnvCreator): - fake_td = self.create_env_fn.tensordict - elif isinstance(self.create_env_fn, EnvBase): - fake_td = self.create_env_fn.fake_tensordict() - else: - fake_td = self.create_env_fn[0]( - **self.create_env_kwargs[0] - ).fake_tensordict() - fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long) + is_init = getattr(self.replay_buffer._storage, "initialized", True) + if not is_init: + if isinstance(self.create_env_fn[0], EnvCreator): + fake_td = self.create_env_fn[0].tensordict + elif isinstance(self.create_env_fn[0], EnvBase): + fake_td = self.create_env_fn[0].fake_tensordict() + else: + fake_td = self.create_env_fn[0]( + **self.create_env_kwargs[0] + ).fake_tensordict() + fake_td["collector", "traj_ids"] = torch.zeros( + fake_td.shape, dtype=torch.long + ) - self.replay_buffer._storage._init(fake_td) - except AttributeError: - pass + self.replay_buffer.add(fake_td) + self.replay_buffer.empty() @classmethod def _total_workers_from_env(cls, env_creators):