diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8e70b4f5de40..580154d012ac 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -133,8 +133,10 @@ def asarray(x: ArrayLike) -> Array: """Lightweight conversion of ArrayLike input to Array output.""" if isinstance(x, Array): return x - if isinstance(x, (np.ndarray, np.generic, bool, int, float, builtins.complex)): - return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x)) # type: ignore[unused-ignore,bad-return-type] + elif isinstance(x, (bool, np.ndarray, np.generic)): + return _convert_element_type(x, weak_type=False) # type: ignore[bad-return-type] + elif isinstance(x, (int, float, builtins.complex)): + return _convert_element_type(dtypes.coerce_to_array(x), weak_type=True) # type: ignore[bad-return-type] else: raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.") diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 7a5adfc40145..4cf1d7b6e57b 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -137,8 +137,6 @@ def _arraylike_asarray(x: Any) -> Array: """Convert an array-like object to an array.""" if hasattr(x, '__jax_array__'): x = x.__jax_array__() - elif isinstance(x, (bool, int, float, complex)): - x = dtypes.coerce_to_array(x) return lax.asarray(x)