Skip to content

Commit

Permalink
In progress experimentation for supporting JAX Arrays with variable-w…
Browse files Browse the repository at this point in the history
…idth strings (i.e., with dtype = StringDType).

PiperOrigin-RevId: 703603535
  • Loading branch information
Google-ML-Automation committed Dec 14, 2024
1 parent d05ab5b commit be7a009
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 19 deletions.
47 changes: 34 additions & 13 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,45 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Sequence
from functools import lru_cache, partial
import inspect
import operator
from functools import partial, lru_cache
from typing import Any

import numpy as np

from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import traceback_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_map,
treedef_children, generate_key_paths, keystr, broadcast_prefix,
prefix_errors)
from jax._src.tree_util import _replace_nones
from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo
from jax._src.util import (safe_map, WrapKwArgs, Hashable, HashableFunction,
Unhashable, safe_zip)
from jax._src import traceback_util
from jax._src.tree_util import _replace_nones
from jax._src.tree_util import (
PyTreeDef,
broadcast_prefix,
generate_key_paths,
keystr,
prefix_errors,
tree_flatten,
tree_map,
tree_unflatten,
treedef_children,
)
from jax._src.util import (
Hashable,
HashableFunction,
Unhashable,
WrapKwArgs,
safe_map,
safe_zip,
)
import numpy as np

try:
from numpy import dtypes as np_dtypes
except ImportError:
np_dtypes = None

traceback_util.register_exclusion(__file__)

map = safe_map
Expand Down Expand Up @@ -614,7 +632,10 @@ def _str_abstractify(x):

def _numpy_array_abstractify(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)

if np_dtypes is not None and not isinstance(dtype, np_dtypes.StringDType): # type: ignore
dtypes.check_valid_dtype(dtype)

return ShapedArray(x.shape,
dtypes.canonicalize_dtype(dtype, allow_extended_dtype=True))
_shaped_abstractify_handlers[np.ndarray] = _numpy_array_abstractify
Expand Down
16 changes: 10 additions & 6 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@
from functools import partial
from typing import Any, Union

import numpy as np

from jax._src import core
from jax._src import dtypes
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ShapedArray
from jax._src.util import safe_zip, safe_map

from jax._src.lib import xla_client as xc
from jax._src.typing import Shape
from jax._src.util import safe_map, safe_zip
import numpy as np

try:
import numpy.dtypes as np_dtypes
except ImportError:
np_dtypes = None

from jax._src.lib import xla_client as xc

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -170,7 +173,8 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:

def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
if np_dtypes is not None and not isinstance(dtype, np_dtypes.StringDType): # type: ignore
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))


Expand Down
19 changes: 19 additions & 0 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
PartitionSpec as P)
from jax.tree_util import tree_flatten, tree_leaves, tree_map
import numpy as np

try:
from numpy import dtypes as np_dtypes
except ImportError:
np_dtypes = None
import opt_einsum

export = set_module('jax.numpy')
Expand Down Expand Up @@ -374,6 +379,10 @@ def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) ->
# Note: this will only work for files created via np.save(), not np.savez().
out = np.load(file, *args, **kwargs)
if isinstance(out, np.ndarray):

if out.dtype == np.object_:
return out

# numpy does not recognize bfloat16, so arrays are serialized as void16
if out.dtype == 'V2':
out = out.view(bfloat16)
Expand Down Expand Up @@ -5575,6 +5584,16 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# Keep the output uncommitted.
return jax.device_put(object)

# 2DO: Comment.
if isinstance(object, np.ndarray) and (
np_dtypes is not None and isinstance(dtype, np_dtypes.StringDType)
):
if (ndmin > 0) and (ndmin != object.ndim):
raise TypeError(
f"ndmin {ndmin} does not match ndims {object.ndim} of input array"
)
return jax.device_put(x=object, device=device)

# For Python scalar literals, call coerce_to_array to catch any overflow
# errors. We don't use dtypes.is_python_scalar because we don't want this
# triggering for traced values. We do this here because it matters whether or
Expand Down

0 comments on commit be7a009

Please sign in to comment.