diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index a55432c5fe5c..ae755e5ef92d 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -454,9 +454,10 @@ def f_(*args): out_tree = lambda: tree_structure(out_shape) assert len(jaxpr.invars) == len(in_leaves) dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals") - return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore + return _saved_residuals(jaxpr, dbg.arg_names) -def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]: +def _saved_residuals(jaxpr: core.Jaxpr, + arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]: res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} @@ -471,7 +472,7 @@ def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]: for i, v in enumerate(jaxpr.invars): if v in res_vars: - if arg_names is not None: + if arg_names[i] is not None: src = f'from the argument {arg_names[i]}' else: src = 'from the argument at flattened index {i}' @@ -587,7 +588,8 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, _, staged_unk = partition_list(in_used_staged, in_unknowns) res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:]) res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:] - body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None) + body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), + ("",) * len(jaxpr_known.invars)) logger.log(log_level, 'remat-decorated function ' + 'saving inputs with shapes:\n' * bool(res_invars) + diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index f35ae05850e9..5c3b5b253814 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -18,6 +18,7 @@ import inspect import operator from functools import partial, lru_cache +import re from typing import Any from jax._src import core @@ -603,15 +604,13 @@ def tracing_debug_info( # TODO(necula): check if we really need this, e.g., to speed up tracing. sourceinfo: str | None = None, signature: inspect.Signature | None = None, -) -> TracingDebugInfo | None: +) -> TracingDebugInfo: if sourceinfo is None: sourceinfo = fun_sourceinfo(fun) if signature is None: signature = fun_signature(fun) arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, static_argnames) - if arg_names is None: - return None return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) @@ -624,12 +623,13 @@ def fun_signature(fun: Callable) -> inspect.Signature | None: def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable): # Prefer this to functools.wraps because it does not create a reference to # the wrapped function. - sourceinfo = fun_sourceinfo(wrapped) - if sourceinfo is not None: - setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped)) + setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped)) + +_fun_name_re = re.compile(r"(?:)") # TODO(mattjj): make this function internal to this module -def fun_sourceinfo(fun: Callable) -> str | None: +def fun_sourceinfo(fun: Callable) -> str: + # See TracingDebugInfo.fun_src_info res = getattr(fun, "__fun_sourceinfo__", None) if res is not None: return res while isinstance(fun, partial): @@ -639,28 +639,51 @@ def fun_sourceinfo(fun: Callable) -> str | None: filename = fun.__code__.co_filename lineno = fun.__code__.co_firstlineno return f"{fun.__name__} at {filename}:{lineno}" - except AttributeError: - return None + except AttributeError as e: + try: + fun_str = str(fun) + except: + return "" + # By contract, the function name has no spaces; also, we want to avoid + # fun_sourceinfo of the form "", because it makes + # lowering non-deterministic. + if m := _fun_name_re.match(fun_str): + return m.group(1) + return "" + -# TODO(necula): this should never return None def _non_static_arg_names(fn_signature: inspect.Signature | None, args: Sequence[Any], kwargs: dict[str, Any], static_argnums: Sequence[int], static_argnames: Sequence[str], - ) -> tuple[str | None, ...] | None: - if fn_signature is None: - return None + ) -> tuple[str | None, ...]: + """Returns the names of the non-static arguments. + + If the `fn_signature` is given then we get from it the names of the + top-level arguments. In other cases, including when the `args` and `kwargs` + do not match the signature, we use names like `args[0[]`, `args[1]`, etc. + """ static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} - try: - ba = fn_signature.bind(*args_, **kwargs) - except (ValueError, TypeError): - return None - return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) + kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + if fn_signature is not None: + try: + ba = fn_signature.bind(*args_, **kwargs_) + except (ValueError, TypeError): + pass + else: + return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() + for path, l in generate_key_paths(x) if l is not static) + args_arg_names = tuple(f'args{keystr(path)}' + for path, l in generate_key_paths(args_) + if l is not static) + kwargs_arg_names = tuple(f'kwargs{keystr(path)}' + for path, l in generate_key_paths(kwargs_) + if l is not static) + arg_names = args_arg_names + kwargs_arg_names + return arg_names @lu.transformation_with_aux2 def result_paths(_fun, _store, *args, **kwargs): diff --git a/jax/_src/core.py b/jax/_src/core.py index df061d5f8b8f..3c3ae5315dfc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -77,10 +77,20 @@ EffectTypeSet = effects.EffectTypeSet no_effects: Effects = effects.no_effects + +# TODO(necula): make this an extension of TracingDebugInfo class JaxprDebugInfo(NamedTuple): traced_for: str # e.g. 'jit', 'scan', etc - func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}' - arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... ) + # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have + # no source location information. The first word is always the function name, + # which may be ''. + func_src_info: str + + # The paths of the flattened non-static argnames, + # e.g. ('x', 'dict_arg["a"]', ... ). + # Uses `None` for the args that do not correspond to user-named arguments, + # e.g., tangent args in jax.jvp. + arg_names: tuple[str | None, ...] result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...) class Jaxpr: @@ -140,7 +150,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], self._eqns = list(eqns) self._effects = effects self._debug_info = debug_info - assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars) + assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) def __str__(self): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index cc3399f4e4e2..ea1df444cd9c 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1545,7 +1545,7 @@ def _origin_msg(self): return "" origin = ("The error occurred while tracing the function " - f"{dbg.func_src_info or ''} for {dbg.traced_for}. ") + f"{dbg.func_src_info} for {dbg.traced_for}. ") if invar_pos and dbg.arg_names: try: arg_names = [dbg.arg_names[i] for i in invar_pos] @@ -2116,7 +2116,7 @@ def tracing_debug_info( out_tree_thunk: Callable[[], PyTreeDef], has_kwargs: bool, traced_for: str -) -> lu.TracingDebugInfo | None: +) -> lu.TracingDebugInfo: # TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead # We just have to make sure we grad the debugging information when we have # the unflattened args diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 919cd90c3521..ef2c534e0606 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -259,7 +259,10 @@ class TracingDebugInfo(NamedTuple): Formed just before staging to a jaxpr and read in trace-time error messages. """ traced_for: str # e.g. 'jit', 'scan', etc - func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}' + # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have + # no source location information. The first word is always the function name, + # which may be ''. + func_src_info: str # The paths of the flattened non-static argnames, # e.g. ('x', 'dict_arg["a"]', ... ). diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index df825f4e20a1..5930d214904f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -75,6 +75,7 @@ class CompilerParams(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +# TODO(necula): clean up the splitting of the fun_sourceinfo @dataclasses.dataclass(frozen=True) class NameAndSrcInfo: #: The name of the pallas_call or the name of the kernel function. @@ -108,9 +109,12 @@ def from_pallas_call(pallas_call_name: str | None, if pallas_call_name is not None: return NameAndSrcInfo(pallas_call_name, f"for kernel function {src_info}") - src_info_parts = src_info.split(" ") - return NameAndSrcInfo(src_info_parts[0], - " ".join(src_info_parts[1:])) + src_info_parts = src_info.split(" at ") + if len(src_info_parts) > 1: + return NameAndSrcInfo(src_info_parts[0], + "at " + " ".join(src_info_parts[1:])) + else: + return NameAndSrcInfo(src_info_parts[0], "") split_list = util.split_list diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 64cd93ba1136..87a63db3928a 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1814,7 +1814,7 @@ def wrapped(*args): "pallas_call kernel", kernel, [1] * len(kernel_fun_sig.parameters), {}) - arg_names = kernel_debug_info and kernel_debug_info.arg_names + arg_names = kernel_debug_info.arg_names del kernel_debug_info in_origins = tuple(in_path_to_input_origin(p, arg_names) for p in in_paths) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 636d8a68f142..a608e54ff392 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -143,7 +143,7 @@ class PjitInfo(NamedTuple): In other words, this structure contains arguments to jit()/pjit(), preprocessed and validated. """ - fun_sourceinfo: str | None + fun_sourceinfo: str fun_signature: inspect.Signature | None # Shardings, as specified by the user. These can either be UNSPECIFIED or they # can be a tree (prefix) of shardings or None. @@ -537,7 +537,7 @@ class PjitParams(NamedTuple): in_tree: PyTreeDef out_tree: PyTreeDef donated_invars: tuple[bool, ...] - arg_names: tuple[str | None, ...] | None + arg_names: tuple[str | None, ...] num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] abstract_mesh: AbstractMesh @@ -1189,7 +1189,8 @@ def unpack(key): # have we seen this function before at all? fun_name = getattr(f, '__qualname__', f) if debug_info is not None and debug_info.func_src_info: - _, _, *rest = debug_info.func_src_info.split(' ') + # TODO(necula): clean up the extraction of the source info + _, *rest = debug_info.func_src_info.split(' at ') src_info = " defined at " + ' '.join(rest) else: src_info = '' diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 3722e83e74df..4ebeddd34361 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -22,6 +22,7 @@ from absl.testing import absltest, parameterized import jax from jax import lax +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import test_util as jtu @@ -40,6 +41,98 @@ class DebugInfoTest(jtu.JaxTestCase): + def test_debug_info_basic(self): + def my_f(x, y, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) + self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") + self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) + self.assertIsNone(dbg.result_paths_thunk) + + def test_debug_info_arg_passed_as_kwarg(self): + def my_f(x, y, z): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3)) + self.assertEqual(dbg.arg_names, ("x", "y", "z")) + + def test_debug_info_pytrees(self): + def my_f(x_tree, *, y_tree): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),), + dict(y_tree=dict(z=3, w=4))) + self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]", + "y_tree['w']", "y_tree['z']")) + + def test_debug_info_with_statics(self): + def my_f(x, y, *, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x", "z")) + + def test_debug_info_with_pytrees_and_statics(self): + def my_f(x, y, *, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)), + dict(z=(3, 4), w=(5, 6)), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) + + def test_debug_info_too_many_args(self): + def my_f(x): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3)) + self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) + + def test_debug_info_no_source_info_built_in(self): + # built-in function "int" does not have an inspect.Signature + dbg = api_util.tracing_debug_info("jit", max, (1,), {}) + self.assertEqual(dbg.func_src_info, "max") + self.assertEqual(dbg.arg_names, ("args[0]",)) + + def test_debug_info_lambda(self): + # built-in function "int" does not have an inspect.Signature + dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {}) + self.assertRegex(dbg.func_src_info, r"^ at .*debug_info_test.py:\d+") + self.assertEqual(dbg.arg_names, ("my_arg",)) + + def test_debug_info_no_source_info_not_callable(self): + # built-in function "int" does not have an inspect.Signature + dbg = api_util.tracing_debug_info("jit", False, (1,), {}) + self.assertEqual(dbg.func_src_info, "") + self.assertEqual(dbg.arg_names, ("args[0]",)) + + def test_debug_info_no_source_info_callable(self): + class Foo: + x: int + def __call__(self, y): + return self.x + y + + dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) + self.assertRegex(dbg.func_src_info, "") + self.assertEqual(dbg.arg_names, ("y",)) + + def test_debug_info_no_source_info_callable_with_repr_errors(self): + class Foo: + x: int + def __call__(self, y): + return self.x + y + + def __repr__(self): + raise NotImplementedError + + dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) + self.assertRegex(dbg.func_src_info, "") + self.assertEqual(dbg.arg_names, ("y",)) + def helper_save_tracer(self, x): self._saved_tracer = x return x diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index b5d517ee43e7..0fc73e041a87 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -966,9 +966,8 @@ def my_index_map(): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - # TODO(necula): the function name should be "my_index_map" - "Index map function unknown .* " - "must return 1 values to match .*" + "Index map function my_index_map at .*pallas_test.py.* " + "for x_ref must return 1 values to match .*" "Currently returning 2 values."): f(a) @@ -982,9 +981,8 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - # TODO(necula): the function name should be "my_index_map" - "Index map function unknown .* " - "must return integer scalars. Output\\[0\\] has " + "Index map function my_index_map at .*pallas_test.py.* " + "for x_ref must return integer scalars. Output\\[0\\] has " "type .*float"): f(a) @@ -998,9 +996,8 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - # TODO(necula): the function name should be "my_index_map" - "Index map function unknown .* " - "must return integer scalars. Output\\[0\\] has " + "Index map function my_index_map at .*pallas_test.py.* " + "for x_ref must return integer scalars. Output\\[0\\] has " "type .*int32\\[4\\]"): f(a)