Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Ensure debug_info.arg_names is never None. #25992

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,10 @@ def f_(*args):
out_tree = lambda: tree_structure(out_shape)
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
return _saved_residuals(jaxpr, dbg.arg_names)

def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
def _saved_residuals(jaxpr: core.Jaxpr,
arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

Expand All @@ -471,7 +472,7 @@ def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:

for i, v in enumerate(jaxpr.invars):
if v in res_vars:
if arg_names is not None:
if arg_names[i] is not None:
src = f'from the argument {arg_names[i]}'
else:
src = 'from the argument at flattened index {i}'
Expand Down Expand Up @@ -587,7 +588,8 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
_, staged_unk = partition_list(in_used_staged, in_unknowns)
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars),
("",) * len(jaxpr_known.invars))
logger.log(log_level,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
Expand Down
63 changes: 43 additions & 20 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import operator
from functools import partial, lru_cache
import re
from typing import Any

from jax._src import core
Expand Down Expand Up @@ -603,15 +604,13 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo | None:
) -> TracingDebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


Expand All @@ -624,12 +623,13 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
# Prefer this to functools.wraps because it does not create a reference to
# the wrapped function.
sourceinfo = fun_sourceinfo(wrapped)
if sourceinfo is not None:
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))

_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")

# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str | None:
def fun_sourceinfo(fun: Callable) -> str:
# See TracingDebugInfo.fun_src_info
res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res
while isinstance(fun, partial):
Expand All @@ -639,28 +639,51 @@ def fun_sourceinfo(fun: Callable) -> str | None:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return None
except AttributeError as e:
try:
fun_str = str(fun)
except:
return "<unknown>"
# By contract, the function name has no spaces; also, we want to avoid
# fun_sourceinfo of the form "<object Foo at 0x1234>", because it makes
# lowering non-deterministic.
if m := _fun_name_re.match(fun_str):
return m.group(1)
return "<unknown>"


# TODO(necula): this should never return None
def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None:
return None
) -> tuple[str | None, ...]:
"""Returns the names of the non-static arguments.

If the `fn_signature` is given then we get from it the names of the
top-level arguments. In other cases, including when the `args` and `kwargs`
do not match the signature, we use names like `args[0[]`, `args[1]`, etc.
"""
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = fn_signature.bind(*args_, **kwargs)
except (ValueError, TypeError):
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
if fn_signature is not None:
try:
ba = fn_signature.bind(*args_, **kwargs_)
except (ValueError, TypeError):
pass
else:
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
args_arg_names = tuple(f'args{keystr(path)}'
for path, l in generate_key_paths(args_)
if l is not static)
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
for path, l in generate_key_paths(kwargs_)
if l is not static)
arg_names = args_arg_names + kwargs_arg_names
return arg_names

@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
Expand Down
13 changes: 9 additions & 4 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects


# TODO(necula): make this an extension of TracingDebugInfo
class JaxprDebugInfo(NamedTuple):
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
# An extension of lu.TracingDebugInfo; see comments there
traced_for: str
func_src_info: str
arg_names: tuple[str | None, ...]
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)


class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
'_effects', '_debug_info']
Expand Down Expand Up @@ -140,7 +145,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)

def __str__(self):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ def _origin_msg(self):
return ""

