diff --git a/examples/agents/composite_ppo.py b/examples/agents/composite_ppo.py index d75ce3218b3..501dceb651d 100644 --- a/examples/agents/composite_ppo.py +++ b/examples/agents/composite_ppo.py @@ -4,28 +4,71 @@ # LICENSE file in the root directory of this source tree. """ -Multi-head agent and PPO loss +Multi-head Agent and PPO Loss ============================= - This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions (Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses. -The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict. -It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution -object containing the three distributions. - -The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters, -creates a distribution from these parameters, and samples from the distribution to output multiple actions. - -The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss. - -Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a -fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities` -argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False` -if not specified. - -In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in -the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used. +Step-by-step Explanation +------------------------ + +1. **Setting Composite Log-Probabilities**: + - To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC + or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during + execution of your script. + - From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in + `CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities. + - Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named + `_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies + for instance, the log-probability will be named `"action_log_prob"`. + Previously, log-prob keys defaulted to `sample_log_prob`. +2. **Action Grouping**: + - Actions can be grouped or not; PPO doesn't require them to be grouped. + - If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and + log-probability, e.g., `agent0`, `agent0_log_prob`, etc. + + ... [...] + ... action: TensorDict( + ... fields={ + ... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + + - If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields. + + ... [...] + ... agent0: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + ... agent1: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + ... agent2: TensorDict( + ... fields={ + ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), + ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)}, + ... batch_size=torch.Size([4]), + ... device=None, + ... is_shared=False), + +3. **PPO Loss Calculation**: + - Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage. + +The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses. """ @@ -38,6 +81,7 @@ InteractionType, ProbabilisticTensorDictModule as Prob, ProbabilisticTensorDictSequential as ProbSeq, + set_composite_lp_aggregate, TensorDictModule as Mod, TensorDictSequential as Seq, WrapModule as Wrap, @@ -45,6 +89,10 @@ from torch import distributions as d from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss +set_composite_lp_aggregate(False).set() + +GROUPED_ACTIONS = False + make_params = Mod( lambda: ( torch.ones(4), @@ -74,8 +122,18 @@ def mixture_constructor(logits, loc, scale): ) -# ============================================================================= -# Example 0: aggregate_probabilities=None (default) =========================== +if GROUPED_ACTIONS: + name_map = { + "gamma": ("action", "agent0"), + "Kumaraswamy": ("action", "agent1"), + "mixture": ("action", "agent2"), + } +else: + name_map = { + "gamma": ("agent0", "action"), + "Kumaraswamy": ("agent1", "action"), + "mixture": ("agent2", "action"), + } dist_constructor = functools.partial( CompositeDistribution, @@ -84,12 +142,7 @@ def mixture_constructor(logits, loc, scale): "Kumaraswamy": d.Kumaraswamy, "mixture": mixture_constructor, }, - name_map={ - "gamma": ("agent0", "action"), - "Kumaraswamy": ("agent1", "action"), - "mixture": ("agent2", "action"), - }, - aggregate_probabilities=None, + name_map=name_map, ) @@ -97,7 +150,7 @@ def mixture_constructor(logits, loc, scale): make_params, Prob( in_keys=["params"], - out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], + out_keys=list(name_map.values()), distribution_class=dist_constructor, return_log_prob=True, default_interaction_type=InteractionType.RANDOM, @@ -105,19 +158,11 @@ def mixture_constructor(logits, loc, scale): ) td = policy(TensorDict(batch_size=[4])) -print("0. result of policy call", td) +print("Result of policy call", td) dist = policy.get_dist(td) -log_prob = dist.log_prob( - td, aggregate_probabilities=False, inplace=False, include_sum=False -) -print("0. non-aggregated log-prob") - -# We can also get the log-prob from the policy directly -log_prob = policy.log_prob( - td, aggregate_probabilities=False, inplace=False, include_sum=False -) -print("0. non-aggregated log-prob (from policy)") +log_prob = dist.log_prob(td) +print("Composite log-prob", log_prob) # Build a dummy value operator value_operator = Seq( @@ -134,70 +179,12 @@ def mixture_constructor(logits, loc, scale): TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)), ) -# Instantiate the loss +# Instantiate the loss - test the 3 different PPO losses for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): + # PPO sets the keys automatically by looking at the policy ppo = loss_cls(policy, value_operator) - - # Keys are not the default ones - there is more than one action - ppo.set_keys( - action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], - sample_log_prob=[ - ("agent0", "action_log_prob"), - ("agent1", "action_log_prob"), - ("agent2", "action_log_prob"), - ], - ) - - # Get the loss values - loss_vals = ppo(data) - print("0. ", loss_cls, loss_vals) - - -# =================================================================== -# Example 1: aggregate_probabilities=True =========================== - -dist_constructor.keywords["aggregate_probabilities"] = True - -td = policy(TensorDict(batch_size=[4])) -print("1. result of policy call", td) - -# Instantiate the loss -for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): - ppo = loss_cls(policy, value_operator) - - # Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since - # there is only one. - ppo.set_keys( - action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")] - ) - - # Get the loss values - loss_vals = ppo(data) - print("1. ", loss_cls, loss_vals) - - -# =================================================================== -# Example 2: aggregate_probabilities=False =========================== - -dist_constructor.keywords["aggregate_probabilities"] = False - -td = policy(TensorDict(batch_size=[4])) -print("2. result of policy call", td) - -# Instantiate the loss -for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss): - ppo = loss_cls(policy, value_operator) - - # Keys are not the default ones - there is more than one action - ppo.set_keys( - action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")], - sample_log_prob=[ - ("agent0", "action_log_prob"), - ("agent1", "action_log_prob"), - ("agent2", "action_log_prob"), - ], - ) + print("tensor keys", ppo.tensor_keys) # Get the loss values loss_vals = ppo(data) - print("2. ", loss_cls, loss_vals) + print("Loss result:", loss_cls, loss_vals) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 9b41afd9afa..b8425e085b1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -959,6 +959,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: # to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion # of the weights. lw = log_weight.squeeze() + if not isinstance(lw, torch.Tensor): + lw = _sum_td_features(lw) ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp() batch = log_weight.shape[0]