Skip to content

Commit

Permalink
[AutoPGLE] Explicitly ignore host callback pointers
Browse files Browse the repository at this point in the history
Before this change users had to specify remove_custom_partitioning_ptr_from_cache_key config flag when using AutoPGLE.

PiperOrigin-RevId: 700289965
  • Loading branch information
Google-ML-Automation committed Nov 26, 2024
1 parent b6566c8 commit 231967f
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 55 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ pytype_strict_library(
name = "compiler",
srcs = ["_src/compiler.py"],
deps = [
":cache_key",
":compilation_cache_internal",
":config",
":mlir",
Expand Down
115 changes: 79 additions & 36 deletions jax/_src/cache_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
import enum
import hashlib
import io
import logging
Expand Down Expand Up @@ -62,11 +63,23 @@ def custom_hook() -> str:
return ""


def get(module: ir.Module,
devices: np.ndarray,
compile_options: xla_client.CompileOptions,
backend: xla_client.Client,
compression_algorithm: str = "zstandard") -> str:
class IgnoreCallbacks(enum.IntEnum):
# Do not remove any callback pointers from precompiled IR.
NO = enum.auto()
# Remove all callback pointers from precompiled IR.
ALL = enum.auto()
# Remove only custom_partitioning callback pointer from precompiled IR.
CUSTOM_PARTITIONING = enum.auto()


def get(
module: ir.Module,
devices: np.ndarray,
compile_options: xla_client.CompileOptions,
backend: xla_client.Client,
compression_algorithm: str = "zstandard",
ignore_callbacks: IgnoreCallbacks = IgnoreCallbacks.NO,
) -> str:
"""Creates a hashed string to use as a key to the compilation cache.
Creates a cache key that is a hex-encoded string of a unique hash based on
Expand All @@ -79,28 +92,47 @@ def get(module: ir.Module,
backend: description of the platform (e.g., TPU version)
compression_algorithm: a string representing the compression algorithm used
for the executable before persisting in the cache
ignore_callbacks: whether to remove the all callback pointer from the
computation.
Typical return value example:
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
"""
entries = [
("computation",
lambda hash_obj: _hash_computation(hash_obj, module)),
("jax_lib version",
lambda hash_obj: hash_obj.update(
bytes(jaxlib_version_str.encode("utf-8")))),
("XLA flags",
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())),
("compile_options",
lambda hash_obj: _hash_serialized_compile_options(
hash_obj, compile_options,
# In case of GPU multi-process tasks we need to strip device
# assignment to use cache key as invariant between processes.
strip_device_assignment=(backend.platform == "gpu"))),
("accelerator_config",
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
("compression",
lambda hash_obj: _hash_string(hash_obj, compression_algorithm)),
(
"computation",
lambda hash_obj: _hash_computation(
hash_obj, module, ignore_callbacks
),
),
(
"jax_lib version",
lambda hash_obj: hash_obj.update(
bytes(jaxlib_version_str.encode("utf-8"))
),
),
(
"XLA flags",
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()),
),
(
"compile_options",
lambda hash_obj: _hash_serialized_compile_options(
hash_obj,
compile_options,
# In case of GPU multi-process tasks we need to strip device
# assignment to use cache key as invariant between processes.
strip_device_assignment=(backend.platform == "gpu"),
),
),
(
"accelerator_config",
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend),
),
(
"compression",
lambda hash_obj: _hash_string(hash_obj, compression_algorithm),
),
("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())),
]

Expand Down Expand Up @@ -131,45 +163,56 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
)


