Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Oct 4, 2024
1 parent 2d622bd commit 12dd746
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
14 changes: 12 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8678,8 +8678,18 @@ def test_compose_indexing(self):
def test_compose_action_spec(self):
# Create a Compose transform that renames "action" to "action_1" and then to "action_2"
c = Compose(
RenameTransform(in_keys=(), out_keys=(), in_keys_inv=("action",), out_keys_inv=("action_1",)),
RenameTransform(in_keys=(), out_keys=(), in_keys_inv=("action_1",), out_keys_inv=("action_2",)),
RenameTransform(
in_keys=(),
out_keys=(),
in_keys_inv=("action",),
out_keys_inv=("action_1",),
),
RenameTransform(
in_keys=(),
out_keys=(),
in_keys_inv=("action_1",),
out_keys_inv=("action_2",),
),
)
base_env = ContinuousActionVecMockEnv()
env = TransformedEnv(base_env, c)
Expand Down
29 changes: 15 additions & 14 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,21 +1675,22 @@ def __init__(
self.cat_results = cat_results

def _check_replay_buffer_init(self):
try:
if not self.replay_buffer._storage.initialized:
if isinstance(self.create_env_fn, EnvCreator):
fake_td = self.create_env_fn.tensordict
elif isinstance(self.create_env_fn, EnvBase):
fake_td = self.create_env_fn.fake_tensordict()
else:
fake_td = self.create_env_fn[0](
**self.create_env_kwargs[0]
).fake_tensordict()
fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long)
is_init = getattr(self.replay_buffer._storage, "initialized", True)
if not is_init:
if isinstance(self.create_env_fn[0], EnvCreator):
fake_td = self.create_env_fn[0].tensordict
elif isinstance(self.create_env_fn[0], EnvBase):
fake_td = self.create_env_fn[0].fake_tensordict()
else:
fake_td = self.create_env_fn[0](
**self.create_env_kwargs[0]
).fake_tensordict()
fake_td["collector", "traj_ids"] = torch.zeros(
fake_td.shape, dtype=torch.long
)

self.replay_buffer._storage._init(fake_td)
except AttributeError:
pass
self.replay_buffer.add(fake_td)
self.replay_buffer.empty()

@classmethod
def _total_workers_from_env(cls, env_creators):
Expand Down

0 comments on commit 12dd746

Please sign in to comment.