Skip to content

Commit

Permalink
Make jax.numpy.where()'s condition, x, y arguments positional-only to…
Browse files Browse the repository at this point in the history
… match numpy.where.

PiperOrigin-RevId: 584377134
  • Loading branch information
hawkinsp authored and jax authors committed Nov 21, 2023
1 parent 0388792 commit 84c1e82
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ Remember to align the itemized text with the first line of an item within a list
* Passing `None` to {func}`jax.array` or {func}`jax.asarray`, either directly or
within a list or tuple, is deprecated and now raises a {obj}`FutureWarning`.
It currently is converted to NaN, and in the future will raise a {obj}`TypeError`.

* Passing the `condition`, `x`, and `y` parameters to `jax.numpy.where` by
keyword arguments has been deprecated, to match `numpy.where`.


## jaxlib 0.4.21

Expand Down
57 changes: 42 additions & 15 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,25 +1062,27 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
return jitted_interp(x, xp, fp, left, right, period)


@overload
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
size: int | None = None,
@overload # type: ignore[no-overload-impl]
def where(condition: ArrayLike, x: Literal[None] = None,
y: Literal[None] = None, /, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> tuple[Array, ...]: ...

@overload
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, / ,*,
size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array: ...

@overload
def where(condition: ArrayLike, x: ArrayLike | None = None,
y: ArrayLike | None = None, *, size: int | None = None,
y: ArrayLike | None = None, /, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array | tuple[Array, ...]: ...

@util._wraps(np.where,
_DEPRECATED_WHERE_ARG = object()

@util._wraps(np.where, # type: ignore[no-redef]
lax_description=_dedent("""
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
Expand All @@ -1104,18 +1106,43 @@ def where(condition: ArrayLike, x: ArrayLike | None = None,
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def where(condition: ArrayLike, x: ArrayLike | None = None,
y: ArrayLike | None = None, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> Array | tuple[Array, ...]:
if x is None and y is None:
util.check_arraylike("where", condition)
return nonzero(condition, size=size, fill_value=fill_value)
def where(
acondition = None, if_true = None, if_false = None, /, *,
size=None, fill_value=None,
# Deprecated keyword-only names.
condition = _DEPRECATED_WHERE_ARG, x = _DEPRECATED_WHERE_ARG,
y = _DEPRECATED_WHERE_ARG
) -> Array | tuple[Array, ...]:
if (condition is not _DEPRECATED_WHERE_ARG or x is not _DEPRECATED_WHERE_ARG
or y is not _DEPRECATED_WHERE_ARG):
# TODO(phawkins): deprecated Nov 17 2023, remove after deprecation expires.
warnings.warn(
"Passing condition, x, or y to jax.numpy.where via keyword arguments "
"is deprecated.",
DeprecationWarning,
stacklevel=2,
)
if condition is not _DEPRECATED_WHERE_ARG:
if acondition is not None:
raise ValueError("condition should be a positional-only argument")
acondition = condition
if x is not _DEPRECATED_WHERE_ARG:
if if_true is not None:
raise ValueError("x should be a positional-only argument")
if_true = x
if y is not _DEPRECATED_WHERE_ARG:
if if_false is not None:
raise ValueError("y should be a positional-only argument")
if_false = y

if if_true is None and if_false is None:
util.check_arraylike("where", acondition)
return nonzero(acondition, size=size, fill_value=fill_value)
else:
util.check_arraylike("where", condition, x, y)
util.check_arraylike("where", acondition, if_true, if_false)
if size is not None or fill_value is not None:
raise ValueError("size and fill_value arguments cannot be used in three-term where function.")
return util._where(condition, x, y)
return util._where(acondition, if_true, if_false)


@util._wraps(np.select)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def _solve(A, b):
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)

failed = jnp.isnan(_norm(x))
info = jnp.where(failed, x=-1, y=0)
info = jnp.where(failed, -1, 0)
return x, info


Expand Down
6 changes: 3 additions & 3 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -831,19 +831,19 @@ def vstack(tup: Union[_np.ndarray, Array, Sequence[ArrayLike]],

@overload
def where(condition: ArrayLike, x: Literal[None] = ..., y: Literal[None] = ...,
*, size: Optional[int] = ...,
/, *, size: Optional[int] = ...,
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
) -> tuple[Array, ...]: ...

@overload
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /, *,
size: Optional[int] = ...,
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
) -> Array: ...

@overload
def where(condition: ArrayLike, x: Optional[ArrayLike] = ...,
y: Optional[ArrayLike] = ..., *, size: Optional[int] = ...,
y: Optional[ArrayLike] = ..., /, *, size: Optional[int] = ...,
fill_value: Union[None, ArrayLike, tuple[ArrayLike, ...]] = ...
) -> Union[Array, tuple[Array, ...]]: ...

Expand Down

0 comments on commit 84c1e82

Please sign in to comment.