def _remove_custom_partitioning_ptr(m: ir.Module):
"""
Removes custom_partitioning callback pointer from precompiled IR.
def _remove_callbacks(m: ir.Module, ignore_callbacks: IgnoreCallbacks):
"""Removes callback pointers from precompiled IR.
Python function pointers are not deterministic across executions.
"""
def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult:
if (op.name == "stablehlo.custom_call" and
op.attributes["call_target_name"].value == "CustomSPMDPartitioning"):
if op.name == "stablehlo.custom_call" and (
(
ignore_callbacks == IgnoreCallbacks.ALL
and op.attributes["call_target_name"].value.endswith("callback")
)
or op.attributes["call_target_name"].value == "CustomSPMDPartitioning"
):
op.attributes["backend_config"] = ir.StringAttr.get("REMOVED")
return ir.WalkResult.ADVANCE

if ignore_callbacks == IgnoreCallbacks.NO:
return m

m.operation.walk(_update_bc_attribute)
return m


def _serialize_ir(m: ir.Module) -> bytes:
def _serialize_ir(m: ir.Module, ignore_callbacks: IgnoreCallbacks) -> bytes:
output = io.BytesIO()
if config.remove_custom_partitioning_ptr_from_cache_key.value:
m = _remove_custom_partitioning_ptr(type_cast(ir.Module,
m.operation.clone()))
if ignore_callbacks != IgnoreCallbacks.NO:
m = _remove_callbacks(
type_cast(ir.Module, m.operation.clone()), ignore_callbacks
)
m.operation.write_bytecode(file=output)
return output.getvalue()


def _canonicalize_ir(m_original: ir.Module) -> bytes:
def _canonicalize_ir(
m_original: ir.Module, ignore_callbacks: IgnoreCallbacks
) -> bytes:
with m_original.context:
m = type_cast(ir.Module, m_original.operation.clone())
passes = pm.PassManager.parse(
"builtin.module(strip-debuginfo)"
)
passes.run(m.operation)
return _serialize_ir(m)
return _serialize_ir(m, ignore_callbacks)


def _hash_computation(hash_obj, module):
def _hash_computation(hash_obj, module, ignore_callbacks: IgnoreCallbacks):
if config.compilation_cache_include_metadata_in_key.value:
canonical_ir = _serialize_ir(module)
canonical_ir = _serialize_ir(module, ignore_callbacks)
else:
canonical_ir = _canonicalize_ir(module)
canonical_ir = _canonicalize_ir(module, ignore_callbacks)
hash_obj.update(canonical_ir)


Expand Down
21 changes: 15 additions & 6 deletions jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,21 @@ def put_executable_and_time(
cache.put(cache_key, executable_and_time)


def get_cache_key(module: ir.Module,
devices: np.ndarray,
compile_options,
backend) -> str:
return cache_key.get(module, devices, compile_options, backend,
"zstandard" if zstandard is not None else "zlib")
def get_cache_key(
module: ir.Module,
devices: np.ndarray,
compile_options,
backend,
ignore_callbacks: cache_key.IgnoreCallbacks = cache_key.IgnoreCallbacks.NO,
) -> str:
return cache_key.get(
module,
devices,
compile_options,
backend,
"zstandard" if zstandard is not None else "zlib",
ignore_callbacks,
)


def is_initialized() -> bool:
Expand Down
28 changes: 24 additions & 4 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Any, Callable
import warnings

from jax._src import cache_key as cache_key_type
from jax._src import compilation_cache
from jax._src import config as config
from jax._src import distributed
Expand All @@ -33,8 +34,8 @@
from jax._src import profiler
from jax._src import traceback_util
from jax._src.interpreters import mlir
from jax._src.lib import xla_client as xc
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
import numpy as np

Expand Down Expand Up @@ -351,8 +352,18 @@ def compile_or_get_cached(
monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache')

try:
if config.remove_custom_partitioning_ptr_from_cache_key.value:
ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING
else:
ignore_callbacks = cache_key_type.IgnoreCallbacks.NO

cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend)
computation,
devices,
compile_options,
backend,
ignore_callbacks=ignore_callbacks,
)
except xc._xla.XlaRuntimeError as ex:
logger.error("compile_or_get_cached: unable to generate cache key, "
"skipping the cache: %s", ex)
Expand Down Expand Up @@ -385,7 +396,12 @@ def compile_or_get_cached(
compile_options.executable_build_options.fdo_profile = b"pgle profiled"

pgle_profiled_module_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend)
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
)
compile_options.executable_build_options.fdo_profile = fdo_profile

if _is_executable_in_cache(backend, pgle_profiled_module_key):
Expand Down Expand Up @@ -493,7 +509,11 @@ def _share_fdo_profiles(
compile_options.executable_build_options.fdo_profile = b""
profile_key = (
compilation_cache.get_cache_key(
computation, devices, compile_options, backend
computation,
devices,
compile_options,
backend,
cache_key_type.IgnoreCallbacks.ALL,
)
+ "_fdo_sync"
)
Expand Down
63 changes: 54 additions & 9 deletions tests/cache_key_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _infer_sharding_from_operands(mesh, arg_shapes, result_shape):

@custom_partitioning
def _cp_add(x, y):
return jax.numpy.add(x, y)
return jax.numpy.add(x, y)

_cp_add.def_partition(
infer_sharding_from_operands=_infer_sharding_from_operands,
Expand All @@ -199,14 +199,59 @@ def _cp_add(x, y):
r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
r'\}'
)
with config.remove_custom_partitioning_ptr_from_cache_key(True):
with computation.context:
updated_module = cache_key._remove_custom_partitioning_ptr(
type_cast(ir.Module, computation.operation.clone()))
bcs = [match[2] for
match in re.findall(pattern, str(updated_module), re.DOTALL)]
for bc in bcs:
self.assertEqual(bc, "REMOVED")
with computation.context:
updated_module = cache_key._remove_callbacks(
type_cast(ir.Module, computation.operation.clone()),
ignore_callbacks=cache_key.IgnoreCallbacks.ALL,
)
bcs = [
match[2]
for match in re.findall(pattern, str(updated_module), re.DOTALL)
]
for bc in bcs:
self.assertEqual(bc, "REMOVED")

compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
hash_without_callback_ptrs = cache_key.get(
computation,
devices,
compile_options,
backend,
ignore_callbacks=cache_key.IgnoreCallbacks.CUSTOM_PARTITIONING,
)
expected_hash = cache_key.get(
updated_module, devices, compile_options, backend
)
self.assertEqual(expected_hash, hash_without_callback_ptrs)

@jtu.skip_on_devices("cpu")
def test_host_callbacks_ptrs_removed(self):
def _host_callback(x, y):
jax.debug.print("x={x[0]} y={y[0]}", x=x, y=y)

computation = (
jax.jit(_host_callback)
.lower(
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
)
.compiler_ir()
)
pattern = r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
with computation.context:
updated_module = cache_key._remove_callbacks(
type_cast(ir.Module, computation.operation.clone()),
ignore_callbacks=cache_key.IgnoreCallbacks.ALL,
)
bcs = [
match[1]
for match in re.findall(pattern, str(updated_module), re.DOTALL)
]
for bc in bcs:
self.assertEqual(bc, "REMOVED")

def test_different_device_assignment(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
Expand Down

0 comments on commit 231967f

Please sign in to comment.