Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 25, 2024
1 parent 1bc83b5 commit b114f64
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export DISPLAY=:0
export SDL_VIDEODRIVER=dummy

# legacy from bash scripts: remove?
conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False
conda env config vars set MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:0 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG

pip3 install pip --upgrade
pip install virtualenv
Expand Down
2 changes: 2 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def get_default_devices():
return [torch.device("cpu")]
elif num_cuda == 1:
return [torch.device("cuda:0")]
elif torch.mps.is_available():
return [torch.device("mps:0")]
else:
# then run on all devices
return get_available_devices()
Expand Down
10 changes: 4 additions & 6 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,7 +2153,7 @@ def test_multi_collector_consistency(
assert_allclose_td(c2.unsqueeze(0), d2)


@pytest.mark.skipif(not torch.cuda.device_count(), reason="No casting if no cuda")
@pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="No casting if no cuda")
class TestUpdateParams:
class DummyEnv(EnvBase):
def __init__(self, device, batch_size=[]): # noqa: B006
Expand Down Expand Up @@ -2211,8 +2211,8 @@ def forward(self, td):
@pytest.mark.parametrize(
"policy_device,env_device",
[
["cpu", "cuda"],
["cuda", "cpu"],
["cpu", get_default_devices()[0]],
[get_default_devices()[0], "cpu"],
# ["cpu", "cuda:0"], # 1226: faster execution
# ["cuda:0", "cpu"],
# ["cuda", "cuda:0"],
Expand All @@ -2236,9 +2236,7 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
policy.param.data += 1
policy.buf.data += 2
if give_weights:
d = dict(policy.named_parameters())
d.update(policy.named_buffers())
p_w = TensorDict(d, [])
p_w = TensorDict.from_module(policy)
else:
p_w = None
col.update_policy_weights_(p_w)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

VERBOSE = strtobool(os.environ.get("VERBOSE", "0"))
VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))
_os_is_windows = sys.platform == "win32"
RL_WARNINGS = strtobool(os.environ.get("RL_WARNINGS", "1"))
if RL_WARNINGS:
Expand Down
3 changes: 3 additions & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,12 +1772,15 @@ def update_policy_weights_(self, policy_weights=None) -> None:
if policy_weights is not None:
self._policy_weights_dict[_device].data.update_(policy_weights)
elif self._get_weights_fn_dict[_device] is not None:
print(1, self._get_weights_fn_dict[_device])
original_weights = self._get_weights_fn_dict[_device]()
print(2, original_weights)
if original_weights is None:
# if the weights match in identity, we can spare a call to update_
continue
if isinstance(original_weights, TensorDictParams):
original_weights = original_weights.data
print(3, 'self._policy_weights_dict[_device]', self._policy_weights_dict[_device])
self._policy_weights_dict[_device].data.update_(original_weights)

@property
Expand Down
2 changes: 2 additions & 0 deletions torchrl/collectors/distributed/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,8 @@ def _iterator_dist(self):

for i in range(self.num_workers):
rank = i + 1
if self._VERBOSE:
torchrl_logger.info(f"shutting down rank {rank}.")
self._store.set(f"NODE_{rank}_in", b"shutdown")

def _next_sync(self, total_frames):
Expand Down

0 comments on commit b114f64

Please sign in to comment.