Skip to content

Commit

Permalink
[shape_poly] Remove the deprecated PolyShape object for specifying sy…
Browse files Browse the repository at this point in the history
…mbolic dimensions

PiperOrigin-RevId: 702742514
  • Loading branch information
gnecula authored and Google-ML-Automation committed Dec 4, 2024
1 parent 46eb77b commit 7b20f53
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 64 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
use `uses_global_constants`.
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
`platforms` instead.
* The kwargs `symbolic_scope` and `symbolic_constraints` from
{func}`jax.export.symbolic_args_specs` have been removed. They were
deprecated in June 2024. Use `scope` and `constraints` instead.
* The `jax2tf.PolyShape` has been removed. Use instead string specifications
for the symbolic dimensions. E.g., instead of `PolyShape("d1", "d2")` you
can use `"d1, d2"`.
* Hashing of tracers, which has been deprecated since version 0.4.30, now
results in a `TypeError`.
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
Expand Down
51 changes: 0 additions & 51 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import operator as op
import tokenize
from typing import Any, Union, overload
import warnings

import numpy as np
import opt_einsum
Expand Down Expand Up @@ -1198,12 +1197,6 @@ def is_symbolic_dim(p: DimSize) -> bool:
"""
return isinstance(p, _DimExpr)

def is_poly_dim(p: DimSize) -> bool:
# TODO: deprecated January 2024, remove June 2024.
warnings.warn("is_poly_dim is deprecated, use export.is_symbolic_dim",
DeprecationWarning, stacklevel=2)
return is_symbolic_dim(p)

dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int]

def _einsum_contract_path(*operands, **kwargs):
Expand Down Expand Up @@ -1331,31 +1324,6 @@ def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *,
mlir.register_lowering(dim_as_value_p, _dim_as_value_lowering)


class PolyShape(tuple):
"""Tuple of polymorphic dimension specifications.
See docstring of :func:`jax2tf.convert`.
"""

def __init__(self, *dim_specs):
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
DeprecationWarning, stacklevel=2)
tuple.__init__(dim_specs)

def __new__(cls, *dim_specs):
warnings.warn("PolyShape is deprecated, use string specifications for symbolic shapes",
DeprecationWarning, stacklevel=2)
for ds in dim_specs:
if not isinstance(ds, (int, str)) and ds != ...:
msg = (f"Invalid polymorphic shape element: {ds!r}; must be a string "
"representing a dimension variable, or an integer, or ...")
raise ValueError(msg)
return tuple.__new__(PolyShape, dim_specs)

def __str__(self):
return "(" + ", ".join(["..." if d is ... else str(d) for d in self]) + ")"


def symbolic_shape(shape_spec: str | None,
*,
constraints: Sequence[str] = (),
Expand Down Expand Up @@ -1396,8 +1364,6 @@ def symbolic_shape(shape_spec: str | None,
shape_spec_repr = repr(shape_spec)
if shape_spec is None:
shape_spec = "..."
elif isinstance(shape_spec, PolyShape): # TODO: deprecate
shape_spec = str(shape_spec)
elif not isinstance(shape_spec, str):
raise ValueError("polymorphic shape spec should be None or a string. "
f"Found {shape_spec_repr}.")
Expand All @@ -1413,8 +1379,6 @@ def symbolic_args_specs(
shapes_specs, # prefix pytree of strings
constraints: Sequence[str] = (),
scope: SymbolicScope | None = None,
symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24
symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24
):
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
Expand All @@ -1435,25 +1399,10 @@ def symbolic_args_specs(
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
constraints: as for :func:`jax.export.symbolic_shape`.
scope: as for :func:`jax.export.symbolic_shape`.
symbolic_constraints: DEPRECATED, use `constraints`.
symbolic_scope: DEPRECATED, use `scope`.
Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes
replaced with symbolic dimensions as specified by `shapes_specs`.
"""
if symbolic_constraints:
warnings.warn("symbolic_constraints is deprecated, use constraints",
DeprecationWarning, stacklevel=2)
if constraints:
raise ValueError("Cannot use both symbolic_constraints and constraints")
constraints = symbolic_constraints
if symbolic_scope is not None:
warnings.warn("symbolic_scope is deprecated, use scope",
DeprecationWarning, stacklevel=2)
if scope is not None:
raise ValueError("Cannot use both symbolic_scope and scope")
scope = symbolic_scope

polymorphic_shapes = shapes_specs
args_flat, args_tree = tree_util.tree_flatten(args)

Expand Down
1 change: 0 additions & 1 deletion jax/experimental/jax2tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@
dtype_of_val as dtype_of_val,
split_to_logical_devices as split_to_logical_devices,
DisabledSafetyCheck as DisabledSafetyCheck,
PolyShape as PolyShape # TODO: deprecate
)
from jax.experimental.jax2tf.call_tf import call_tf as call_tf
20 changes: 8 additions & 12 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
# pylint: enable=g-direct-tensorflow-import

NameStack = source_info_util.NameStack
PolyShape = shape_poly.PolyShape # TODO: deprecate
DType = Any

DisabledSafetyCheck = export.DisabledSafetyCheck
Expand Down Expand Up @@ -275,18 +274,15 @@ def convert(fun_jax: Callable,
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument should be an object
`PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size, or a string denoting a
dimension variable assumed to range over non-zero dimension sizes, or
the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument. As a shortcut, an Ellipsis
A shape specification for an array argument should be a string containing
a comma-separated list of dimension specifications each being either a
positive integer denoting a known dimension of the given size,
or a string denoting a dimension variable assumed to range over
non-zero dimension sizes, or the special placeholder string "_" denoting a
constant dimension whose size is given by the actual argument.
As a shortcut, an Ellipsis
suffix in the list of dimension specifications stands for a list of "_"
placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
placeholders. E.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The lowering fails if it cannot ensure that the it would produce the same
Expand Down

0 comments on commit 7b20f53

Please sign in to comment.