Skip to content

Commit

Permalink
Update xla to use mlir rather than backend-specific-translations
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Aug 12, 2024
1 parent f411fc2 commit b7b8a49
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions envpool/python/xla_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax.core import ShapedArray
from jax.interpreters import xla
from jax.lib import xla_client
from jax import interpreters


def _shape_with_layout(
Expand Down Expand Up @@ -91,10 +92,10 @@ def translation(c: Any, *args: Any, platform: str = "cpu") -> Any:
prim.multiple_results = (len(out_specs) > 1)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(abstract)
xla.backend_specific_translations["cpu"][prim] = partial(
interpreters.mlir["cpu"][prim] = partial(
translation, platform="cpu"
)
xla.backend_specific_translations["gpu"][prim] = partial(
interpreters.mlir["gpu"][prim] = partial(
translation, platform="gpu"
)

Expand Down

0 comments on commit b7b8a49

Please sign in to comment.