-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update docs of callbacks #25982
Update docs of callbacks #25982
Conversation
Callback functions should not call into JAX. This information was missing in the docs of the callbacks. This commit adds this information to the docs. See: jax-ml#25861, jax-ml#24255
Thank you @roth-jakob for hunting down the cause of the hangs in NIFTy with recent JAX versions!! Personally, I think this information is worth a big warning box in the doc-string but I might be alone in this view. I think the innocent looking |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to this PR but on the topic of callback documentation changes, shouldn't jax.debug.callback
have guaranteed execution? It doesn't make sense otherwise. (That is, the last column of the jax.debug.callback
row should have a ✅ instead of a ❌.)
jax/docs/external-callbacks.md
Lines 76 to 80 in 4f8699c
|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution | | |
|-------------------------------------|----|----|----|----|----|----| | |
|{func}`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ | | |
|{func}`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ | | |
|{func}`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | |
I can open another PR if need be, sorry for hijacking this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this @roth-jakob! I think this is a good place to start, but hopefully we can improve the default behavior soon.
@stephen-huan — You might look at this discussion for a conversation about your question. @sharadmv's comment there says that JAX has permission to eliminate calls to |
@dfm I see, thanks for answering my question. Perhaps that should be another documentation fix, in another PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For what it's worth, I find the wording "must not include any calls back into JAX" a bit confusing, especially when talking about callbacks ("callbacks" vs "calls back"). I think something like "must not involve any computation with JAX arrays" a bit more clear (with some intentional ambiguity with what "computation" means exactly). My mental model is anything that one can access statically (e.g. .shape
, .dtype
, etc.) is ok but computation that would be staged out when jit
'd is dangerous. I haven't actually tested this intuition.
Callback functions should not call into JAX. This information was missing in the docs of the callbacks. This commit adds this information to the docs.
See: #25861, #24255