Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docs of callbacks #25982

Merged
merged 1 commit into from
Jan 21, 2025
Merged

Conversation

roth-jakob
Copy link
Contributor

Callback functions should not call into JAX. This information was missing in the docs of the callbacks. This commit adds this information to the docs.

See: #25861, #24255

Callback functions should not call into JAX. This information was
missing in the docs of the callbacks. This commit adds this information
to the docs.

See: jax-ml#25861, jax-ml#24255
@Edenhofer
Copy link
Contributor

Thank you @roth-jakob for hunting down the cause of the hangs in NIFTy with recent JAX versions!! Personally, I think this information is worth a big warning box in the doc-string but I might be alone in this view. I think the innocent looking a == b case with a and b being JAX arrays is also worth mentioning explicitly in the docs. It is obvious that this calls back into JAX but I think it is very easy to overlook. Maybe it is even worth suggesting to users to convert JAX arrays to numpy arrays if they intend to do any math and/or comparisons on arguments to a callback. WDYT?

@dfm dfm self-assigned this Jan 20, 2025
Copy link

@stephen-huan stephen-huan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR but on the topic of callback documentation changes, shouldn't jax.debug.callback have guaranteed execution? It doesn't make sense otherwise. (That is, the last column of the jax.debug.callback row should have a ✅ instead of a ❌.)

|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|-------------------------------------|----|----|----|----|----|----|
|{func}`jax.pure_callback` |||| ❌¹ |||
|{func}`jax.experimental.io_callback` ||| ✅/❌² || ✅³ ||
|{func}`jax.debug.callback` |||||||

I can open another PR if need be, sorry for hijacking this one.

Copy link
Collaborator

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @roth-jakob! I think this is a good place to start, but hopefully we can improve the default behavior soon.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Jan 21, 2025
@dfm
Copy link
Collaborator

dfm commented Jan 21, 2025

@stephen-huan — You might look at this discussion for a conversation about your question. @sharadmv's comment there says that JAX has permission to eliminate calls to debug.callback in some cases. I must admit that I'm not totally sure what those cases are, but it might be better to revive that thread or open a new issue for this discsusion.

@stephen-huan
Copy link

@dfm I see, thanks for answering my question. Perhaps that should be another documentation fix, in another PR.

Copy link

@stephen-huan stephen-huan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it's worth, I find the wording "must not include any calls back into JAX" a bit confusing, especially when talking about callbacks ("callbacks" vs "calls back"). I think something like "must not involve any computation with JAX arrays" a bit more clear (with some intentional ambiguity with what "computation" means exactly). My mental model is anything that one can access statically (e.g. .shape, .dtype, etc.) is ok but computation that would be staged out when jit'd is dangerous. I haven't actually tested this intuition.

@copybara-service copybara-service bot merged commit 3b5b981 into jax-ml:main Jan 21, 2025
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants