Skip to content

Commit

Permalink
Merge pull request #25992 from gnecula:debug_info_arg_names
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718216003
  • Loading branch information
Google-ML-Automation committed Jan 22, 2025
2 parents 908df65 + 3f73f7b commit e304e9e
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 47 deletions.
10 changes: 6 additions & 4 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,10 @@ def f_(*args):
out_tree = lambda: tree_structure(out_shape)
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.tracing_debug_info(f, in_tree, out_tree, True, "saved_residuals")
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore
return _saved_residuals(jaxpr, dbg.arg_names)

def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
def _saved_residuals(jaxpr: core.Jaxpr,
arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

Expand All @@ -471,7 +472,7 @@ def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:

for i, v in enumerate(jaxpr.invars):
if v in res_vars:
if arg_names is not None:
if arg_names[i] is not None:
src = f'from the argument {arg_names[i]}'
else:
src = 'from the argument at flattened index {i}'
Expand Down Expand Up @@ -587,7 +588,8 @@ def remat_partial_eval(trace: pe.JaxprTrace, *tracers: core.Tracer,
_, staged_unk = partition_list(in_used_staged, in_unknowns)
res_invars, _ = partition_list(staged_unk, jaxpr_unknown.invars[num_res:])
res_outvars = jaxpr_known.outvars[len(jaxpr_known.outvars) - num_res:]
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars), None)
body_res = _saved_residuals(jaxpr_known.replace(outvars=res_outvars),
("",) * len(jaxpr_known.invars))
logger.log(log_level,
'remat-decorated function ' +
'saving inputs with shapes:\n' * bool(res_invars) +
Expand Down
63 changes: 43 additions & 20 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
import operator
from functools import partial, lru_cache
import re
from typing import Any

from jax._src import core
Expand Down Expand Up @@ -603,15 +604,13 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo | None:
) -> TracingDebugInfo:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


Expand All @@ -624,12 +623,13 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
def save_wrapped_fun_sourceinfo(wrapper: Callable, wrapped: Callable):
# Prefer this to functools.wraps because it does not create a reference to
# the wrapped function.
sourceinfo = fun_sourceinfo(wrapped)
if sourceinfo is not None:
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))
setattr(wrapper, "__fun_sourceinfo__", fun_sourceinfo(wrapped))

_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")

# TODO(mattjj): make this function internal to this module
def fun_sourceinfo(fun: Callable) -> str | None:
def fun_sourceinfo(fun: Callable) -> str:
# See TracingDebugInfo.fun_src_info
res = getattr(fun, "__fun_sourceinfo__", None)
if res is not None: return res
while isinstance(fun, partial):
Expand All @@ -639,28 +639,51 @@ def fun_sourceinfo(fun: Callable) -> str | None:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
except AttributeError:
return None
except AttributeError as e:
try:
fun_str = str(fun)
except:
return "<unknown>"
# By contract, the function name has no spaces; also, we want to avoid
# fun_sourceinfo of the form "<object Foo at 0x1234>", because it makes
# lowering non-deterministic.
if m := _fun_name_re.match(fun_str):
return m.group(1)
return "<unknown>"


# TODO(necula): this should never return None
def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None:
return None
) -> tuple[str | None, ...]:
"""Returns the names of the non-static arguments.
If the `fn_signature` is given then we get from it the names of the
top-level arguments. In other cases, including when the `args` and `kwargs`
do not match the signature, we use names like `args[0[]`, `args[1]`, etc.
"""
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = fn_signature.bind(*args_, **kwargs)
except (ValueError, TypeError):
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
if fn_signature is not None:
try:
ba = fn_signature.bind(*args_, **kwargs_)
except (ValueError, TypeError):
pass
else:
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
args_arg_names = tuple(f'args{keystr(path)}'
for path, l in generate_key_paths(args_)
if l is not static)
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
for path, l in generate_key_paths(kwargs_)
if l is not static)
arg_names = args_arg_names + kwargs_arg_names
return arg_names

@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
Expand Down
13 changes: 9 additions & 4 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@
EffectTypeSet = effects.EffectTypeSet
no_effects: Effects = effects.no_effects


# TODO(necula): make this an extension of TracingDebugInfo
class JaxprDebugInfo(NamedTuple):
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str | None, ...] # e.g. ('args[0]', ... )
# An extension of lu.TracingDebugInfo; see comments there
traced_for: str
func_src_info: str
arg_names: tuple[str | None, ...]
# This is formed after tracing, when we have concrete `result_paths`
result_paths: tuple[str, ...] # e.g. ('[0]', '[1]', ...)


class Jaxpr:
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
'_effects', '_debug_info']
Expand Down Expand Up @@ -140,7 +145,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info
assert (not debug_info or debug_info.arg_names is None or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)

def __str__(self):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ def _origin_msg(self):
return ""

