Skip to content

Commit

Permalink
Remove device_context from trace_context because we don't need it…
Browse files Browse the repository at this point in the history
… there. We can get compilation cache misses (and tracing/lowering cache hit) naturally without putting concrete devices into trace_context.

PiperOrigin-RevId: 718113413
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Jan 22, 2025
1 parent 051861b commit 3aa5599
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 5 deletions.
1 change: 0 additions & 1 deletion jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def trace_context():
return (axis_env_state.value, mesh_context_manager.value,
xla_metadata_context_manager.value,
abstract_mesh_context_manager.value,
device_context.value,
compute_on_context_manager.value, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value,
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,7 @@ def cache(call: Callable, *, explain: Callable | None = None):

def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, new_cache := {}) # type: ignore
key = (fun.transforms, fun.params, fun.in_type, args, config.enable_x64.value,
config.default_device.value, config.trace_context())
key = (fun.transforms, fun.params, fun.in_type, args, config.trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ def weakref_lru_cache(call: Callable, maxsize=2048,
"""
global _weakref_lru_caches
cached_call = xc.weakref_lru_cache(
config.trace_context if trace_context_in_key else _ignore,
call, maxsize)
config.trace_context if trace_context_in_key else _ignore, call, maxsize)
_weakref_lru_caches.add(cached_call)
return cached_call

Expand Down
34 changes: 34 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6336,6 +6336,40 @@ def f(x):
out = hf(arr) # doesn't crash
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))

def test_compilation_cache_miss_when_devices_change(self):
mesh1 = jtu.create_mesh((2, 2), ('x', 'y'))
devs = jax.devices()[:4]
mesh2 = Mesh(np.asarray(devs[::-1]).reshape(2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)

with jax.sharding.use_mesh(mesh1):
arr1 = jax.device_put(np_inp, NamedSharding(mesh1, P('x', 'y')))
with jax.sharding.use_mesh(mesh2):
arr2 = jax.device_put(np_inp, NamedSharding(mesh2, P('x', 'y')))

@jax.jit
def f(x):
return x

with (jtu.count_jit_tracing_cache_miss() as tracing_count,
jtu.count_jit_and_pmap_lowerings() as lowering_count,
jtu.count_jit_compilation_cache_miss() as compilation_count,
jtu.count_pjit_cpp_cache_miss() as cpp_cache_miss_count):
with jax.sharding.use_mesh(mesh1):
out1 = f(arr1)
with jax.sharding.use_mesh(mesh2):
out2 = f(arr2)

self.assertEqual(tracing_count(), 1)
self.assertEqual(lowering_count(), 1)
self.assertEqual(compilation_count(), 2)
self.assertEqual(cpp_cache_miss_count(), 2)

self.assertTupleEqual(out1.sharding._device_assignment,
tuple(mesh1.devices.flat))
self.assertTupleEqual(out2.sharding._device_assignment,
tuple(mesh2.devices.flat))


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 3aa5599

Please sign in to comment.