Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: 01/14/25 upstream sync #202

Open
wants to merge 74 commits into
base: rocm-main
Choose a base branch
from
Open

Conversation

github-actions[bot]
Copy link

Daily sync with upstream

mattjj and others added 30 commits January 8, 2025 23:38
Boolean fields in the descriptor struct led to padding, which let random
bytes in the string representation of the struct and variance in HLO
from run to run.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
There were some confusion regarding how to properly add attributes to the op in jax-ml#25767.

PiperOrigin-RevId: 713726697
A pivoted QR factorization is possible in `scipy.linalg.qr`, thanks
to the `geqp3` routine of LAPACK. To provide the same functionality
in JAX, we implement a new primitive `geqp3_p` which calls the LAPACK
routine via the FFI on CPU devices.

Both `jax.scipy.linalg.qr` and `jax.lax.linalg.qr` now support the
use of column-pivoting on CPU devices.

To provide a GPU implementation of `geqp3` may require using MAGMA,
due to the lack of a `geqp3` implementation in `cuSolver` -  see
ccb3317 (`jax.lax.linalg.eig`) for
an example of using MAGMA in GPU lowerings. Such a GPU implementation
can be considered in the future.
Everything passes other than an io callback test due to the lowered `sdy.manual_computation` returning a token. Will be fixed in a follow-up.

PiperOrigin-RevId: 713780181
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
…structure mismatch

Fixes: jax-ml#25140

Previously, the following code:
```
def f(i, x):
  return lax.switch(i, [lambda x: dict(a=x),
                        lambda x: dict(a=(x, x))], x)
f(0, 42)
```

resulted in the error message:
```
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef({'a': *}) and PyTreeDef({'a': (*, *)}).
```

With this change the error message is more specific where the
difference is in the pytree structure:

```
TypeError: branch 0 output must have same type structure as branch 1 output, but there are differences:
    * at output['a'], branch 0 output has pytree leaf and branch 1 output has <class 'tuple'>, so their Python types differ
```
PiperOrigin-RevId: 713989952
…vered bugs

We previously weren't testing unsigned integer types.

PiperOrigin-RevId: 714002869
…reads enabled.

Add a new jtu.thread_unsafe_test_class() decorator to tag entire `TestCase` classes as thread-hostile.

PiperOrigin-RevId: 714037277
Google-ML-Automation and others added 28 commits January 10, 2025 12:12
…ass, skipping over elementwise and matmul op insertions and/or type compat casts.

PiperOrigin-RevId: 714132282
This is implemented by merging multiple indexers into one.

PiperOrigin-RevId: 714150733
…x` and `user_mode_ctx` as **private** APIs to make writing auto/user sharding in types code way easier and noise-free.

These can be made public in the future under different names.

PiperOrigin-RevId: 714169304
PiperOrigin-RevId: 714293671
…artially specified out_shardings (i.e. some out_sharding's are None and others are NamedShardings).

In this case, the returned out_shardings should all be NamedSharding (because of NamedSharding's presence in some out_sharding's).

PiperOrigin-RevId: 714681941
…SubelementMaskOp

PiperOrigin-RevId: 714795856
…er TPU gens

It is still not efficiently implemented, this is mostly to clean up some logic. We may be able to fuse the creation of masks for different tiles into the creation of a single one. But this is also a problem for the later gens.

This also cleans up an unreachable return statement.

PiperOrigin-RevId: 714847066
Use absl::call_once instead of a GIL-protected global initialization.

In passing, also remove an unused function.

PiperOrigin-RevId: 714892175
…he `mlir::tensor::TensorDialect`. This was causing the compiler to crash.

PiperOrigin-RevId: 714896947
…arget types

This effectively moves some of the Pallas logic to the layer below.

PiperOrigin-RevId: 714965374
PiperOrigin-RevId: 715008710
PiperOrigin-RevId: 715077057
PiperOrigin-RevId: 715242560
…on the current

value of the `use_shardy_partitioner` feature flag.

Before the way the API works depends on the value of the flag when the partitioning is defined. But we should allow this to be dynamically swapped in and out when the function is actually called. This change allows for that.

PiperOrigin-RevId: 715293018
https://opensource.google/documentation/reference/github/services#actions mandates using a specific commit for non-Google actions.

PiperOrigin-RevId: 715377970
…ypes by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.

Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
@github-actions github-actions bot enabled auto-merge January 14, 2025 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.