-
Notifications
You must be signed in to change notification settings - Fork 686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SAC jax #300
base: master
Are you sure you want to change the base?
SAC jax #300
Conversation
The latest updates on your projects. Learn more about Vercel for Git ↗︎
|
@vwxyzjn tests fails because |
It turns out the culprit is the following changes -classic_control = ["pygame (==2.1.0)"]
+classic-control = ["pygame (==2.1.0)"] We install pygame by |
@vwxyzjn I think I'm done for the implementation, I added support for constant entropy coeff and for deterministic eval. |
Perhaps it's because in #217 I implemented my own normal distribution I am trying to do the same for SAC... However if I replaced
with the log probability taken from https://github.com/openai/baselines/blob/9b68103b737ac46bc201dfb3121cfa5df2127e53/baselines/common/distributions.py#L238-L241
things kind of fall catastrophically... I felt that maybe implementing our own would bring greater transparency but maybe not be necessary... |
Aha! I got it, it's supposed to be the following
Interestingly, the paper seems to say our implementation should have been the following (with the summation)
but empirically, it doesn't perform as well... @dosssman any thoughts? |
Not sure to follow the difference... You can take a look at how we do it in SB3, I think it is what is described: |
I tried to implement the probability distribution ourselves 0cf0e9e, but hit a performance regression. Looking into the issue deeper, I couldn't quite understand how
I am quite puzzled. |
@vwxyzjn run the code with I guess the answer to your question is called numerical precision ;). EDIT: the code from tf distribution is here: https://github.com/tensorflow/probability/blob/bcdf53024ef9f35d81be063093ccfb3a762dab3f/tensorflow_probability/python/bijectors/tanh.py#L70-L81 # We implicitly rely on _forward_log_det_jacobian rather than explicitly
# implement _inverse_log_det_jacobian since directly using
# `-tf.math.log1p(-tf.square(y))` has lower numerical precision.
def _forward_log_det_jacobian(self, x):
# This formula is mathematically equivalent to
# `tf.log1p(-tf.square(tf.tanh(x)))`, however this code is more numerically
# stable.
# Derivation:
# log(1 - tanh(x)^2)
# = log(sech(x)^2)
# = 2 * log(sech(x))
# = 2 * log(2e^-x / (e^-2x + 1))
# = 2 * (log(2) - x - log(e^-2x + 1))
# = 2 * (log(2) - x - softplus(-2x))
return 2. * (np.log(2.) - x - tf.math.softplus(-2. * x)) |
@vwxyzjn as a follow up, if you remove the EDIT: I don't know why precommit fails, it does work locally |
@araffin |
Hi, is there any update/blocking thing on this? |
@vwxyzjn I would need your help again to update the lockfile, I tried to do it locally and poetry destroyed my conda env... |
Description
Missing: benchmark and doc
Adapted from https://github.com/araffin/sbx
Report (3 seeds on 3 MuJoCo envs): https://wandb.ai/openrlbenchmark/cleanrl/reports/SAC-jax---VmlldzoyODM4MjU0
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See #137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).