Skip to content

Commit

Permalink
[shape_poly] Remove the deprecated PolyShape object for specifying sy…
Browse files Browse the repository at this point in the history
…mbolic dimensions

PolyShape has been deprecated in January 2024. The constructor
has been raising a DeprecationWarning since then.

PiperOrigin-RevId: 702742514
  • Loading branch information
gnecula authored and Google-ML-Automation committed Jan 7, 2025
1 parent 7997f08 commit d45090a
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 45 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes:
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
as shortcuts of the corresponding `tree_util` functions.

* The `jax2tf.PolyShape` has been removed (was deprecated since January 2024).
Use instead string specifications for the symbolic dimensions.
E.g., instead of `PolyShape("d1", "d2")` you can use `"d1, d2"`.
* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
Expand All @@ -65,6 +67,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
* from {mod}`jax.numpy`: `round_`.


* New Features
* {func}`jax.export.export` can be used for device-polymorphic export with
shardings constructed with {func}`jax.sharding.AbstractMesh`.
Expand Down
28 changes: 0 additions & 28 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import operator as op
import tokenize
from typing import Any, Union, overload
import warnings

import numpy as np
import opt_einsum
Expand Down Expand Up @@ -1351,31 +1350,6 @@ def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *,
mlir.register_lowering(dim_as_value_p, _dim_as_value_lowering)


class PolyShape(tuple):
"""Tuple of polymorphic dimension specifications.
See docstring of :func:`jax2tf.convert`.
"""

def __init__(self, *dim_specs):
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
DeprecationWarning, stacklevel=2)
tuple.__init__(dim_specs)

def __new__(cls, *dim_specs):
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
DeprecationWarning, stacklevel=2)
for ds in dim_specs:
if not isinstance(ds, (int, str)) and ds != ...:
msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string "
"representing a dimension variable, or an integer, or ...")
raise ValueError(msg)
return tuple.__new__(PolyShape, dim_specs)

def __str__(self):
return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")"


def symbolic_shape(shape_spec: str | None,
*,
constraints: Sequence[str] = (),
Expand Down Expand Up @@ -1416,8 +1390,6 @@ def symbolic_shape(shape_spec: str | None,
shape_spec_repr = repr(shape_spec)
if shape_spec is None:
shape_spec = "..."
elif isinstance(shape_spec, PolyShape): # TODO: deprecate
shape_spec = str(shape_spec)
elif not isinstance(shape_spec, str):
raise ValueError("polymorphic shape spec should be None or a string. "
f"Found {shape_spec_repr}.")
Expand Down
3 changes: 1 addition & 2 deletions jax/experimental/jax2tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from jax.experimental.jax2tf.jax2tf import (
convert as convert,
eval_polymorphic_shape as eval_polymorphic_shape,
dtype_of_val as dtype_of_val,
split_to_logical_devices as split_to_logical_devices,
DisabledSafetyCheck as DisabledSafetyCheck,
PolyShape as PolyShape # TODO: deprecate
)
from jax.experimental.jax2tf.call_tf import call_tf as call_tf
24 changes: 10 additions & 14 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@
# pylint: enable=g-direct-tensorflow-import

NameStack = source_info_util.NameStack
PolyShape = shape_poly.PolyShape # TODO: deprecate
DType = Any

DisabledSafetyCheck = export.DisabledSafetyCheck
Expand Down Expand Up @@ -266,26 +265,23 @@ def convert(fun_jax: Callable,
It is meant to be sound, but it is known to reject some JAX programs
that are shape polymorphic. The details of this feature can change.
It should be `None` (all arguments are monomorphic), a single PolyShape
or string (applies to all arguments), or a tuple/list of the same length
It should be `None` (all arguments are monomorphic), a single
string (applies to all arguments), or a tuple/list of the same length
as the function arguments. For each argument the shape specification
should be `None` (monomorphic argument), or a Python object with the
same pytree structure as the argument.
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument should be an object
`PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size, or a string denoting a
dimension variable assumed to range over non-zero dimension sizes, or
the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument. As a shortcut, an Ellipsis
A shape specification for an array argument should be a string containing
a comma-separated list of dimension specifications each being either a
positive integer denoting a known dimension of the given size,
or a string denoting a dimension variable assumed to range over
non-zero dimension sizes, or the special placeholder string "_" denoting a
constant dimension whose size is given by the actual argument.
As a shortcut, an Ellipsis
suffix in the list of dimension specifications stands for a list of "_"
placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
placeholders. E.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The lowering fails if it cannot ensure that the it would produce the same
Expand Down

0 comments on commit d45090a

Please sign in to comment.