Skip to content

Commit

Permalink
Always use the same code for array avals
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 17, 2024
1 parent 05ad393 commit c9afc89
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
7 changes: 0 additions & 7 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,6 @@

array_types: set[type] = {np.ndarray} | numpy_scalar_types # pylint: disable=g-bare-generic

def canonical_concrete_aval(val, weak_type=None):
weak_type = dtypes.is_weakly_typed(val) if weak_type is None else weak_type
dtype = dtypes.canonicalize_dtype(np.result_type(val))
dtypes.check_valid_dtype(dtype)
sharding = core._get_abstract_sharding(val)
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)


def masked_array_error(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
Expand Down
9 changes: 5 additions & 4 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import operator as op
from typing import Any, TYPE_CHECKING, cast

from jax._src import abstract_arrays
from jax._src import api
from jax._src import api_util
from jax._src import basearray
Expand Down Expand Up @@ -1027,18 +1026,20 @@ def make_array_from_single_device_arrays(
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
committed=True)


core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity

def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh,
self.sharding.spec._normalized_spec(self.ndim)))
else:
return self.aval

api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
core.xla_pytype_aval_mappings[ArrayImpl] = _get_aval_array

# TODO(jakevdp) replace this with true inheritance at the C++ level.
basearray.Array.register(ArrayImpl)

Expand Down

0 comments on commit c9afc89

Please sign in to comment.