From 3f73f7b0eb7d94dd79c8715f4d1aa608440da9b7 Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 20 Jan 2025 17:17:44 +0100 Subject: [PATCH] [better_errors] Ensure debug_info.arg_names is never None. Most places in the code assumed this already, but often that usage is error reporting code, which is not yet well tested. When we cannot get the `inspect.Signature` or when the args and kwargs do not match the signature, we generate the flattened argument names as: `args[0]`, `args[1]`, `kwargs['foo']`, ... Previously, in these cases we returned `arg_names` is None, and then the whole debug_info ended up being `None`, throwing away even available information. We also add support for `api_util.fun_sourceinfo` even for cases when the `fun.__code__` is not available. In those cases we used to say that `fun_sourceinfo` is `None`. Now, we use the string representation of `fun` to get the name of built-in functions, or we use "". --- jax/_src/ad_checkpoint.py | 10 +-- jax/_src/api_util.py | 63 ++++++++++++------ jax/_src/core.py | 13 ++-- jax/_src/interpreters/partial_eval.py | 4 +- jax/_src/linear_util.py | 5 +- jax/_src/pallas/core.py | 10 ++- jax/_src/pallas/pallas_call.py | 2 +- jax/_src/pjit.py | 7 +- tests/debug_info_test.py | 93 +++++++++++++++++++++++++++ tests/pallas/pallas_test.py | 15 ++--- 10 files changed, 175 insertions(+), 47 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index a55432c5fe5c..ae755e5ef92d 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -454,9 +454,10 @@ def f_(*args): out_tree = lambda: tree_structure(out_shape) assert len(jaxpr.invars) == len(in_leaves) dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals") - return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore + return _saved_residuals(jaxpr, dbg.arg_names) -def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]: +def _saved_residuals(jaxpr: core.Jaxpr, + arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]: res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} @@ -471,7 +472,7 @@ def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]: for i, v in enumerate(jaxpr.invars): if v in res_vars: - if arg_names is not None: + if arg_names[i] is not None: src = f'from the argument {arg_names[i]}' else: src = 'from the argument at flattened index {i}' @@ -587,7 +588,8 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer, _, staged_unk = partition_list(in_used_staged, in_unknowns) res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:]) res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:] - body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None) + body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), + ("",) * len(jaxpr_known.invars)) logger.log(log_level, 'remat-decorated function ' + 'saving inputs with shapes:\n' * bool(res_invars) + diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index f35ae05850e9..5c3b5b253814 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -18,6 +18,7 @@ import inspect import operator from functools import partial, lru_cache +import re from typing import Any from jax._src import core @@ -603,15 +604,13 @@ def tracing_debug_info( # TODO(necula): check if we really need this, e.g., to speed up tracing. sourceinfo: str | None = None, signature: inspect.Signature | None = None, -) -> TracingDebugInfo | None: +) -> TracingDebugInfo: if sourceinfo is None: sourceinfo = fun_sourceinfo(fun) if signature is None: signature = fun_signature(fun) arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, static_argnames) - if arg_names is None: - return None return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) @@ -624,12 +623,13 @@ def fun_signature(fun: Callable) -> inspect.Signature | None: def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable): # Prefer this to functools.wraps because it does not create a reference to # the wrapped function. - sourceinfo = fun_sourceinfo(wrapped) - if sourceinfo is not None: - setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped)) + setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped)) + +_fun_name_re = re.compile(r"(?:)") # TODO(mattjj): make this function internal to this module -def fun_sourceinfo(fun: Callable) -> str | None: +def fun_sourceinfo(fun: Callable) -> str: + # See TracingDebugInfo.fun_src_info res = getattr(fun, "__fun_sourceinfo__", None) if res is not None: return res while isinstance(fun, partial): @@ -639,28 +639,51 @@ def fun_sourceinfo(fun: Callable) -> str | None: filename = fun.__code__.co_filename lineno = fun.__code__.co_firstlineno return f"{fun.__name__} at {filename}:{lineno}" - except AttributeError: - return None + except AttributeError as e: + try: + fun_str = str(fun) + except: + return "" + # By contract, the function name has no spaces; also, we want to avoid + # fun_sourceinfo of the form "", because it makes + # lowering non-deterministic. + if m := _fun_name_re.match(fun_str): + return m.group(1) + return "" + -# TODO(necula): this should never return None def _non_static_arg_names(fn_signature: inspect.Signature | None, args: Sequence[Any], kwargs: dict[str, Any], static_argnums: Sequence[int], static_argnames: Sequence[str], - ) -> tuple[str | None, ...] | None: - if fn_signature is None: - return None + ) -> tuple[str | None, ...]: + """Returns the names of the non-static arguments. + + If the `fn_signature` is given then we get from it the names of the + top-level arguments. In other cases, including when the `args` and `kwargs` + do not match the signature, we use names like `args[0[]`, `args[1]`, etc. + """ static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} - try: - ba = fn_signature.bind(*args_, **kwargs) - except (ValueError, TypeError): - return None - return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) + kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + if fn_signature is not None: + try: + ba = fn_signature.bind(*args_, **kwargs_) + except (ValueError, TypeError): + pass + else: + return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() + for path, l in generate_key_paths(x) if l is not static) + args_arg_names = tuple(f'args{keystr(path)}' + for path, l in generate_key_paths(args_) + if l is not static) + kwargs_arg_names = tuple(f'kwargs{keystr(path)}' + for path, l in generate_key_paths(kwargs_) + if l is not static) + arg_names = args_arg_names + kwargs_arg_names + return arg_names @lu.transformation_with_aux2 def result_paths(_fun, _store, *args, **kwargs): diff --git a/jax/_src/core.py b/jax/_src/core.py index df061d5f8b8f..26038dbeaa22 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -77,12 +77,17 @@ EffectTypeSet = effects.EffectTypeSet no_effects: Effects = effects.no_effects + +# TODO(necula): make this an extension of TracingDebugInfo class JaxprDebugInfo(NamedTuple): - traced_for: str # e.g. 'jit', 'scan', etc - func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}' - arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... ) + # An extension of lu.TracingDebugInfo; see comments there + traced_for: str + func_src_info: str + arg_names: tuple[str | None, ...] + # This is formed after tracing, when we have concrete `result_paths` result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...) + class Jaxpr: __slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns', '_effects', '_debug_info'] @@ -140,7 +145,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], self._eqns = list(eqns) self._effects = effects self._debug_info = debug_info - assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars) + assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) def __str__(self): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index cc3399f4e4e2..ea1df444cd9c 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1545,7 +1545,7 @@ def _origin_msg(self): return "" origin = ("The error occurred while tracing the function " - f"{dbg.func_src_info or ''} for {dbg.traced_for}. ") + f"{dbg.func_src_info} for {dbg.traced_for}. ") if invar_pos and dbg.arg_names: try: arg_names = [dbg.arg_names[i] for i in invar_pos] @@ -2116,7 +2116,7 @@ def tracing_debug_info( out_tree_thunk: Callable[[], PyTreeDef], has_kwargs: bool, traced_for: str -) -> lu.TracingDebugInfo | None: +) -> lu.TracingDebugInfo: # TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead # We just have to make sure we grad the debugging information when we have # the unflattened args diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 919cd90c3521..ef2c534e0606 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -259,7 +259,10 @@ class TracingDebugInfo(NamedTuple): Formed just before staging to a jaxpr and read in trace-time error messages. """ traced_for: str # e.g. 'jit', 'scan', etc - func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}' + # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have + # no source location information. The first word is always the function name, + # which may be ''. + func_src_info: str # The paths of the flattened non-static argnames, # e.g. ('x', 'dict_arg["a"]', ... ). diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index df825f4e20a1..5930d214904f 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -75,6 +75,7 @@ class CompilerParams(Protocol): __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] +# TODO(necula): clean up the splitting of the fun_sourceinfo @dataclasses.dataclass(frozen=True) class NameAndSrcInfo: #: The name of the pallas_call or the name of the kernel function. @@ -108,9 +109,12 @@ def from_pallas_call(pallas_call_name: str | None, if pallas_call_name is not None: return NameAndSrcInfo(pallas_call_name, f"for kernel function {src_info}") - src_info_parts = src_info.split(" ") - return NameAndSrcInfo(src_info_parts[0], - " ".join(src_info_parts[1:])) + src_info_parts = src_info.split(" at ") + if len(src_info_parts) > 1: + return NameAndSrcInfo(src_info_parts[0], + "at " + " ".join(src_info_parts[1:])) + else: + return NameAndSrcInfo(src_info_parts[0], "") split_list = util.split_list diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 64cd93ba1136..87a63db3928a 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1814,7 +1814,7 @@ def wrapped(*args): "pallas_call kernel", kernel, [1] * len(kernel_fun_sig.parameters), {}) - arg_names = kernel_debug_info and kernel_debug_info.arg_names + arg_names = kernel_debug_info.arg_names del kernel_debug_info in_origins = tuple(in_path_to_input_origin(p, arg_names) for p in in_paths) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 636d8a68f142..a608e54ff392 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -143,7 +143,7 @@ class PjitInfo(NamedTuple): In other words, this structure contains arguments to jit()/pjit(), preprocessed and validated. """ - fun_sourceinfo: str | None + fun_sourceinfo: str fun_signature: inspect.Signature | None # Shardings, as specified by the user. These can either be UNSPECIFIED or they # can be a tree (prefix) of shardings or None. @@ -537,7 +537,7 @@ class PjitParams(NamedTuple): in_tree: PyTreeDef out_tree: PyTreeDef donated_invars: tuple[bool, ...] - arg_names: tuple[str | None, ...] | None + arg_names: tuple[str | None, ...] num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] abstract_mesh: AbstractMesh @@ -1189,7 +1189,8 @@ def unpack(key): # have we seen this function before at all? fun_name = getattr(f, '__qualname__', f) if debug_info is not None and debug_info.func_src_info: - _, _, *rest = debug_info.func_src_info.split(' ') + # TODO(necula): clean up the extraction of the source info + _, *rest = debug_info.func_src_info.split(' at ') src_info = " defined at " + ' '.join(rest) else: src_info = '' diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 3722e83e74df..4ebeddd34361 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -22,6 +22,7 @@ from absl.testing import absltest, parameterized import jax from jax import lax +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import test_util as jtu @@ -40,6 +41,98 @@ class DebugInfoTest(jtu.JaxTestCase): + def test_debug_info_basic(self): + def my_f(x, y, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4)) + self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+") + self.assertEqual(dbg.arg_names, ("x", "y", "z", "w")) + self.assertIsNone(dbg.result_paths_thunk) + + def test_debug_info_arg_passed_as_kwarg(self): + def my_f(x, y, z): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3)) + self.assertEqual(dbg.arg_names, ("x", "y", "z")) + + def test_debug_info_pytrees(self): + def my_f(x_tree, *, y_tree): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),), + dict(y_tree=dict(z=3, w=4))) + self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]", + "y_tree['w']", "y_tree['z']")) + + def test_debug_info_with_statics(self): + def my_f(x, y, *, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x", "z")) + + def test_debug_info_with_pytrees_and_statics(self): + def my_f(x, y, *, z, w): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)), + dict(z=(3, 4), w=(5, 6)), + static_argnums=(1,), + static_argnames=("w",)) + self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]")) + + def test_debug_info_too_many_args(self): + def my_f(x): + pass + + dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3)) + self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']")) + + def test_debug_info_no_source_info_built_in(self): + # built-in function "int" does not have an inspect.Signature + dbg = api_util.tracing_debug_info("jit", max, (1,), {}) + self.assertEqual(dbg.func_src_info, "max") + self.assertEqual(dbg.arg_names, ("args[0]",)) + + def test_debug_info_lambda(self): + # built-in function "int" does not have an inspect.Signature + dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {}) + self.assertRegex(dbg.func_src_info, r"^ at .*debug_info_test.py:\d+") + self.assertEqual(dbg.arg_names, ("my_arg",)) + + def test_debug_info_no_source_info_not_callable(self): + # built-in function "int" does not have an inspect.Signature + dbg = api_util.tracing_debug_info("jit", False, (1,), {}) + self.assertEqual(dbg.func_src_info, "") + self.assertEqual(dbg.arg_names, ("args[0]",)) + + def test_debug_info_no_source_info_callable(self): + class Foo: + x: int + def __call__(self, y): + return self.x + y + + dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) + self.assertRegex(dbg.func_src_info, "") + self.assertEqual(dbg.arg_names, ("y",)) + + def test_debug_info_no_source_info_callable_with_repr_errors(self): + class Foo: + x: int + def __call__(self, y): + return self.x + y + + def __repr__(self): + raise NotImplementedError + + dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {}) + self.assertRegex(dbg.func_src_info, "") + self.assertEqual(dbg.arg_names, ("y",)) + def helper_save_tracer(self, x): self._saved_tracer = x return x diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index b5d517ee43e7..0fc73e041a87 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -966,9 +966,8 @@ def my_index_map(): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - # TODO(necula): the function name should be "my_index_map" - "Index map function unknown .* " - "must return 1 values to match .*" + "Index map function my_index_map at .*pallas_test.py.* " + "for x_ref must return 1 values to match .*" "Currently returning 2 values."): f(a) @@ -982,9 +981,8 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - # TODO(necula): the function name should be "my_index_map" - "Index map function unknown .* " - "must return integer scalars. Output\\[0\\] has " + "Index map function my_index_map at .*pallas_test.py.* " + "for x_ref must return integer scalars. Output\\[0\\] has " "type .*float"): f(a) @@ -998,9 +996,8 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - # TODO(necula): the function name should be "my_index_map" - "Index map function unknown .* " - "must return integer scalars. Output\\[0\\] has " + "Index map function my_index_map at .*pallas_test.py.* " + "for x_ref must return integer scalars. Output\\[0\\] has " "type .*int32\\[4\\]"): f(a)