diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index d1ea2396db88..0280c11f3917 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -590,21 +590,6 @@ def _dtype(x): def api_hook(fun, tag: str): return fun -# TODO(necula): replace usage with tracing_debug_info -def debug_info( - traced_for: str, fun_src_info: str | None, - fun_signature: inspect.Signature | None, - args: tuple[Any, ...], kwargs: dict[str, Any], - static_argnums: tuple[int, ...], - static_argnames: tuple[str, ...] -) -> TracingDebugInfo | None: - """Try to build trace-time debug info for fun when applied to args/kwargs.""" - arg_names = _non_static_arg_names(fun_signature, args, kwargs, static_argnums, - static_argnames) - if arg_names is None: - return None - return TracingDebugInfo(traced_for, fun_src_info, arg_names, None) - def tracing_debug_info( traced_for: str, @@ -618,15 +603,16 @@ def tracing_debug_info( # TODO(necula): check if we really need this, e.g., to speed up tracing. sourceinfo: str | None = None, signature: inspect.Signature | None = None, -) -> TracingDebugInfo: +) -> TracingDebugInfo | None: if sourceinfo is None: sourceinfo = fun_sourceinfo(fun) if signature is None: signature = fun_signature(fun) arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, static_argnames) - # TODO(necula): remove type: ignore once we fix arg_names to never be None - return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) # type: ignore + if arg_names is None: + return None + return TracingDebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) def fun_signature(fun: Callable) -> inspect.Signature | None: @@ -660,19 +646,34 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, args: Sequence[Any], kwargs: dict[str, Any], static_argnums: Sequence[int], static_argnames: Sequence[str], - ) -> tuple[str | None, ...] | None: - if fn_signature is None: return None + ) -> tuple[str | None, ...]: + """Returns the names of the non-static arguments. + + If the `fn_signature` is given then we gets from it the names of the + top-level arguments, else we use names like `args[0[]`, `args[1]`, etc. + """ static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] - kwargs = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} - try: - ba = fn_signature.bind(*args_, **kwargs) - except (ValueError, TypeError): - return None - return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() - for path, l in generate_key_paths(x) if l is not static) + kwargs_ = {k:static if k in static_argnames_ else x for k, x in kwargs.items()} + if fn_signature is not None: + try: + ba = fn_signature.bind(*args_, **kwargs_) + except (ValueError, TypeError): + pass + else: + return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() + for path, l in generate_key_paths(x) if l is not static) + args_arg_names = tuple(f'args{keystr(path)}' + for path, l in generate_key_paths(args_) + if l is not static) + kwargs_arg_names = tuple(f'kwargs{keystr(path)}' + for path, l in generate_key_paths(kwargs_) + if l is not static) + arg_names = args_arg_names + kwargs_arg_names + return arg_names + @lu.transformation_with_aux2 def result_paths(_fun, _store, *args, **kwargs): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index c0d20d1b21f3..fae1824eb0ba 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -2116,7 +2116,7 @@ def tracing_debug_info( out_tree_thunk: Callable[[], PyTreeDef], has_kwargs: bool, traced_for: str -) -> lu.TracingDebugInfo: +) -> lu.TracingDebugInfo | None: # TODO(necula): we should not need this function, and can use api_util.tracing_debug_info instead # We just have to make sure we grad the debugging information when we have # the unflattened args diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 7c9449062de9..2a0b00211319 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -49,7 +49,7 @@ from jax._src.api_util import ( argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, donation_vector, check_callable, resolve_argnums, - argnames_partial_except, debug_info, tracing_debug_info, result_paths, add_jaxpr_debug_info, + argnames_partial_except, tracing_debug_info, result_paths, add_jaxpr_debug_info, hoist_obj_attrs, _check_no_aliased_ref_args, _check_no_aliased_closed_over_refs) from jax._src.interpreters import partial_eval as pe @@ -565,17 +565,13 @@ def _infer_params_impl( "device is also specified as an argument to jit.") axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs) + dbg = tracing_debug_info('jit', fun, args, kwargs, + static_argnums=ji.static_argnums, + static_argnames=ji.static_argnames, + # TODO(necula): do we really need this, e.g., for tracing speed + sourceinfo=ji.fun_sourceinfo, + signature=ji.fun_signature) - dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs, - ji.static_argnums, ji.static_argnames) - # TODO(necula): replace the above with below. - # haiku/_src/integration:hk_transforms_test fails - # dbg = tracing_debug_info('jit', fun, args, kwargs, - # static_argnums=ji.static_argnums, - # static_argnames=ji.static_argnames, - # TODO(necula): do we really need this, e.g., for tracing speed - # sourceinfo = ji.fun_sourceinfo, - # signature = ji.fun_signature) f = lu.wrap_init(fun) f, res_paths = result_paths(f) f, dyn_args = argnums_partial_except(f, ji.static_argnums, args, allow_invalid=True)