From 7b20f5302b486afac25954bf9f39bcb0a13697e4 Mon Sep 17 00:00:00 2001 From: George Necula Date: Wed, 4 Dec 2024 08:53:39 -0800 Subject: [PATCH] [shape_poly] Remove the deprecated PolyShape object for specifying symbolic dimensions PiperOrigin-RevId: 702742514 --- CHANGELOG.md | 6 ++++ jax/_src/export/shape_poly.py | 51 ----------------------------- jax/experimental/jax2tf/__init__.py | 1 - jax/experimental/jax2tf/jax2tf.py | 20 +++++------ 4 files changed, 14 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce8b040439c0..2e1281cf49af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. use `uses_global_constants`. * the `lowering_platforms` kwarg for {func}`jax.export.export`: use `platforms` instead. + * The kwargs `symbolic_scope` and `symbolic_constraints` from + {func}`jax.export.symbolic_args_specs` have been removed. They were + deprecated in June 2024. Use `scope` and `constraints` instead. + * The `jax2tf.PolyShape` has been removed. Use instead string specifications + for the symbolic dimensions. E.g., instead of `PolyShape("d1", "d2")` you + can use `"d1, d2"`. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a `TypeError`. * Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 15f99533d59e..b146795a8563 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -29,7 +29,6 @@ import operator as op import tokenize from typing import Any, Union, overload -import warnings import numpy as np import opt_einsum @@ -1198,12 +1197,6 @@ def is_symbolic_dim(p: DimSize) -> bool: """ return isinstance(p, _DimExpr) -def is_poly_dim(p: DimSize) -> bool: - # TODO: deprecated January 2024, remove June 2024. - warnings.warn("is_poly_dim is deprecated, use export.is_symbolic_dim", - DeprecationWarning, stacklevel=2) - return is_symbolic_dim(p) - dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int] def _einsum_contract_path(*operands, **kwargs): @@ -1331,31 +1324,6 @@ def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, mlir.register_lowering(dim_as_value_p, _dim_as_value_lowering) -class PolyShape(tuple): - """Tuple of polymorphic dimension specifications. - - See docstring of :func:`jax2tf.convert`. - """ - - def __init__(self, *dim_specs): - warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes", - DeprecationWarning, stacklevel=2) - tuple.__init__(dim_specs) - - def __new__(cls, *dim_specs): - warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes", - DeprecationWarning, stacklevel=2) - for ds in dim_specs: - if not isinstance(ds, (int, str)) and ds != ...: - msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string " - "representing a dimension variable, or an integer, or ...") - raise ValueError(msg) - return tuple.__new__(PolyShape, dim_specs) - - def __str__(self): - return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")" - - def symbolic_shape(shape_spec: str | None, *, constraints: Sequence[str] = (), @@ -1396,8 +1364,6 @@ def symbolic_shape(shape_spec: str | None, shape_spec_repr = repr(shape_spec) if shape_spec is None: shape_spec = "..." - elif isinstance(shape_spec, PolyShape): # TODO: deprecate - shape_spec = str(shape_spec) elif not isinstance(shape_spec, str): raise ValueError("polymorphic shape spec should be None or a string. " f"Found {shape_spec_repr}.") @@ -1413,8 +1379,6 @@ def symbolic_args_specs( shapes_specs, # prefix pytree of strings constraints: Sequence[str] = (), scope: SymbolicScope | None = None, - symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24 - symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24 ): """Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`. @@ -1435,25 +1399,10 @@ def symbolic_args_specs( arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). constraints: as for :func:`jax.export.symbolic_shape`. scope: as for :func:`jax.export.symbolic_shape`. - symbolic_constraints: DEPRECATED, use `constraints`. - symbolic_scope: DEPRECATED, use `scope`. Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes replaced with symbolic dimensions as specified by `shapes_specs`. """ - if symbolic_constraints: - warnings.warn("symbolic_constraints is deprecated, use constraints", - DeprecationWarning, stacklevel=2) - if constraints: - raise ValueError("Cannot use both symbolic_constraints and constraints") - constraints = symbolic_constraints - if symbolic_scope is not None: - warnings.warn("symbolic_scope is deprecated, use scope", - DeprecationWarning, stacklevel=2) - if scope is not None: - raise ValueError("Cannot use both symbolic_scope and scope") - scope = symbolic_scope - polymorphic_shapes = shapes_specs args_flat, args_tree = tree_util.tree_flatten(args) diff --git a/jax/experimental/jax2tf/__init__.py b/jax/experimental/jax2tf/__init__.py index acd2adb2d562..2b58a04c3d23 100644 --- a/jax/experimental/jax2tf/__init__.py +++ b/jax/experimental/jax2tf/__init__.py @@ -18,6 +18,5 @@ dtype_of_val as dtype_of_val, split_to_logical_devices as split_to_logical_devices, DisabledSafetyCheck as DisabledSafetyCheck, - PolyShape as PolyShape # TODO: deprecate ) from jax.experimental.jax2tf.call_tf import call_tf as call_tf diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 188ffeb6d670..02c754602a0f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -87,7 +87,6 @@ # pylint: enable=g-direct-tensorflow-import NameStack = source_info_util.NameStack -PolyShape = shape_poly.PolyShape # TODO: deprecate DType = Any DisabledSafetyCheck = export.DisabledSafetyCheck @@ -275,18 +274,15 @@ def convert(fun_jax: Callable, See [how optional parameters are matched to arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees). - A shape specification for an array argument should be an object - `PolyShape(dim0, dim1, ..., dimn)` - where each `dim` is a dimension specification: a positive integer denoting - a monomorphic dimension of the given size, or a string denoting a - dimension variable assumed to range over non-zero dimension sizes, or - the special placeholder string "_" denoting a monomorphic dimension - whose size is given by the actual argument. As a shortcut, an Ellipsis + A shape specification for an array argument should be a string containing + a comma-separated list of dimension specifications each being either a + positive integer denoting a known dimension of the given size, + or a string denoting a dimension variable assumed to range over + non-zero dimension sizes, or the special placeholder string "_" denoting a + constant dimension whose size is given by the actual argument. + As a shortcut, an Ellipsis suffix in the list of dimension specifications stands for a list of "_" - placeholders. - - For convenience, a shape specification can also be given as a string - representation, e.g.: "batch, ...", "batch, height, width, _", possibly + placeholders. E.g.: "batch, ...", "batch, height, width, _", possibly with surrounding parentheses: "(batch, ...)". The lowering fails if it cannot ensure that the it would produce the same