Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: 01/07/25 upstream sync #194

Merged
merged 32 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cb4d97a
Move jex.ffi to jax.ffi.
dfm Dec 20, 2024
20b75ab
Update package indentation fix
Ruturaj4 Jan 2, 2025
4de7794
Merge pull request #25715 from ROCm:ci_build_code_fixes-upstream
Google-ML-Automation Jan 6, 2025
512d545
Temporarily allow deprecation warnings for `scipy.special.lpmn` and `…
dfm Jan 6, 2025
f2e210b
Disable `avxvnniint8` when building with Clang version < 19, or GCC <…
belitskiy Jan 6, 2025
e87a2a5
[shape_poly] Remove old non_negative support.
gnecula Jan 6, 2025
3ff000e
fix the degenerated case
yliu120 Dec 31, 2024
6c87bf3
Fixes tril/triu comments (they were flipped)
marksandler2 Jan 6, 2025
245a13a
Deprecate scipy.special.lpmn & lpmn_values
jakevdp Jan 6, 2025
18b193c
Update XLA dependency to use revision
Google-ML-Automation Jan 6, 2025
74be8bd
Merge pull request #25675 from jakevdp:dep-lpmn
Google-ML-Automation Jan 6, 2025
c39e38f
bazel: export serialization.fbs for downstream usage
maxwillzq Jan 6, 2025
2f7204f
jnp.einsum: default to optimize='auto'
jakevdp Jan 6, 2025
634b45b
Merge pull request #25699 from yliu120:fix_iota
Google-ML-Automation Jan 6, 2025
9f84290
[Mosaic TPU] Validate inserted layout in relayout-insertion pass.
bythew3i Jan 6, 2025
52cc5c7
Merge pull request #25214 from jakevdp:einsum-optimize
Google-ML-Automation Jan 6, 2025
61dd041
Suppress MSAN warnings from SVD that are showing up in CI.
hawkinsp Jan 6, 2025
90d8f37
Rename pybind_extension to nanobind_extension.
hawkinsp Jan 6, 2025
77c6947
fix the doc error: module 'scipy.misc' has no attribute 'face'
zhenying-liu Jan 6, 2025
b49ba65
Remove the need for check_rep for with_sharding_constraint.
pschuh Jan 6, 2025
4caa263
[Mosaic TPU] Add some elementwise canonicalizations
sharadmv Jan 6, 2025
c7b0d68
Remove deprecated jax.experimental.array_api
jakevdp Jan 6, 2025
b304b9e
Merge pull request #25740 from jakevdp:remove-array-api
Google-ML-Automation Jan 7, 2025
23eaf21
Make inspect_array_sharding work without mesh context manager too.
yashk2810 Jan 7, 2025
7997f08
Merge pull request #25728 from zhenying-liu:scipy.misc
Google-ML-Automation Jan 7, 2025
bc3306c
[shape_poly] Improve threefry with symbolic shapes
gnecula Dec 21, 2024
7fb68ca
Fix type signature for __divmod__
shoyer Jan 7, 2025
712bece
Merge pull request #25731 from gnecula:poly_random_even
Google-ML-Automation Jan 7, 2025
56f0f95
Merge pull request #25633 from dfm:move-ffi
Google-ML-Automation Jan 7, 2025
853af56
Merge pull request #25748 from shoyer:divmod
Google-ML-Automation Jan 7, 2025
a7f384c
Add a register_custom_type_id function to the GPU plugins.
dfm Jan 7, 2025
a94ee1f
Unskip unit tests that are now fixed
charleshofer Jan 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ jobs:
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
documentation_render:
Expand Down
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,37 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`).

* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Support added for user defined state in the FFI via the new
{func}`jax.ffi.register_ffi_type_id` function.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.
* The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the
previous import path is deprecated.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.
* The `jax.experimental.array_api` module has been removed after being
deprecated in JAX v0.4.32. Since that release, {mod}`jax.numpy` supports
the array API directly.

## jax 0.4.38 (Dec 17, 2024)

Expand Down
7 changes: 7 additions & 0 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ async def main():
# Enable clang settings that are needed for the build to work with newer
# versions of Clang.
wheel_build_command_base.append("--config=clang")
if clang_major_version < 19:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")

else:
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
logging.debug(
Expand All @@ -477,6 +480,10 @@ async def main():
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")

gcc_major_version = utils.get_gcc_major_version(gcc_path)
if gcc_major_version < 13:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")

if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
if target_cpu == "aarch64":
Expand Down
5 changes: 2 additions & 3 deletions build/rocm/tools/get_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ def install_packages(self, package_specs):
env = dict(os.environ)
if self.pkgbin == "apt":
env["DEBIAN_FRONTEND"] = "noninteractive"

# Update indexes.
subprocess.check_call(["apt-get", "update"])
# Update indexes.
subprocess.check_call(["apt-get", "update"])

LOG.info("Running %r" % cmd)
subprocess.check_call(cmd, env=env)
Expand Down
12 changes: 12 additions & 0 deletions build/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,18 @@ def get_clang_major_version(clang_path):

return major_version

def get_gcc_major_version(gcc_path: str):
gcc_version_proc = subprocess.run(
[gcc_path, "-dumpversion"],
check=True,
capture_output=True,
text=True,
)
major_version = int(gcc_version_proc.stdout)

return major_version


def get_jax_configure_bazel_options(bazel_command: list[str]):
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
# Get the index of the "run" parameter. Build options will come after "run" so
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,4 +362,5 @@ def linkcode_resolve(domain, info):
'jax-101/index.rst': 'tutorials.rst',
'notebooks/external_callbacks.md': 'external-callbacks.md',
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
'jax.extend.ffi.rst': 'jax.ffi.rst',
}
2 changes: 1 addition & 1 deletion docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ You can find the up-to-date command to run doctests in
E.g., you can run:

```
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
```

Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
Expand Down
Loading
Loading