origin = ("The error occurred while tracing the function "
f"{dbg.func_src_info or '<unknown>'} for {dbg.traced_for}. ")
f"{dbg.func_src_info} for {dbg.traced_for}. ")
if invar_pos and dbg.arg_names:
try:
arg_names = [dbg.arg_names[i] for i in invar_pos]
Expand Down Expand Up @@ -2116,7 +2116,7 @@ def tracing_debug_info(
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo | None:
) -> lu.TracingDebugInfo:
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ class TracingDebugInfo(NamedTuple):
Formed just before staging to a jaxpr and read in trace-time error messages.
"""
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
# no source location information. The first word is always the function name,
# which may be '<unknown>'.
func_src_info: str

# The paths of the flattened non-static argnames,
# e.g. ('x', 'dict_arg["a"]', ... ).
Expand Down
10 changes: 7 additions & 3 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CompilerParams(Protocol):
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


# TODO(necula): clean up the splitting of the fun_sourceinfo
@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
#: The name of the pallas_call or the name of the kernel function.
Expand Down Expand Up @@ -108,9 +109,12 @@ def from_pallas_call(pallas_call_name: str | None,
if pallas_call_name is not None:
return NameAndSrcInfo(pallas_call_name,
f"for kernel function {src_info}")
src_info_parts = src_info.split(" ")
return NameAndSrcInfo(src_info_parts[0],
" ".join(src_info_parts[1:]))
src_info_parts = src_info.split(" at ")
if len(src_info_parts) > 1:
return NameAndSrcInfo(src_info_parts[0],
"at " + " ".join(src_info_parts[1:]))
else:
return NameAndSrcInfo(src_info_parts[0], "")


split_list = util.split_list
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,7 @@ def wrapped(*args):
"pallas_call kernel",
kernel,
[1] * len(kernel_fun_sig.parameters), {})
arg_names = kernel_debug_info and kernel_debug_info.arg_names
arg_names = kernel_debug_info.arg_names
del kernel_debug_info
in_origins = tuple(in_path_to_input_origin(p, arg_names)
for p in in_paths)
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class PjitInfo(NamedTuple):
In other words, this structure contains arguments to jit()/pjit(),
preprocessed and validated.
"""
fun_sourceinfo: str | None
fun_sourceinfo: str
fun_signature: inspect.Signature | None
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
# can be a tree (prefix) of shardings or None.
Expand Down Expand Up @@ -537,7 +537,7 @@ class PjitParams(NamedTuple):
in_tree: PyTreeDef
out_tree: PyTreeDef
donated_invars: tuple[bool, ...]
arg_names: tuple[str | None, ...] | None
arg_names: tuple[str | None, ...]
num_consts: int
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
abstract_mesh: AbstractMesh
Expand Down Expand Up @@ -1189,7 +1189,8 @@ def unpack(key):
# have we seen this function before at all?
fun_name = getattr(f, '__qualname__', f)
if debug_info is not None and debug_info.func_src_info:
_, _, *rest = debug_info.func_src_info.split(' ')
# TODO(necula): clean up the extraction of the source info
_, *rest = debug_info.func_src_info.split(' at ')
src_info = " defined at " + ' '.join(rest)
else:
src_info = ''
Expand Down
93 changes: 93 additions & 0 deletions tests/debug_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl.testing import absltest, parameterized
import jax
from jax import lax
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
Expand All @@ -40,6 +41,98 @@

class DebugInfoTest(jtu.JaxTestCase):

def test_debug_info_basic(self):
def my_f(x, y, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
self.assertIsNone(dbg.result_paths_thunk)

def test_debug_info_arg_passed_as_kwarg(self):
def my_f(x, y, z):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3))
self.assertEqual(dbg.arg_names, ("x", "y", "z"))

def test_debug_info_pytrees(self):
def my_f(x_tree, *, y_tree):
pass

dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),),
dict(y_tree=dict(z=3, w=4)))
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
"y_tree['w']", "y_tree['z']"))

def test_debug_info_with_statics(self):
def my_f(x, y, *, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x", "z"))

def test_debug_info_with_pytrees_and_statics(self):
def my_f(x, y, *, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)),
dict(z=(3, 4), w=(5, 6)),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))

def test_debug_info_too_many_args(self):
def my_f(x):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3))
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))

def test_debug_info_no_source_info_built_in(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", max, (1,), {})
self.assertEqual(dbg.func_src_info, "max")
self.assertEqual(dbg.arg_names, ("args[0]",))

def test_debug_info_lambda(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {})
self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("my_arg",))

def test_debug_info_no_source_info_not_callable(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", False, (1,), {})
self.assertEqual(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("args[0]",))

def test_debug_info_no_source_info_callable(self):
class Foo:
x: int
def __call__(self, y):
return self.x + y

dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",))

def test_debug_info_no_source_info_callable_with_repr_errors(self):
class Foo:
x: int
def __call__(self, y):
return self.x + y

def __repr__(self):
raise NotImplementedError

dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",))

def helper_save_tracer(self, x):
self._saved_tracer = x
return x
Expand Down
Loading
Loading