diff --git a/algo/ppo.py b/algo/ppo.py index f4873d4a7..a93295e84 100644 --- a/algo/ppo.py +++ b/algo/ppo.py @@ -71,7 +71,7 @@ def update(self, rollouts): value_losses_clipped = (value_pred_clipped - return_batch).pow(2) value_loss = .5 * torch.max(value_losses, value_losses_clipped).mean() else: - value_loss = F.mse_loss(return_batch, values) + value_loss = 0.5 * F.mse_loss(return_batch, values) self.optimizer.zero_grad() (value_loss * self.value_loss_coef + action_loss -