diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index d8724e42975e..c5910f852bbd 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -131,6 +131,11 @@ def _initialize_cache() -> None: with _cache_initialized_mutex: if _cache_initialized: return + + path: str | None = config.compilation_cache_dir.value + # If the path is not set, the cache will not be built. + if not path: + return _cache_initialized = True # Nothing to do if the cache is disabled. @@ -146,10 +151,7 @@ def _initialize_cache() -> None: global _cache assert _cache is None, "The cache has already been initialized!" - path: str | None = config.compilation_cache_dir.value - # If the path is not set, the cache will not be enabled. - if not path: - return + cache_and_path = get_file_cache(path) if cache_and_path is None: diff --git a/jax/_src/config.py b/jax/_src/config.py index 99d8f33158db..7100d0d3696d 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -705,6 +705,7 @@ def validator(new_val): update_thread_local_hook=update_thread_local_hook, validator=validator) + def string_or_object_state( name: str, default: Any, diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index ef245bc8d3ff..f2be97233b9d 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -198,6 +198,30 @@ def test_jit(self): f(1.0) self.assertEqual(count_cache_items(), 2) + + def test_change_compilation_cache_dir_reset_cache(self): + original_value = config.compilation_cache_dir.value + try: + cc.reset_cache() + self.assertFalse(cc._cache_initialized) + self.assertFalse(cc.is_persistent_cache_enabled()) + + a = jnp.zeros((2,3)) + self.assertFalse(cc._cache_initialized) + self.assertFalse(cc.is_persistent_cache_enabled()) + + jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") + @jax.jit + def f(x): + return x * 10 + f(a) + self.assertTrue(cc._cache_initialized) + self.assertTrue(cc.is_persistent_cache_enabled()) + + finally: + config.update("jax_compilation_cache_dir", original_value) + + def test_xla_autofdo_profile_version(self): original_profile_version = config.jax_xla_profile_version.value with config.jax_xla_profile_version(original_profile_version + 1): @@ -568,6 +592,9 @@ def test_persistent_cache_enable_xla_caches(self): self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, f"jax-cache{s}xla_gpu_per_fusion_autotune_cache_dir") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + + + @jtu.with_config( jax_enable_compilation_cache=False, jax_persistent_cache_min_compile_time_secs=0, @@ -614,5 +641,6 @@ def test_persistent_cache_enable_xla_caches_disabled(self): self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())