Skip to content

Commit

Permalink
Merge branch 'main' into fix_segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Oct 26, 2023
2 parents d478cda + a9d2ec9 commit 61ae101
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions envpool/python/xla_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np
from jax import core, dtypes
from jax import numpy as jnp
from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.interpreters import xla
from jax.lib import xla_client

Expand Down Expand Up @@ -52,9 +52,10 @@ def _make_xla_function(
in_specs = _normalize_specs(in_specs)
out_specs = _normalize_specs(out_specs)
cpu_capsule, gpu_capsule = capsules
xla_client.register_cpu_custom_call_target(
xla_client.register_custom_call_target(
f"{type(obj).__name__}_{id(obj)}_{name}_cpu".encode(),
cpu_capsule,
platform="cpu"
)
xla_client.register_custom_call_target(
f"{type(obj).__name__}_{id(obj)}_{name}_gpu".encode(),
Expand Down

0 comments on commit 61ae101

Please sign in to comment.