Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[better_errors] Ensure that debug_info.arg_names is never None. #25990

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,21 +590,6 @@ def _dtype(x):
def api_hook(fun, tag: str):
return fun

# TODO(necula): replace usage with tracing_debug_info
def debug_info(
traced_for: str, fun_src_info: str | None,
fun_signature: inspect.Signature | None,
args: tuple[Any, ...], kwargs: dict[str, Any],
static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)


def tracing_debug_info(
traced_for: str,
Expand All @@ -618,15 +603,16 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
) -> TracingDebugInfo | None:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
# TODO(necula): remove type: ignore once we fix arg_names to never be None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) # type: ignore
if arg_names is None:
return None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


def fun_signature(fun: Callable) -> inspect.Signature | None:
Expand Down Expand Up @@ -660,19 +646,34 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None: return None
) -> tuple[str | None, ...]:
"""Returns the names of the non-static arguments.

If the `fn_signature` is given then we gets from it the names of the
top-level arguments, else we use names like `args[0[]`, `args[1]`, etc.
"""
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = fn_signature.bind(*args_, **kwargs)
except (ValueError, TypeError):
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
if fn_signature is not None:
try:
ba = fn_signature.bind(*args_, **kwargs_)
except (ValueError, TypeError):
pass
else:
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
args_arg_names = tuple(f'args{keystr(path)}'
for path, l in generate_key_paths(args_)
if l is not static)
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
for path, l in generate_key_paths(kwargs_)
if l is not static)
arg_names = args_arg_names + kwargs_arg_names
return arg_names


@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2116,7 +2116,7 @@ def tracing_debug_info(
out_tree_thunk: Callable[[], PyTreeDef],
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo:
) -> lu.TracingDebugInfo | None:
# TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead
# We just have to make sure we grad the debugging information when we have
# the unflattened args
Expand Down
18 changes: 7 additions & 11 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums,
argnames_partial_except, debug_info, tracing_debug_info, result_paths, add_jaxpr_debug_info,
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
Expand Down Expand Up @@ -565,17 +565,13 @@ def _infer_params_impl(
"device is also specified as an argument to jit.")

axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
dbg = tracing_debug_info('jit', fun, args, kwargs,
static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)

dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
# TODO(necula): replace the above with below.
# haiku/_src/integration:hk_transforms_test fails
# dbg = tracing_debug_info('jit', fun, args, kwargs,
# static_argnums=ji.static_argnums,
# static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
# sourceinfo = ji.fun_sourceinfo,
# signature = ji.fun_signature)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
Expand Down
Loading