From 3dc8dacae3d5ed5ddb285f0eb68c76672daac31b Mon Sep 17 00:00:00 2001 From: Stella S Yan Date: Wed, 15 Jan 2025 00:02:03 +0000 Subject: [PATCH] Fix cache init when JAX Array is created early (#25768) --- jax/_src/compilation_cache.py | 12 ++++++----- tests/compilation_cache_test.py | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index d8724e42975e..3b3b1bc88924 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -131,13 +131,19 @@ def _initialize_cache() -> None: with _cache_initialized_mutex: if _cache_initialized: return - _cache_initialized = True + + path: str | None = config.compilation_cache_dir.value + # If the path is not set, the cache will not be built. + if not path: + return # Nothing to do if the cache is disabled. if not _is_cache_enabled(): logger.debug("_initialize_cache: cache is disabled!") return + _cache_initialized = True + # Set the minimum cache size entry only if the flag # --jax_persistent_cache_min_entry_size_bytes has not been set. if config.persistent_cache_min_entry_size_bytes.value == 0: @@ -146,10 +152,6 @@ def _initialize_cache() -> None: global _cache assert _cache is None, "The cache has already been initialized!" - path: str | None = config.compilation_cache_dir.value - # If the path is not set, the cache will not be enabled. - if not path: - return cache_and_path = get_file_cache(path) if cache_and_path is None: diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index ef245bc8d3ff..476409e59849 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -198,6 +198,23 @@ def test_jit(self): f(1.0) self.assertEqual(count_cache_items(), 2) + def test_set_compilation_cache_dir_reset_cache(self): + with config.compilation_cache_dir(None): + cc.reset_cache() + backend = xla_bridge.get_backend() + + a = jnp.zeros((2,3)) + self.assertFalse(cc.is_persistent_cache_enabled()) + cache = cc._get_cache(backend) + self.assertIsNone(cache) # Not able to create cache + + with config.compilation_cache_dir("jax-cache"): + f = jit(lambda x: x + 1) + f(a) # Compile and cache + self.assertTrue(cc.is_persistent_cache_enabled()) + cache = cc._get_cache(backend) + self.assertIsNotNone(cache) # Cache is created + def test_xla_autofdo_profile_version(self): original_profile_version = config.jax_xla_profile_version.value with config.jax_xla_profile_version(original_profile_version + 1): @@ -587,6 +604,26 @@ def test_jit(self): f(1) self.assertEqual(count_cache_items(), 0) + def test_enable_compilation_cache(self): + with ( + config.enable_compilation_cache(False), + config.compilation_cache_dir("jax-cache") + ): + cc.reset_cache() # reset cache before testing + backend = xla_bridge.get_backend() + f = jit(lambda x: x + 1) + f(1) # Compile and cache + cache = cc._get_cache(backend) + self.assertIsNone(cache) # Cache should not exist + + with config.enable_compilation_cache(True): + cc.reset_cache() + backend = xla_bridge.get_backend() + g = jit(lambda x: x * 3) + g(2) + cache = cc._get_cache(backend) + self.assertIsNotNone(cache) # Cache should be initalized + def test_tasks_disable_cache_metric(self): with config.enable_compilation_cache(False): count_before_first_use = _counts[