Skip to content

Commit

Permalink
Fix cache init when JAX Array is created early (#25768)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stella-S-Yan committed Jan 16, 2025
1 parent d1810b4 commit 46d542f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
9 changes: 5 additions & 4 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def _initialize_cache() -> None:
with _cache_initialized_mutex:
if _cache_initialized:
return

path: str | None = config.compilation_cache_dir.value
# If the path is not set, the cache will not be built.
if not path:
return
_cache_initialized = True

# Nothing to do if the cache is disabled.
Expand All @@ -146,10 +151,6 @@ def _initialize_cache() -> None:

global _cache
assert _cache is None, "The cache has already been initialized!"
path: str | None = config.compilation_cache_dir.value
# If the path is not set, the cache will not be enabled.
if not path:
return

cache_and_path = get_file_cache(path)
if cache_and_path is None:
Expand Down
24 changes: 24 additions & 0 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,30 @@ def test_jit(self):
f(1.0)
self.assertEqual(count_cache_items(), 2)


def test_change_compilation_cache_dir_reset_cache(self):
original_value = config.compilation_cache_dir.value
try:
cc.reset_cache()
self.assertFalse(cc._cache_initialized)
self.assertFalse(cc.is_persistent_cache_enabled())

a = jnp.zeros((2,3))
self.assertFalse(cc._cache_initialized)
self.assertFalse(cc.is_persistent_cache_enabled())

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
@jax.jit
def f(x):
return x * 10
f(a)
self.assertTrue(cc._cache_initialized)
self.assertTrue(cc.is_persistent_cache_enabled())

finally:
config.update("jax_compilation_cache_dir", original_value)


def test_xla_autofdo_profile_version(self):
original_profile_version = config.jax_xla_profile_version.value
with config.jax_xla_profile_version(original_profile_version + 1):
Expand Down

0 comments on commit 46d542f

Please sign in to comment.