Skip to content

Commit

Permalink
Roll back the optimized version of jax.block_until_ready due to tes…
Browse files Browse the repository at this point in the history
…t breakage

Reverts 6cc6d09

PiperOrigin-RevId: 581577789
  • Loading branch information
junwhanahn authored and jax authors committed Nov 11, 2023
1 parent 7f39099 commit 55394a0
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 29 deletions.
24 changes: 1 addition & 23 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2940,29 +2940,7 @@ def try_to_block(x):
return x.block_until_ready()
except AttributeError:
return x

if xla_extension_version < 214:
return tree_map(try_to_block, x)

arrays = []
for leaf in tree_leaves(x):
if isinstance(leaf, array.ArrayImpl):
arrays.append(leaf)
else:
try_to_block(leaf)

if not arrays:
# `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
# jax.Array.
pass
elif len(arrays) == 1:
# Fast path for single array.
try_to_block(arrays[0])
else:
# Optimized for multiple arrays.
xc.batched_block_until_ready(arrays)

return x
return tree_map(try_to_block, x)


def clear_backends():
Expand Down
6 changes: 0 additions & 6 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,12 +2409,6 @@ def test_block_until_ready_function(self):
self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False)
self.assertAllClose(pytree[1], np.ones(3), check_dtypes=False)

def test_block_until_ready_numpy_arrays(self):
pytree = (np.ones(1), np.ones(2))
pytree = jax.block_until_ready(pytree)
self.assertAllClose(pytree[0], np.ones(1), check_dtypes=False)
self.assertAllClose(pytree[1], np.ones(2), check_dtypes=False)

def test_devicearray_weakref_friendly(self):
x = device_put(1.)
y = weakref.ref(x)
Expand Down

0 comments on commit 55394a0

Please sign in to comment.