Skip to content

Commit

Permalink
[AutoPGLE] Add multi-process test case
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703031689
  • Loading branch information
Google-ML-Automation committed Dec 5, 2024
1 parent 8163e74 commit 7214a3a
Showing 1 changed file with 37 additions and 28 deletions.
65 changes: 37 additions & 28 deletions jax/_src/cache_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,38 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend):
_hash_devices(hash_obj, accelerators)
_hash_platform(hash_obj, backend)

# LINT.IfChange(xla_flags)
xla_flags_to_exclude_from_cache_key = [
"--xla_dump_compress_protos",
"--xla_dump_module_metadata",
"--xla_dump_max_hlo_modules",
"--xla_dump_include_timestamp",
"--xla_dump_hlo_pass_re",
"--xla_dump_hlo_module_re",
"--xla_dump_hlo_snapshots",
"--xla_dump_fusion_visualization",
"--xla_dump_hlo_as_url",
"--xla_dump_hlo_as_proto",
"--xla_dump_hlo_as_text",
"--xla_dump_hlo_as_long_text",
"--xla_dump_hlo_as_html",
"--xla_dump_hlo_as_dot",
"--xla_dump_to",
"--xla_force_host_platform_device_count",
"--xla_dump_disable_metadata",
"--xla_dump_hlo_pipeline_re",
"--xla_tpu_sdc_checker_streamz_metric",
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
"--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks",
"--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present",
"--xla_gpu_cuda_data_dir",
"--xla_gpu_experimental_autotune_cache_mode",
]

env_override_flags_to_exclude_from_cache_key = {
x.strip("-") for x in xla_flags_to_exclude_from_cache_key
}
# LINT.ThenChange(:debug_options)

def _hash_serialized_compile_options(hash_obj, compile_options_obj,
strip_device_assignment=False):
Expand Down Expand Up @@ -284,6 +316,11 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
debug_options.xla_gpu_cuda_data_dir = ""
# LINT.ThenChange(:xla_flags)

compile_options_copy.env_option_overrides = [
flag_value
for flag_value in compile_options_copy.env_option_overrides
if flag_value[0] not in env_override_flags_to_exclude_from_cache_key
]
if strip_device_assignment and compile_options_copy.device_assignment:
replica_count = compile_options_copy.device_assignment.replica_count()
computation_count = compile_options_copy.device_assignment.computation_count()
Expand All @@ -301,34 +338,6 @@ def _hash_platform(hash_obj, backend):


def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]):
# LINT.IfChange(xla_flags)
xla_flags_to_exclude_from_cache_key = [
"--xla_dump_compress_protos",
"--xla_dump_module_metadata",
"--xla_dump_max_hlo_modules",
"--xla_dump_include_timestamp",
"--xla_dump_hlo_pass_re",
"--xla_dump_hlo_module_re",
"--xla_dump_hlo_snapshots",
"--xla_dump_fusion_visualization",
"--xla_dump_hlo_as_url",
"--xla_dump_hlo_as_proto",
"--xla_dump_hlo_as_text",
"--xla_dump_hlo_as_long_text",
"--xla_dump_hlo_as_html",
"--xla_dump_hlo_as_dot",
"--xla_dump_to",
"--xla_force_host_platform_device_count",
"--xla_dump_disable_metadata",
"--xla_dump_hlo_pipeline_re",
"--xla_tpu_sdc_checker_streamz_metric",
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
"--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks",
"--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present",
"--xla_gpu_cuda_data_dir",
]
# LINT.ThenChange(:debug_options)

xla_flags = []

xla_flags_env_var = os.getenv("XLA_FLAGS")
Expand Down

0 comments on commit 7214a3a

Please sign in to comment.