Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 16, 2025
2 parents 38b5a95 + 4dc12b8 commit aa372c3
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 100 deletions.
187 changes: 87 additions & 100 deletions examples/agents/composite_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,71 @@
# LICENSE file in the root directory of this source tree.

"""
Multi-head agent and PPO loss
Multi-head Agent and PPO Loss
=============================
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict.
It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution
object containing the three distributions.
The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters,
creates a distribution from these parameters, and samples from the distribution to output multiple actions.
The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss.
Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a
fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities`
argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False`
if not specified.
In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in
the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used.
Step-by-step Explanation
------------------------
1. **Setting Composite Log-Probabilities**:
- To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC
or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during
execution of your script.
- From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in
`CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities.
- Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named
`<action_key>_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies
for instance, the log-probability will be named `"action_log_prob"`.
Previously, log-prob keys defaulted to `sample_log_prob`.
2. **Action Grouping**:
- Actions can be grouped or not; PPO doesn't require them to be grouped.
- If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and
log-probability, e.g., `agent0`, `agent0_log_prob`, etc.
... [...]
... action: TensorDict(
... fields={
... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
... batch_size=torch.Size([4]),
... device=None,
... is_shared=False),
- If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields.
... [...]
... agent0: TensorDict(
... fields={
... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
... batch_size=torch.Size([4]),
... device=None,
... is_shared=False),
... agent1: TensorDict(
... fields={
... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
... batch_size=torch.Size([4]),
... device=None,
... is_shared=False),
... agent2: TensorDict(
... fields={
... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
... batch_size=torch.Size([4]),
... device=None,
... is_shared=False),
3. **PPO Loss Calculation**:
- Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage.
The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses.
"""

Expand All @@ -38,13 +81,18 @@
InteractionType,
ProbabilisticTensorDictModule as Prob,
ProbabilisticTensorDictSequential as ProbSeq,
set_composite_lp_aggregate,
TensorDictModule as Mod,
TensorDictSequential as Seq,
WrapModule as Wrap,
)
from torch import distributions as d
from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss

set_composite_lp_aggregate(False).set()

GROUPED_ACTIONS = False

make_params = Mod(
lambda: (
torch.ones(4),
Expand Down Expand Up @@ -74,8 +122,18 @@ def mixture_constructor(logits, loc, scale):
)


# =============================================================================
# Example 0: aggregate_probabilities=None (default) ===========================
if GROUPED_ACTIONS:
name_map = {
"gamma": ("action", "agent0"),
"Kumaraswamy": ("action", "agent1"),
"mixture": ("action", "agent2"),
}
else:
name_map = {
"gamma": ("agent0", "action"),
"Kumaraswamy": ("agent1", "action"),
"mixture": ("agent2", "action"),
}

dist_constructor = functools.partial(
CompositeDistribution,
Expand All @@ -84,40 +142,27 @@ def mixture_constructor(logits, loc, scale):
"Kumaraswamy": d.Kumaraswamy,
"mixture": mixture_constructor,
},
name_map={
"gamma": ("agent0", "action"),
"Kumaraswamy": ("agent1", "action"),
"mixture": ("agent2", "action"),
},
aggregate_probabilities=None,
name_map=name_map,
)


policy = ProbSeq(
make_params,
Prob(
in_keys=["params"],
out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
out_keys=list(name_map.values()),
distribution_class=dist_constructor,
return_log_prob=True,
default_interaction_type=InteractionType.RANDOM,
),
)

td = policy(TensorDict(batch_size=[4]))
print("0. result of policy call", td)
print("Result of policy call", td)

dist = policy.get_dist(td)
log_prob = dist.log_prob(
td, aggregate_probabilities=False, inplace=False, include_sum=False
)
print("0. non-aggregated log-prob")

# We can also get the log-prob from the policy directly
log_prob = policy.log_prob(
td, aggregate_probabilities=False, inplace=False, include_sum=False
)
print("0. non-aggregated log-prob (from policy)")
log_prob = dist.log_prob(td)
print("Composite log-prob", log_prob)

# Build a dummy value operator
value_operator = Seq(
Expand All @@ -134,70 +179,12 @@ def mixture_constructor(logits, loc, scale):
TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)),
)

# Instantiate the loss
# Instantiate the loss - test the 3 different PPO losses
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
# PPO sets the keys automatically by looking at the policy
ppo = loss_cls(policy, value_operator)

# Keys are not the default ones - there is more than one action
ppo.set_keys(
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
sample_log_prob=[
("agent0", "action_log_prob"),
("agent1", "action_log_prob"),
("agent2", "action_log_prob"),
],
)

# Get the loss values
loss_vals = ppo(data)
print("0. ", loss_cls, loss_vals)


# ===================================================================
# Example 1: aggregate_probabilities=True ===========================

dist_constructor.keywords["aggregate_probabilities"] = True

td = policy(TensorDict(batch_size=[4]))
print("1. result of policy call", td)

# Instantiate the loss
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
ppo = loss_cls(policy, value_operator)

# Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since
# there is only one.
ppo.set_keys(
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")]
)

# Get the loss values
loss_vals = ppo(data)
print("1. ", loss_cls, loss_vals)


# ===================================================================
# Example 2: aggregate_probabilities=False ===========================

dist_constructor.keywords["aggregate_probabilities"] = False

td = policy(TensorDict(batch_size=[4]))
print("2. result of policy call", td)

# Instantiate the loss
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
ppo = loss_cls(policy, value_operator)

# Keys are not the default ones - there is more than one action
ppo.set_keys(
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
sample_log_prob=[
("agent0", "action_log_prob"),
("agent1", "action_log_prob"),
("agent2", "action_log_prob"),
],
)
print("tensor keys", ppo.tensor_keys)

# Get the loss values
loss_vals = ppo(data)
print("2. ", loss_cls, loss_vals)
print("Loss result:", loss_cls, loss_vals)
2 changes: 2 additions & 0 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
# of the weights.
lw = log_weight.squeeze()
if not isinstance(lw, torch.Tensor):
lw = _sum_td_features(lw)
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
batch = log_weight.shape[0]

Expand Down

0 comments on commit aa372c3

Please sign in to comment.