diff --git a/a2c_ppo_acktr/algo/ppo.py b/a2c_ppo_acktr/algo/ppo.py index 08534ba82..61d5fdde1 100644 --- a/a2c_ppo_acktr/algo/ppo.py +++ b/a2c_ppo_acktr/algo/ppo.py @@ -71,10 +71,9 @@ def update(self, rollouts): value_losses = (values - return_batch).pow(2) value_losses_clipped = ( value_pred_clipped - return_batch).pow(2) - value_loss = 0.5 * torch.max(value_losses, - value_losses_clipped).mean() + value_loss = torch.max(value_losses, value_losses_clipped).mean() else: - value_loss = 0.5 * (return_batch - values).pow(2).mean() + value_loss = (return_batch - values).pow(2).mean() self.optimizer.zero_grad() (value_loss * self.value_loss_coef + action_loss -