From 2c3f21d9debb678407f969a81197f8a84863513d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 20 Nov 2024 17:16:04 -0800 Subject: [PATCH] Reverts 525b646c0ebd5205f4fa0639c94adb2de47e1cf0 PiperOrigin-RevId: 698574302 --- CHANGELOG.md | 3 + docs/jax.lax.rst | 1 + jax/_src/lax/lax.py | 104 +++++++++++++++++++++++++---- jax/_src/numpy/array_methods.py | 3 +- jax/_src/numpy/lax_numpy.py | 31 ++++----- jax/_src/pallas/mosaic/lowering.py | 21 ++++++ jax/experimental/jax2tf/jax2tf.py | 6 ++ jax/experimental/jet.py | 5 +- jax/lax/__init__.py | 2 + tests/lax_autodiff_test.py | 18 +++++ tests/lax_test.py | 27 ++++++++ tests/lax_vmap_test.py | 18 +++++ tests/pjit_test.py | 29 ++++++++ 13 files changed, 237 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec33122064b4..b51faacd9830 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * {func}`jax.export.export` can be used for device-polymorphic export with shardings constructed with {func}`jax.sharding.AbstractMesh`. See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export). + * Added {func}`jax.lax.split`. This is a primitive version of + {func}`jax.numpy.split`, added because it yields a more compact + transpose during automatic differentiation. ## jax 0.4.37 (Dec 9, 2024) diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 065127718c54..d8a28bc399c8 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -154,6 +154,7 @@ Operators slice_in_dim sort sort_key_val + split sqrt square squeeze diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 87292c2efc00..0758dda0b506 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -673,6 +673,26 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: return concatenate_p.bind(*operands, dimension=dimension) +def split(operand: ArrayLike, sizes: Sequence[int], + axis: int = 0) -> Sequence[Array]: + """Splits an array along ``axis``. + + Args: + operand: an array to split + sizes: the sizes of the split arrays. The sum of the sizes must be equal + to the size of the ``axis`` dimension of ``operand``. + axis: the axis along which to split the array. + + Returns: + A sequence of ``len(sizes)`` arrays. If ``sizes`` is + ``[s1, s2, ...]``, this function returns chunks of sizes ``s1``, ``s2``, + taken along ``axis``. + """ + operand = asarray(operand) + return split_p.bind(operand, sizes=tuple(sizes), + axis=canonicalize_axis(axis, operand.ndim)) + + _precision_strings: dict[Any, Precision] = {} class Precision(enum.Enum): @@ -4454,18 +4474,8 @@ def _concatenate_transpose_rule(t, *operands, dimension): return [ad_util.Zero(o.aval) if ad.is_undefined_primal(o) else None for o in operands] else: - limit_points = np.cumsum( - [shape[dimension] for shape in operand_shapes]).tolist() - starts = np.zeros((len(operands), t.ndim), dtype=int).tolist() - limits = np.tile(t.shape, (len(operands), 1)).tolist() - - for i, s in enumerate(starts[1:]): - s[dimension] = limit_points[:-1][i] - for i, l in enumerate(limits): - l[dimension] = limit_points[i] - - return [slicing.slice(t, start, limit) if ad.is_undefined_primal(o) - else None for o, start, limit in zip(operands, starts, limits)] + return split(t, tuple(shape[dimension] for shape in operand_shapes), + axis=dimension) def _concatenate_batch_rule(batched_args, batch_dims, *, dimension): size = next(op.shape[bdim] for op, bdim in zip(batched_args, batch_dims) @@ -4499,6 +4509,76 @@ def _concatenate_lower(ctx, *xs, dimension): mlir.register_lowering(concatenate_p, _concatenate_lower) +def _split_shape_rule(operand, *, sizes, axis): + shapes = [] + shape = list(operand.shape) + if any(s < 0 for s in sizes): + raise ValueError( + f"Sizes passed to split must be nonnegative, got {list(sizes)}") + if operand.shape[axis] != np.sum(sizes): + raise ValueError( + f"Sum of sizes {np.sum(sizes)} must be equal to dimension {axis} of the " + f"operand shape {list(operand.shape)}") + for size in sizes: + shape[axis] = size + shapes.append(tuple(shape)) + return shapes + +def _split_dtype_rule(operand, *, sizes, axis): + return (operand.dtype,) * len(sizes) + +def _split_weak_type_rule(operand, *, sizes, axis): + return (operand.weak_type,) * len(sizes) + +def _split_transpose_rule(cotangents, operand, *, sizes, axis): + assert ad.is_undefined_primal(operand) + if all(type(t) is ad_util.Zero for t in cotangents): + return ad_util.Zero(operand.aval), + cotangents = [ + _zeros(t.aval) if type(t) is ad_util.Zero else t + for t in cotangents + ] + return concatenate(cotangents, dimension=axis), + +def _split_batch_rule(batched_args, batch_dims, *, sizes, axis): + operand, = batched_args + bdim, = batch_dims + new_bdims = (bdim,) * len(sizes) + out = split(operand, sizes=sizes, axis=axis + 1 if axis >= bdim else axis) + return out, new_bdims + +def _split_lower(ctx, x, *, sizes, axis): + x_aval, = ctx.avals_in + start_indices = [0] * x_aval.ndim + limit_indices = list(x_aval.shape) + strides = (1,) * x_aval.ndim + outs = [] + for aval_out in ctx.avals_out: + limit_indices[axis] = start_indices[axis] + aval_out.shape[axis] + out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, + limit_indices=limit_indices, strides=strides) + outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out) + if config.sharding_in_types.value else out) + start_indices[axis] = limit_indices[axis] + return outs + +def _split_sharding_rule(operand, *, sizes, axis): + # TODO(yashkatariya): Once JAX supports uneven sharding at the top level, + # change this logic to `return operand.sharding` directly. + out_shapes = _split_shape_rule(operand, sizes=sizes, axis=axis) + return [slicing._get_sharding_for_varying_out_shape(out_sh, operand, 'split') + for out_sh in out_shapes] + +split_p = core.Primitive('split') +split_p.multiple_results = True +split_p.def_abstract_eval( + partial(standard_multi_result_abstract_eval, split_p, _split_shape_rule, + _split_dtype_rule, _split_weak_type_rule, _split_sharding_rule)) +split_p.def_impl(partial(dispatch.apply_primitive, split_p)) +ad.deflinear2(split_p, _split_transpose_rule) +batching.primitive_batchers[split_p] = _split_batch_rule +mlir.register_lowering(split_p, _split_lower) + def _pad_dtype_rule(operand, padding_value, *, padding_config): if operand.dtype != padding_value.dtype: msg = "pad operand and padding_value must be same dtype: got {} and {}." diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 4768a8126c72..617213ca03de 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -629,7 +629,8 @@ def _multi_slice(self: Array, # avoid circular imports. @jax.jit def _unstack(x: Array) -> list[Array]: - return [lax.index_in_dim(x, i, keepdims=False) for i in range(x.shape[0])] + dims = (0,) + return [lax.squeeze(t, dims) for t in lax.split(x, (1,) * x.shape[0])] def _chunk_iter(x, size): if size > x.shape[0]: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3da4aa462f16..9aa131420b5f 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -68,7 +68,7 @@ ) from jax._src.util import ( NumpyComplexWarning, canonicalize_axis as _canonicalize_axis, - ceil_of_ratio, partition_list, safe_zip, set_module, subvals,unzip2, + ceil_of_ratio, partition_list, safe_zip, set_module, unzip2, tuple_replace) from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding, PartitionSpec as P) @@ -3273,10 +3273,10 @@ def _split(op: str, ary: ArrayLike, if (isinstance(indices_or_sections, (tuple, list)) or isinstance(indices_or_sections, (np.ndarray, Array)) and indices_or_sections.ndim > 0): - indices_or_sections = [ + split_indices = np.asarray([0] + [ core.concrete_dim_or_error(i_s, f"in jax.numpy.{op} argument 1") - for i_s in indices_or_sections] - split_indices = [0] + list(indices_or_sections) + [size] + for i_s in indices_or_sections] + [size]) + sizes = list(np.diff(split_indices)) else: if core.is_symbolic_dim(indices_or_sections): raise ValueError(f"jax.numpy.{op} with a symbolic number of sections is " @@ -3285,21 +3285,14 @@ def _split(op: str, ary: ArrayLike, f"in jax.numpy.{op} argument 1") part_size, r = divmod(size, num_sections) if r == 0: - split_indices = [i * part_size - for i in range(num_sections + 1)] + sizes = [part_size] * num_sections elif op == "array_split": - split_indices = ( - [i * (part_size + 1) for i in range(r + 1)] + - [i * part_size + ((r + 1) * (part_size + 1) - 1) - for i in range(num_sections - r)]) + sizes = [(part_size + 1)] * r + [part_size] * (num_sections - r) else: raise ValueError(f"array split does not result in an equal division: rest is {r}") - split_indices = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] - for i in split_indices] - starts, ends = [0] * ndim(ary), shape(ary) - _subval = lambda x, i, v: subvals(x, [(i, v)]) - return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) - for start, end in zip(split_indices[:-1], split_indices[1:])] + sizes = [i if core.is_symbolic_dim(i) else np.int64(i) # type: ignore[misc] + for i in sizes] + return list(lax.split(ary, sizes, axis=axis)) @export @@ -4662,7 +4655,11 @@ def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: "Unstack requires arrays with rank > 0, however a scalar array was " "passed." ) - return tuple(moveaxis(x, axis, 0)) + dimensions = (axis,) + return tuple( + lax.squeeze(t, dimensions) + for t in lax.split(x, (1,) * x.shape[axis], axis=axis) + ) @export diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 4620b8b445b3..d0acc655a6a5 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1901,6 +1901,27 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule +def _split_lowering_rule( + ctx: LoweringRuleContext, x, *, sizes, axis +): + (x_aval,) = ctx.avals_in + slice_size = np.array(x_aval.shape, dtype=np.int64) + starts = np.zeros_like(slice_size) + strides = np.ones_like(slice_size) + outs = [] + for size, aval_out in zip(sizes, ctx.avals_out): + slice_size[axis] = size + outs.append( + vector.extract_strided_slice( + aval_to_ir_type(aval_out), x, starts, slice_size, strides + ) + ) + starts[axis] += size + return outs + +lowering_rules[lax.split_p] = _split_lowering_rule + + def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): out_type = aval_to_ir_type(ctx.avals_out[0]) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 3376e4e8b8d8..700cb07ca847 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -2087,6 +2087,12 @@ def _concatenate(*operands, dimension): tf_impl[lax.concatenate_p] = _concatenate +def _split(operand, *, sizes, axis): + return tf.split(operand, _eval_shape(sizes), axis=axis) + +tf_impl[lax.split_p] = _split + + def _conv_general_dimension_numbers_proto(dimension_numbers): """Converts a ConvDimensionNumbers to an XLA ConvolutionDimensionNumbers.""" assert isinstance(dimension_numbers, lax.ConvDimensionNumbers) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 2681ad1a2a7b..ef158ba635f7 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -73,7 +73,7 @@ from jax._src.api_util import shaped_abstractify from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal -from jax._src.util import unzip2, weakref_lru_cache +from jax._src.util import unzip2, weakref_lru_cache, safe_zip def jet(fun, primals, series): @@ -310,6 +310,8 @@ def deflinear(prim): def linear_prop(prim, primals_in, series_in, **params): primal_out = prim.bind(*primals_in, **params) series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)] + if prim.multiple_results: + series_out = safe_zip(*series_out) return primal_out, series_out deflinear(lax.neg_p) @@ -323,6 +325,7 @@ def linear_prop(prim, primals_in, series_in, **params): deflinear(lax.convert_element_type_p) deflinear(lax.broadcast_in_dim_p) deflinear(lax.concatenate_p) +deflinear(lax.split_p) deflinear(lax.pad_p) deflinear(lax.reshape_p) deflinear(lax.squeeze_p) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 321b1dda19cf..a73de9bf5ebf 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -203,6 +203,8 @@ sort as sort, sort_key_val as sort_key_val, sort_p as sort_p, + split as split, + split_p as split_p, sqrt as sqrt, sqrt_p as sqrt_p, square as square, diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index fa67b6f43bbd..a69f44f37754 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -276,6 +276,24 @@ def testConcatenateGrad(self, dim, base_shape, dtype, num_arrs): concatenate = lambda *args: lax.concatenate(args, dim) check_grads(concatenate, operands, 2, ["fwd", "rev"], eps=1.) + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(base_shape)) + ], + num_pieces=range(3), + dtype=float_dtypes, + ) + def testSplitGrad(self, axis, base_shape, dtype, num_pieces): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + operands = (rng(shape, dtype),) + split = lambda x: lax.split(x, sizes, axis) + check_grads(split, operands, 2, ["fwd", "rev"], eps=1.) + + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape, strides=strides) for lhs_shape, rhs_shape, all_strides in itertools.chain( diff --git a/tests/lax_test.py b/tests/lax_test.py index 1f7083b3a88a..5da58b38aab7 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -285,6 +285,33 @@ def testConcatenateAgainstNumpy(self, dim, base_shape, dtype, num_arrs): numpy_op = lambda *args: lax_reference.concatenate(args, dim) self._CheckAgainstNumpy(numpy_op, op, args_maker) + @jtu.sample_product( + [dict(base_shape=shape, axis=axis) for shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(shape))], + num_pieces=range(3), + dtype=lax_test_util.default_dtypes, + ) + def testSplit(self, axis, base_shape, dtype, num_pieces): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + op = lambda x: lax.split(x, sizes, axis=axis) + def numpy_op(x): + return np.split(x, np.cumsum(sizes[:-1]), axis=axis) + self._CompileAndCheck(op, args_maker) + self._CheckAgainstNumpy(numpy_op, op, args_maker) + + def testSplitErrors(self): + with self.assertRaisesRegex(ValueError, + "Sizes passed to split must be nonnegative"): + lax.split(np.arange(5), [-1]) + with self.assertRaisesRegex(ValueError, "Sum of sizes 6 must be equal"): + lax.split(np.arange(5), [6]) + with self.assertRaisesRegex(ValueError, "axis 1 is out of bounds"): + lax.split(np.arange(5), sizes=(), axis=1) + @jtu.sample_product( [ dict(lhs_shape=(b, i, 9, 10), rhs_shape=(j, i, 4, 5)) diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index bfe9fecd6c7e..2fd817c5a45e 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -344,6 +344,24 @@ def testSlice(self, shape, dtype, starts, limits, strides, bdims): op = lambda x: lax.slice(x, starts, limits, strides) self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng) + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis, bdims=bdims) + for base_shape in [(4,), (3, 4), (2, 3, 4)] + for axis in range(len(base_shape)) + for bdims in lax_test_util.all_bdims(base_shape) + ], + num_pieces=range(3), + dtype=lax_test_util.default_dtypes, + ) + def testSplit(self, base_shape, dtype, num_pieces, axis, bdims): + sizes = jtu.rand_int(self.rng(), 5)((num_pieces + 1,), np.int64) + shape = list(base_shape) + shape[axis] = np.sum(sizes) + rng = jtu.rand_default(self.rng()) + op = lambda x: lax.split(x, sizes, axis) + self._CheckBatching(op, 5, bdims, (shape,), (dtype,), rng, + multiple_results=True) + @jtu.sample_product( [dict(shape=shape, perm=perm, bdims=bdims) for shape, perm in [ diff --git a/tests/pjit_test.py b/tests/pjit_test.py index ae69e8c6ba40..619de3f02615 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5732,6 +5732,35 @@ def test_sharding_cast_src_dst_mesh_mismatch(self): ValueError, "Mesh shape of the input.*does not match"): jax.jit(f)(arr) + @jtu.with_user_mesh((2, 2), ('x', 'y')) + def test_split(self, mesh): + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @partial(jax.jit, static_argnums=(1, 2)) + def f(x, sizes=(4, 4), axis=0): + ys = lax.split(x, sizes, axis=axis) + self.assertEqual(ys[0].sharding.spec, P('x', 'y')) + self.assertEqual(ys[1].sharding.spec, P('x', 'y')) + return ys + + f(arr) + self.assertIn('@Sharding', f.lower(arr).as_text()) + + with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"): + f(arr, sizes=(1, 1), axis=1) + + def g(x): + out = f(x) + return jnp.square(jnp.sum(jnp.stack(out))) + + out = jax.grad(g)(arr) + self.assertEqual(out.sharding, s) + + out = jax.jit(jax.grad(g))(arr) + self.assertEqual(out.sharding, s) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):