Skip to content

Commit

Permalink
[better_errors] Ensure that debug_info.arg_names is never None.
Browse files Browse the repository at this point in the history
Most places in the code were already assuming this anyway.

PiperOrigin-RevId: 717538999
  • Loading branch information
gnecula authored and Google-ML-Automation committed Jan 20, 2025
1 parent a43edb4 commit 9542cd7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 38 deletions.
57 changes: 30 additions & 27 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Callable, Iterable, Sequence
import inspect
import logging
import operator
from functools import partial, lru_cache
from typing import Any
Expand Down Expand Up @@ -590,21 +591,6 @@ def _dtype(x):
def api_hook(fun, tag: str):
return fun

# TODO(necula): replace usage with tracing_debug_info
def debug_info(
traced_for: str, fun_src_info: str | None,
fun_signature: inspect.Signature | None,
args: tuple[Any, ...], kwargs: dict[str, Any],
static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)


def tracing_debug_info(
traced_for: str,
Expand All @@ -618,15 +604,16 @@ def tracing_debug_info(
# TODO(necula): check if we really need this, e.g., to speed up tracing.
sourceinfo: str | None = None,
signature: inspect.Signature | None = None,
) -> TracingDebugInfo:
) -> TracingDebugInfo | None:
if sourceinfo is None:
sourceinfo = fun_sourceinfo(fun)
if signature is None:
signature = fun_signature(fun)
arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums,
static_argnames)
# TODO(necula): remove type: ignore once we fix arg_names to never be None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) # type: ignore
if arg_names is None:
return None
return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk)


def fun_signature(fun: Callable) -> inspect.Signature | None:
Expand Down Expand Up @@ -660,19 +647,35 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...] | None:
if fn_signature is None: return None
) -> tuple[str | None, ...]:
"""Returns the names of the non-static arguments.
If the `fn_signature` is given then we gets from it the names of the
top-level arguments, else we use names like `args[0[]`, `args[1]`, etc.
"""
static = object()
static_argnums_ = _ensure_inbounds(True, len(args), static_argnums)
static_argnames_ = set(static_argnames)
args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)]
kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
try:
ba = fn_signature.bind(*args_, **kwargs)
except (ValueError, TypeError):
return None
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()}
if fn_signature is not None:
try:
ba = fn_signature.bind(*args_, **kwargs_)
except (ValueError, TypeError) as e:
logging.info("xxx Failed to bind signature: %s", e)
else:
return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items()
for path, l in generate_key_paths(x) if l is not static)
args_arg_names = tuple(f'args{keystr(path)}'
for path, l in generate_key_paths(args_)
if l is not static)
kwargs_arg_names = tuple(f'kwargs{keystr(path)}'
for path, l in generate_key_paths(kwargs_)
if l is not static)
arg_names = args_arg_names + kwargs_arg_names
logging.info("xxx arg_names: %s", arg_names)
return arg_names


@lu.transformation_with_aux2
def result_paths(_fun, _store, *args, **kwargs):
Expand Down
18 changes: 7 additions & 11 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, check_callable, resolve_argnums,
argnames_partial_except, debug_info, tracing_debug_info, result_paths, add_jaxpr_debug_info,
argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info,
hoist_obj_attrs, _check_no_aliased_ref_args,
_check_no_aliased_closed_over_refs)
from jax._src.interpreters import partial_eval as pe
Expand Down Expand Up @@ -565,17 +565,13 @@ def _infer_params_impl(
"device is also specified as an argument to jit.")

axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)
dbg = tracing_debug_info('jit', fun, args, kwargs,
static_argnums=ji.static_argnums,
static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
sourceinfo=ji.fun_sourceinfo,
signature=ji.fun_signature)

dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
# TODO(necula): replace the above with below.
# haiku/_src/integration:hk_transforms_test fails
# dbg = tracing_debug_info('jit', fun, args, kwargs,
# static_argnums=ji.static_argnums,
# static_argnames=ji.static_argnames,
# TODO(necula): do we really need this, e.g., for tracing speed
# sourceinfo = ji.fun_sourceinfo,
# signature = ji.fun_signature)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)
Expand Down

0 comments on commit 9542cd7

Please sign in to comment.