Skip to content

Commit

Permalink
Fix cache init when JAX Array is created early (#25768)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stella-S-Yan committed Jan 16, 2025
1 parent d1810b4 commit 6aef390
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
10 changes: 6 additions & 4 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def _initialize_cache() -> None:
with _cache_initialized_mutex:
if _cache_initialized:
return

path: str | None = config.compilation_cache_dir.value
# If the path is not set, the cache will not be built.
if not path:
return
_cache_initialized = True

# Nothing to do if the cache is disabled.
Expand All @@ -146,10 +151,7 @@ def _initialize_cache() -> None:

global _cache
assert _cache is None, "The cache has already been initialized!"
path: str | None = config.compilation_cache_dir.value
# If the path is not set, the cache will not be enabled.
if not path:
return


cache_and_path = get_file_cache(path)
if cache_and_path is None:
Expand Down
1 change: 1 addition & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ def validator(new_val):
update_thread_local_hook=update_thread_local_hook,
validator=validator)


def string_or_object_state(
name: str,
default: Any,
Expand Down
28 changes: 28 additions & 0 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,30 @@ def test_jit(self):
f(1.0)
self.assertEqual(count_cache_items(), 2)


def test_change_compilation_cache_dir_reset_cache(self):
original_value = config.compilation_cache_dir.value
try:
cc.reset_cache()
self.assertFalse(cc._cache_initialized)
self.assertFalse(cc.is_persistent_cache_enabled())

a = jnp.zeros((2,3))
self.assertFalse(cc._cache_initialized)
self.assertFalse(cc.is_persistent_cache_enabled())

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
@jax.jit
def f(x):
return x * 10
f(a)
self.assertTrue(cc._cache_initialized)
self.assertTrue(cc.is_persistent_cache_enabled())

finally:
config.update("jax_compilation_cache_dir", original_value)


def test_xla_autofdo_profile_version(self):
original_profile_version = config.jax_xla_profile_version.value
with config.jax_xla_profile_version(original_profile_version + 1):
Expand Down Expand Up @@ -568,6 +592,9 @@ def test_persistent_cache_enable_xla_caches(self):
self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir")
self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE)




@jtu.with_config(
jax_enable_compilation_cache=False,
jax_persistent_cache_min_compile_time_secs=0,
Expand Down Expand Up @@ -614,5 +641,6 @@ def test_persistent_cache_enable_xla_caches_disabled(self):
self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "")
self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 6aef390

Please sign in to comment.