origin = ("The error occurred while tracing the function "
f"{dbg.func_src_info or '<unknown>'} for {dbg.traced_for}. ")
f"{dbg.func_src_info} for {dbg.traced_for}. ")
if invar_pos and dbg.arg_names:
try:
arg_names = [dbg.arg_names[i] for i in invar_pos]
Expand Down Expand Up @@ -2116,7 +2116,7 @@ def tracing_debug_info(
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo | None:
) -> lu.TracingDebugInfo:
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,10 @@ class TracingDebugInfo(NamedTuple):
Formed just before staging to a jaxpr and read in trace-time error messages.
"""
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
# no source location information. The first word is always the function name,
# which may be '<unknown>'.
func_src_info: str

# The paths of the flattened non-static argnames,
# e.g. ('x', 'dict_arg["a"]', ... ).
Expand Down
10 changes: 7 additions & 3 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CompilerParams(Protocol):
__dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]


# TODO(necula): clean up the splitting of the fun_sourceinfo
@dataclasses.dataclass(frozen=True)
class NameAndSrcInfo:
#: The name of the pallas_call or the name of the kernel function.
Expand Down Expand Up @@ -108,9 +109,12 @@ def from_pallas_call(pallas_call_name: str | None,
if pallas_call_name is not None:
return NameAndSrcInfo(pallas_call_name,
f"for kernel function {src_info}")
src_info_parts = src_info.split(" ")
return NameAndSrcInfo(src_info_parts[0],
" ".join(src_info_parts[1:]))
src_info_parts = src_info.split(" at ")
if len(src_info_parts) > 1:
return NameAndSrcInfo(src_info_parts[0],
"at " + " ".join(src_info_parts[1:]))
else:
return NameAndSrcInfo(src_info_parts[0], "")


split_list = util.split_list
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/pallas_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,7 @@ def wrapped(*args):
"pallas_call kernel",
kernel,
[1] * len(kernel_fun_sig.parameters), {})
arg_names = kernel_debug_info and kernel_debug_info.arg_names
arg_names = kernel_debug_info.arg_names
del kernel_debug_info
in_origins = tuple(in_path_to_input_origin(p, arg_names)
for p in in_paths)
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class PjitInfo(NamedTuple):
In other words, this structure contains arguments to jit()/pjit(),
preprocessed and validated.
"""
fun_sourceinfo: str | None
fun_sourceinfo: str
fun_signature: inspect.Signature | None
# Shardings, as specified by the user. These can either be UNSPECIFIED or they
# can be a tree (prefix) of shardings or None.
Expand Down Expand Up @@ -537,7 +537,7 @@ class PjitParams(NamedTuple):
in_tree: PyTreeDef
out_tree: PyTreeDef
donated_invars: tuple[bool, ...]
arg_names: tuple[str | None, ...] | None
arg_names: tuple[str | None, ...]
num_consts: int
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
abstract_mesh: AbstractMesh
Expand Down Expand Up @@ -1189,7 +1189,8 @@ def unpack(key):
# have we seen this function before at all?
fun_name = getattr(f, '__qualname__', f)
if debug_info is not None and debug_info.func_src_info:
_, _, *rest = debug_info.func_src_info.split(' ')
# TODO(necula): clean up the extraction of the source info
_, *rest = debug_info.func_src_info.split(' at ')
src_info = " defined at " + ' '.join(rest)
else:
src_info = ''
Expand Down
93 changes: 93 additions & 0 deletions tests/debug_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl.testing import absltest, parameterized
import jax
from jax import lax
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
Expand All @@ -40,6 +41,98 @@

class DebugInfoTest(jtu.JaxTestCase):

def test_debug_info_basic(self):
def my_f(x, y, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4))
self.assertRegex(dbg.func_src_info, r"^my_f at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("x", "y", "z", "w"))
self.assertIsNone(dbg.result_paths_thunk)

def test_debug_info_arg_passed_as_kwarg(self):
def my_f(x, y, z):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3))
self.assertEqual(dbg.arg_names, ("x", "y", "z"))

def test_debug_info_pytrees(self):
def my_f(x_tree, *, y_tree):
pass

dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2),),
dict(y_tree=dict(z=3, w=4)))
self.assertEqual(dbg.arg_names, ("x_tree[0]", "x_tree[1]",
"y_tree['w']", "y_tree['z']"))

def test_debug_info_with_statics(self):
def my_f(x, y, *, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2), dict(z=3, w=4),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x", "z"))

def test_debug_info_with_pytrees_and_statics(self):
def my_f(x, y, *, z, w):
pass

dbg = api_util.tracing_debug_info("jit", my_f, ((1, 2), (2, 3)),
dict(z=(3, 4), w=(5, 6)),
static_argnums=(1,),
static_argnames=("w",))
self.assertEqual(dbg.arg_names, ("x[0]", "x[1]", "z[0]", "z[1]"))

def test_debug_info_too_many_args(self):
def my_f(x):
pass

dbg = api_util.tracing_debug_info("jit", my_f, (1, 2, 3), dict(z=3))
self.assertEqual(dbg.arg_names, ('args[0]', 'args[1]', 'args[2]', "kwargs['z']"))

def test_debug_info_no_source_info_built_in(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", max, (1,), {})
self.assertEqual(dbg.func_src_info, "max")
self.assertEqual(dbg.arg_names, ("args[0]",))

def test_debug_info_lambda(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", lambda my_arg: False, (1,), {})
self.assertRegex(dbg.func_src_info, r"^<lambda> at .*debug_info_test.py:\d+")
self.assertEqual(dbg.arg_names, ("my_arg",))

def test_debug_info_no_source_info_not_callable(self):
# built-in function "int" does not have an inspect.Signature
dbg = api_util.tracing_debug_info("jit", False, (1,), {})
self.assertEqual(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("args[0]",))

def test_debug_info_no_source_info_callable(self):
class Foo:
x: int
def __call__(self, y):
return self.x + y

dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",))

def test_debug_info_no_source_info_callable_with_repr_errors(self):
class Foo:
x: int
def __call__(self, y):
return self.x + y

def __repr__(self):
raise NotImplementedError

dbg = api_util.tracing_debug_info("jit", Foo(), (1,), {})
self.assertRegex(dbg.func_src_info, "<unknown>")
self.assertEqual(dbg.arg_names, ("y",))

def helper_save_tracer(self, x):
self._saved_tracer = x
return x
Expand Down
Loading

0 comments on commit e304e9e

Please sign in to comment.