forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CI: 01/14/25 upstream sync #202
Open
github-actions
wants to merge
74
commits into
rocm-main
Choose a base branch
from
ci-upstream-sync-87_1
base: rocm-main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Boolean fields in the descriptor struct led to padding, which let random bytes in the string representation of the struct and variance in HLO from run to run.
PiperOrigin-RevId: 713673355
… int4 PiperOrigin-RevId: 713675029
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure. This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
http://github.com/openxla/xla/commit/2f6eabb5a1d0a4ce5ba9eb0d52620463b3ece2c3. PiperOrigin-RevId: 713704549
PiperOrigin-RevId: 713704872
PiperOrigin-RevId: 713705552
PiperOrigin-RevId: 713717495
There were some confusion regarding how to properly add attributes to the op in jax-ml#25767. PiperOrigin-RevId: 713726697
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks to the `geqp3` routine of LAPACK. To provide the same functionality in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK routine via the FFI on CPU devices. Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the use of column-pivoting on CPU devices. To provide a GPU implementation of `geqp3` may require using MAGMA, due to the lack of a `geqp3` implementation in `cuSolver` - see ccb3317 (`jax.lax.linalg.eig`) for an example of using MAGMA in GPU lowerings. Such a GPU implementation can be considered in the future.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up. PiperOrigin-RevId: 713780181
PiperOrigin-RevId: 713789106
PiperOrigin-RevId: 713830757
This exposes an easier way to get StableHLO and HLO with more debugging information (source locations for StableHLO and metadata for HLO).
…structure mismatch Fixes: jax-ml#25140 Previously, the following code: ``` def f(i, x): return lax.switch(i, [lambda x: dict(a=x), lambda x: dict(a=(x, x))], x) f(0, 42) ``` resulted in the error message: ``` TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}). ``` With this change the error message is more specific where the difference is in the pytree structure: ``` TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences: * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ ```
PiperOrigin-RevId: 713958983
PiperOrigin-RevId: 713962512
…vered bugs We previously weren't testing unsigned integer types. PiperOrigin-RevId: 714002869
PiperOrigin-RevId: 714005603
PiperOrigin-RevId: 714027519
…reads enabled. Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile. PiperOrigin-RevId: 714037277
PiperOrigin-RevId: 714048232
PiperOrigin-RevId: 714048619
PiperOrigin-RevId: 714053620
…ass, skipping over elementwise and matmul op insertions and/or type compat casts. PiperOrigin-RevId: 714132282
PiperOrigin-RevId: 714141077
This is implemented by merging multiple indexers into one. PiperOrigin-RevId: 714150733
…x` and `user_mode_ctx` as **private** APIs to make writing auto/user sharding in types code way easier and noise-free. These can be made public in the future under different names. PiperOrigin-RevId: 714169304
PiperOrigin-RevId: 714293671
http://github.com/openxla/xla/commit/53af2044b8881e67fbc38311f4a25997b9561ce4. PiperOrigin-RevId: 714455324
…artially specified out_shardings (i.e. some out_sharding's are None and others are NamedShardings). In this case, the returned out_shardings should all be NamedSharding (because of NamedSharding's presence in some out_sharding's). PiperOrigin-RevId: 714681941
http://github.com/openxla/xla/commit/0cb8b51b761603f83e91d61be1d513aff77f1f82. PiperOrigin-RevId: 714695840
…SubelementMaskOp PiperOrigin-RevId: 714795856
…er TPU gens It is still not efficiently implemented, this is mostly to clean up some logic. We may be able to fuse the creation of masks for different tiles into the creation of a single one. But this is also a problem for the later gens. This also cleans up an unreachable return statement. PiperOrigin-RevId: 714847066
Use absl::call_once instead of a GIL-protected global initialization. In passing, also remove an unused function. PiperOrigin-RevId: 714892175
…he `mlir::tensor::TensorDialect`. This was causing the compiler to crash. PiperOrigin-RevId: 714896947
…arget types This effectively moves some of the Pallas logic to the layer below. PiperOrigin-RevId: 714965374
PiperOrigin-RevId: 714975806
http://github.com/openxla/xla/commit/e27a06fe71dbecc79c2ee636abd4bf23755bd7d6. PiperOrigin-RevId: 715004907
PiperOrigin-RevId: 715008710
PiperOrigin-RevId: 715073762
PiperOrigin-RevId: 715077057
PiperOrigin-RevId: 715085454
PiperOrigin-RevId: 715242560
PiperOrigin-RevId: 715258789
…on the current value of the `use_shardy_partitioner` feature flag. Before the way the API works depends on the value of the flag when the partitioning is defined. But we should allow this to be dynamically swapped in and out when the function is actually called. This change allows for that. PiperOrigin-RevId: 715293018
PiperOrigin-RevId: 715364096
https://opensource.google/documentation/reference/github/services#actions mandates using a specific commit for non-Google actions. PiperOrigin-RevId: 715377970
…ypes by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping. Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered. PiperOrigin-RevId: 715383866
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream