Skip to content

Commit

Permalink
Use jax.sharding instead of deprecated backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanluoyc committed Feb 22, 2024
1 parent fa68a2f commit 197acb7
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions corax/agents/jax/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,14 @@ def __init__(

# Unpack ActorCore, jitting if requested.
if jit:
self._init = jax.jit(actor.init, backend=backend)
self._policy = jax.jit(actor.select_action, backend=backend)
if backend is not None:
sharding = jax.sharding.SingleDeviceSharding(
jax.local_devices(backend=backend)[0]
)
else:
sharding = None
self._init = jax.jit(actor.init, sharding) # type: ignore
self._policy = jax.jit(actor.select_action, sharding) # type: ignore
else:
self._init = actor.init
self._policy = actor.select_action
Expand Down

0 comments on commit 197acb7

Please sign in to comment.