From 7214a3a82a68bba806949c790124a2023cfc5c9e Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 5 Dec 2024 02:06:35 -0800 Subject: [PATCH] [AutoPGLE] Add multi-process test case PiperOrigin-RevId: 703031689 --- jax/_src/cache_key.py | 65 ++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 6e7a421482ce..2ec645cee407 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -238,6 +238,38 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend): _hash_devices(hash_obj, accelerators) _hash_platform(hash_obj, backend) +# LINT.IfChange(xla_flags) +xla_flags_to_exclude_from_cache_key = [ + "--xla_dump_compress_protos", + "--xla_dump_module_metadata", + "--xla_dump_max_hlo_modules", + "--xla_dump_include_timestamp", + "--xla_dump_hlo_pass_re", + "--xla_dump_hlo_module_re", + "--xla_dump_hlo_snapshots", + "--xla_dump_fusion_visualization", + "--xla_dump_hlo_as_url", + "--xla_dump_hlo_as_proto", + "--xla_dump_hlo_as_text", + "--xla_dump_hlo_as_long_text", + "--xla_dump_hlo_as_html", + "--xla_dump_hlo_as_dot", + "--xla_dump_to", + "--xla_force_host_platform_device_count", + "--xla_dump_disable_metadata", + "--xla_dump_hlo_pipeline_re", + "--xla_tpu_sdc_checker_streamz_metric", + "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", + "--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks", + "--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present", + "--xla_gpu_cuda_data_dir", + "--xla_gpu_experimental_autotune_cache_mode", +] + +env_override_flags_to_exclude_from_cache_key = { + x.strip("-") for x in xla_flags_to_exclude_from_cache_key +} +# LINT.ThenChange(:debug_options) def _hash_serialized_compile_options(hash_obj, compile_options_obj, strip_device_assignment=False): @@ -284,6 +316,11 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj, debug_options.xla_gpu_cuda_data_dir = "" # LINT.ThenChange(:xla_flags) + compile_options_copy.env_option_overrides = [ + flag_value + for flag_value in compile_options_copy.env_option_overrides + if flag_value[0] not in env_override_flags_to_exclude_from_cache_key + ] if strip_device_assignment and compile_options_copy.device_assignment: replica_count = compile_options_copy.device_assignment.replica_count() computation_count = compile_options_copy.device_assignment.computation_count() @@ -301,34 +338,6 @@ def _hash_platform(hash_obj, backend): def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]): - # LINT.IfChange(xla_flags) - xla_flags_to_exclude_from_cache_key = [ - "--xla_dump_compress_protos", - "--xla_dump_module_metadata", - "--xla_dump_max_hlo_modules", - "--xla_dump_include_timestamp", - "--xla_dump_hlo_pass_re", - "--xla_dump_hlo_module_re", - "--xla_dump_hlo_snapshots", - "--xla_dump_fusion_visualization", - "--xla_dump_hlo_as_url", - "--xla_dump_hlo_as_proto", - "--xla_dump_hlo_as_text", - "--xla_dump_hlo_as_long_text", - "--xla_dump_hlo_as_html", - "--xla_dump_hlo_as_dot", - "--xla_dump_to", - "--xla_force_host_platform_device_count", - "--xla_dump_disable_metadata", - "--xla_dump_hlo_pipeline_re", - "--xla_tpu_sdc_checker_streamz_metric", - "--xla_tpu_sdc_checker_enable_sdc_event_callbacks", - "--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks", - "--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present", - "--xla_gpu_cuda_data_dir", - ] - # LINT.ThenChange(:debug_options) - xla_flags = [] xla_flags_env_var = os.getenv("XLA_FLAGS")