Skip to content

Commit

Permalink
Fix deprecated usage of JAX symbols. (#282)
Browse files Browse the repository at this point in the history
## Description

Remove deprecated JAX symbols.

## Motivation and Context

* `jax.abstract_array.ShapedArray` is removed for jax>=0.4.16. See


https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0416-sept-18-2023

* `xla_client.register_cpu_custom_call_target` is removed in


openxla/xla@0ab5486

close #281

- [x] I have raised an issue to propose this change
([required](https://envpool.readthedocs.io/en/latest/pages/contributing.html)
for new features and bug fixes)

## Types of changes

What types of changes does your code introduce? Put an `x` in all the
boxes that apply:

- [x] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds core functionality)
- [x] New environment (non-breaking change which adds 3rd-party
environment)
- [x] Breaking change (fix or feature that would cause existing
functionality to change)
- [x] Documentation (update in the documentation)
- [x] Example (update in the folder of example)

## Implemented Tasks

- [x] Fix deprecated symbol

## Checklist

Go over all the following points, and put an `x` in all the boxes that
apply.
If you are unsure about any of these, don't hesitate to ask. We are here
to help!

- [x] I have read the
[CONTRIBUTION](https://envpool.readthedocs.io/en/latest/pages/contributing.html)
guide (**required**)
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a
new feature*).
- [ ] I have updated the documentation accordingly.
- [x] I have reformatted the code using `make format` (**required**)
- [x] I have checked the code using `make lint` (**required**)
- [x] I have ensured `make bazel-test` pass. (**required**)
  • Loading branch information
ethanluoyc authored Oct 26, 2023
1 parent 47ad258 commit a9d2ec9
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions envpool/python/xla_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
from jax import core, dtypes
from jax import numpy as jnp
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.interpreters import xla
from jax.lib import xla_client

Expand Down Expand Up @@ -52,9 +52,10 @@ def _make_xla_function(
in_specs = _normalize_specs(in_specs)
out_specs = _normalize_specs(out_specs)
cpu_capsule, gpu_capsule = capsules
xla_client.register_cpu_custom_call_target(
xla_client.register_custom_call_target(
f"{type(obj).__name__}_{id(obj)}_{name}_cpu".encode(),
cpu_capsule,
platform="cpu"
)
xla_client.register_custom_call_target(
f"{type(obj).__name__}_{id(obj)}_{name}_gpu".encode(),
Expand Down

0 comments on commit a9d2ec9

Please sign in to comment.