From 45a352041cd2aa8bb2826165d820402658b810dd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 17 Jan 2025 14:38:13 -0800 Subject: [PATCH] internal: check integer overflow in lax.asarray --- jax/_src/lax/lax.py | 6 ++++-- jax/_src/numpy/util.py | 2 -- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8a41eca8c01d..89364f7c882f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -133,8 +133,10 @@ def asarray(x: ArrayLike) -> Array: """Lightweight conversion of ArrayLike input to Array output.""" if isinstance(x, Array): return x - if isinstance(x, (np.ndarray, np.generic, bool, int, float, builtins.complex)): - return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x)) # type: ignore[unused-ignore,bad-return-type] + elif isinstance(x, (bool, np.ndarray, np.generic)): + return _convert_element_type(x, weak_type=False) # type: ignore[bad-return-type] + elif isinstance(x, (int, float, builtins.complex)): + return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True) # type: ignore[bad-return-type] else: raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.") diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 7a5adfc40145..4cf1d7b6e57b 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -137,8 +137,6 @@ def _arraylike_asarray(x: Any) -> Array: """Convert an array-like object to an array.""" if hasattr(x, '__jax_array__'): x = x.__jax_array__() - elif isinstance(x, (bool, int, float, complex)): - x = dtypes.coerce_to_array(x) return lax.asarray(x)