Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 26, 2024
1 parent b114f64 commit 522111a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 32 deletions.
17 changes: 13 additions & 4 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,10 +1833,15 @@ def test_set_truncated(collector_cls):
NestedCountingEnv(), InitTracker()
).add_truncated_keys()
env = env_fn()
policy = env.rand_action
policy = CloudpickleWrapper(env.rand_action)
if collector_cls == SyncDataCollector:
collector = collector_cls(
env, policy=policy, frames_per_batch=20, total_frames=-1, set_truncated=True
env,
policy=policy,
frames_per_batch=20,
total_frames=-1,
set_truncated=True,
trust_policy=True,
)
else:
collector = collector_cls(
Expand All @@ -1846,6 +1851,7 @@ def test_set_truncated(collector_cls):
total_frames=-1,
cat_results="stack",
set_truncated=True,
trust_policy=True,
)
try:
for data in collector:
Expand Down Expand Up @@ -2153,7 +2159,10 @@ def test_multi_collector_consistency(
assert_allclose_td(c2.unsqueeze(0), d2)


@pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="No casting if no cuda")
@pytest.mark.skipif(
not torch.cuda.is_available() and not torch.mps.is_available(),
reason="No casting if no cuda",
)
class TestUpdateParams:
class DummyEnv(EnvBase):
def __init__(self, device, batch_size=[]): # noqa: B006
Expand Down Expand Up @@ -2952,7 +2961,7 @@ def make_policy(device=None, nn_module=True):
out_keys=["action"],
)
policy = make_policy(device=device)
return CloudpickleWrapper(lambda tensordict: policy(tensordict))
return CloudpickleWrapper(policy)

def make_and_test_policy(
policy,
Expand Down
7 changes: 4 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def _get_policy_and_device(

i = -1
for p in param_and_buf.values(True, True):
i += 1
if p.device != policy_device:
# Then we need casting
break
Expand Down Expand Up @@ -1772,15 +1773,12 @@ def update_policy_weights_(self, policy_weights=None) -> None:
if policy_weights is not None:
self._policy_weights_dict[_device].data.update_(policy_weights)
elif self._get_weights_fn_dict[_device] is not None:
print(1, self._get_weights_fn_dict[_device])
original_weights = self._get_weights_fn_dict[_device]()
print(2, original_weights)
if original_weights is None:
# if the weights match in identity, we can spare a call to update_
continue
if isinstance(original_weights, TensorDictParams):
original_weights = original_weights.data
print(3, 'self._policy_weights_dict[_device]', self._policy_weights_dict[_device])
self._policy_weights_dict[_device].data.update_(original_weights)

@property
Expand Down Expand Up @@ -1842,6 +1840,7 @@ def _run_processes(self) -> None:
"replay_buffer": self.replay_buffer,
"replay_buffer_chunk": self.replay_buffer_chunk,
"traj_pool": self._traj_pool,
"trust_policy": self.trust_policy,
}
proc = _ProcessNoWarn(
target=_main_async_collector,
Expand Down Expand Up @@ -2857,6 +2856,7 @@ def _main_async_collector(
replay_buffer: ReplayBuffer | None = None,
replay_buffer_chunk: bool = True,
traj_pool: _TrajectoryPool = None,
trust_policy: bool = False,
) -> None:
pipe_parent.close()
# init variables that will be cleared when closing
Expand All @@ -2883,6 +2883,7 @@ def _main_async_collector(
use_buffers=use_buffers,
replay_buffer=replay_buffer if replay_buffer_chunk else None,
traj_pool=traj_pool,
trust_policy=trust_policy,
)
use_buffers = inner_collector._use_buffers
if verbose:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __init__(self, fn: Callable, **kwargs):
self.fn = fn
self.kwargs = kwargs

functools.update_wrapper(self, fn)
functools.update_wrapper(self, getattr(fn, "forward", fn))

def __getstate__(self):
import cloudpickle
Expand Down
47 changes: 23 additions & 24 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
TensorSpec,
Unbounded,
)
from torchrl.data.utils import check_no_exclusive_keys
from torchrl.data.utils import check_no_exclusive_keys, CloudpickleWrapper

__all__ = [
"exploration_mode",
Expand Down Expand Up @@ -1464,7 +1464,9 @@ def _make_compatible_policy(
policy = RandomPolicy(input_spec)

# make sure policy is an nn.Module - this will return the same policy if conditions are met
policy = _NonParametricPolicyWrapper(policy)
# policy = CloudpickleWrapper(policy)

caller = getattr(policy, "forward", policy)

if not _policy_is_tensordict_compatible(policy):
if observation_spec is None:
Expand Down Expand Up @@ -1496,9 +1498,9 @@ def _make_compatible_policy(
)

try:
sig = policy.forward.__signature__
sig = caller.__signature__
except AttributeError:
sig = inspect.signature(policy.forward)
sig = inspect.signature(caller)
# we check if all the mandatory params are there
params = list(sig.parameters.keys())
if (
Expand Down Expand Up @@ -1527,7 +1529,7 @@ def _make_compatible_policy(
out_keys = ["action"]
else:
out_keys = list(env.action_keys)
for p in policy.parameters():
for p in getattr(policy, "parameters", list)():
policy_device = p.device
break
else:
Expand Down Expand Up @@ -1559,15 +1561,21 @@ def _make_compatible_policy(


def _policy_is_tensordict_compatible(policy: nn.Module):
if isinstance(policy, _NonParametricPolicyWrapper) and isinstance(
policy.policy, RandomPolicy
):
return True

if isinstance(policy, TensorDictModuleBase):
def is_compatible(policy):
return isinstance(policy, (RandomPolicy, TensorDictModuleBase))

if (
is_compatible(policy)
or (
isinstance(policy, _NonParametricPolicyWrapper)
and is_compatible(policy.policy)
)
or (isinstance(policy, CloudpickleWrapper) and is_compatible(policy.fn))
):
return True

sig = inspect.signature(policy.forward)
sig = inspect.signature(getattr(policy, "forward", policy))

if (
len(sig.parameters) == 1
Expand Down Expand Up @@ -1640,19 +1648,10 @@ class _NonParametricPolicyWrapper(nn.Module, metaclass=_PolicyMetaClass):

def __init__(self, policy):
super().__init__()
self.policy = policy

@property
def forward(self):
forward = self.__dict__.get("_forward", None)
if forward is None:

@functools.wraps(self.policy)
def forward(*input, **kwargs):
return self.policy.__call__(*input, **kwargs)

self.__dict__["_forward"] = forward
return forward
functools.update_wrapper(self, policy)
self.policy = CloudpickleWrapper(policy)
if hasattr(policy, "forward"):
self.forward = self.policy.forward

def __getattr__(self, attr: str) -> Any:
if attr in self.__dir__():
Expand Down

0 comments on commit 522111a

Please sign in to comment.