Skip to content

Commit

Permalink
Fix cache init when JAX Array is created early (#25768)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stella-S-Yan committed Jan 21, 2025
1 parent d1810b4 commit 3dc8dac
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
12 changes: 7 additions & 5 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,19 @@ def _initialize_cache() -> None:
with _cache_initialized_mutex:
if _cache_initialized:
return
_cache_initialized = True

path: str | None = config.compilation_cache_dir.value
# If the path is not set, the cache will not be built.
if not path:
return

# Nothing to do if the cache is disabled.
if not _is_cache_enabled():
logger.debug("_initialize_cache: cache is disabled!")
return

_cache_initialized = True

# Set the minimum cache size entry only if the flag
# --jax_persistent_cache_min_entry_size_bytes has not been set.
if config.persistent_cache_min_entry_size_bytes.value == 0:
Expand All @@ -146,10 +152,6 @@ def _initialize_cache() -> None:

global _cache
assert _cache is None, "The cache has already been initialized!"
path: str | None = config.compilation_cache_dir.value
# If the path is not set, the cache will not be enabled.
if not path:
return

cache_and_path = get_file_cache(path)
if cache_and_path is None:
Expand Down
37 changes: 37 additions & 0 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,23 @@ def test_jit(self):
f(1.0)
self.assertEqual(count_cache_items(), 2)

def test_set_compilation_cache_dir_reset_cache(self):
with config.compilation_cache_dir(None):
cc.reset_cache()
backend = xla_bridge.get_backend()

a = jnp.zeros((2,3))
self.assertFalse(cc.is_persistent_cache_enabled())
cache = cc._get_cache(backend)
self.assertIsNone(cache) # Not able to create cache

with config.compilation_cache_dir("jax-cache"):
f = jit(lambda x: x + 1)
f(a) # Compile and cache
self.assertTrue(cc.is_persistent_cache_enabled())
cache = cc._get_cache(backend)
self.assertIsNotNone(cache) # Cache is created

def test_xla_autofdo_profile_version(self):
original_profile_version = config.jax_xla_profile_version.value
with config.jax_xla_profile_version(original_profile_version + 1):
Expand Down Expand Up @@ -587,6 +604,26 @@ def test_jit(self):
f(1)
self.assertEqual(count_cache_items(), 0)

def test_enable_compilation_cache(self):
with (
config.enable_compilation_cache(False),
config.compilation_cache_dir("jax-cache")
):
cc.reset_cache() # reset cache before testing
backend = xla_bridge.get_backend()
f = jit(lambda x: x + 1)
f(1) # Compile and cache
cache = cc._get_cache(backend)
self.assertIsNone(cache) # Cache should not exist

with config.enable_compilation_cache(True):
cc.reset_cache()
backend = xla_bridge.get_backend()
g = jit(lambda x: x * 3)
g(2)
cache = cc._get_cache(backend)
self.assertIsNotNone(cache) # Cache should be initalized

def test_tasks_disable_cache_metric(self):
with config.enable_compilation_cache(False):
count_before_first_use = _counts[
Expand Down

0 comments on commit 3dc8dac

Please sign in to comment.