diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 34395756f25a..f89e4d53a476 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -38,6 +38,7 @@ cond_p as cond_p, switch as switch, platform_dependent as platform_dependent, + platform_index_p as platform_index_p, ) from jax._src.lax.control_flow.solves import ( custom_linear_solve as custom_linear_solve, diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index fb77759c321b..171a87f4f031 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -934,6 +934,7 @@ def other_platforms_code(*args): ... platform_index = platform_index_p.bind( platforms=tuple(tuple(ps) for ps in platforms_lists), has_default=(default is not None)) + if default is not None: branches = branches + (default,) # Use a switch, to get the proper transformation rules for free. Since @@ -946,6 +947,8 @@ def other_platforms_code(*args): ... # recognized on the compilation platform. Detect eager mode and keep only the # needed branch. try: + # Note/TODO(mvoz): This actually rarely seems to concretize - we could look into + # core.ensure_compile_time_eval to get better single-branch selection. platform_index_concrete = core.concrete_or_error(operator.index, platform_index) except core.ConcretizationTypeError: return switch(platform_index, branches, *args) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3da4aa462f16..9c52fa459be9 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -8345,18 +8345,41 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, Array([4, 8], dtype=int32) """ util.check_arraylike("diagonal", a) - a_shape = shape(a) + if ndim(a) < 2: raise ValueError("diagonal requires an array of at least two dimensions.") offset = core.concrete_or_error(operator.index, offset, "'offset' argument of jnp.diagonal()") - a = moveaxis(a, (axis1, axis2), (-2, -1)) + def _default_diag(a): + a_shape = shape(a) + + a = moveaxis(a, (axis1, axis2), (-2, -1)) + + diag_size = max( + 0, min(a_shape[axis1] + min(offset, 0), a_shape[axis2] - max(offset, 0)) + ) + i = arange(diag_size) + j = arange(abs(offset), abs(offset) + diag_size) + return a[..., i, j] if offset >= 0 else a[..., j, i] + + def _mosaic_diag(a): + def _sum(x, axis): + return lax.reduce( + x, + np.array(0, x.dtype), + lax.add if x.dtype != bool_ else lax.bitwise_or, + (axis,), + ) + + if a.shape[0] != a.shape[1]: + raise ValueError("Mosaic diagonal requires a square array for now.") + + a_shape_eye = eye(a.shape[0]) + original_a_dtype = a.dtype + a_shape_eye, a = util.promote_dtypes(a_shape_eye, a) + return _sum(lax.mul(a_shape_eye, a), axis=0).astype(original_a_dtype) - diag_size = max(0, min(a_shape[axis1] + min(offset, 0), - a_shape[axis2] - max(offset, 0))) - i = arange(diag_size) - j = arange(abs(offset), abs(offset) + diag_size) - return a[..., i, j] if offset >= 0 else a[..., j, i] + return lax.platform_dependent(a, mosaic=_mosaic_diag, default=_default_diag) @export diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f798c8e07bc2..c55e12e13ab3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -117,6 +117,7 @@ class LoweringContext: def grid_rank(self): return len(self.grid_sizes) + @contextlib.contextmanager def grid_name_context(self): # TODO(b/355036977): generalize this across other platforms @@ -547,9 +548,13 @@ def lower_jaxpr_to_module( module_name = name_and_src_info.name attrs["sym_name"] = ir.StringAttr.get(module_name) sym_tab = ir.SymbolTable(m.operation) + func_op = lower_jaxpr_to_func( - ctx, jaxpr, mosaic_grid_mapping=mosaic_grid_mapping, - name="main", for_verification=for_verification, + ctx, + jaxpr, + mosaic_grid_mapping=mosaic_grid_mapping, + name="main", + for_verification=for_verification, ) m.body.append(func_op) sym_tab.insert(func_op) @@ -568,6 +573,7 @@ def lower_jaxpr_to_module( # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) continue + mlir_func = lower_jaxpr_to_transform_func( ctx, bm.index_map_jaxpr.jaxpr, @@ -1965,6 +1971,36 @@ def _add_lowering_rule(ctx: LoweringRuleContext, x, y): skip_mlir_conversions.add(ad_util.add_any_p) +class FoldingError(Exception): + pass + + +def _fold_and_get_constant_value(x): + def _fold(x, fuel): + if fuel <= 0: + raise FoldingError("Folding depth exceeded") + op_name = getattr(x.owner, "name", None) + binop_folds = { + "arith.maxsi": max, + "arith.minsi": min, + } + if op_name == "arith.constant": + if ir.IntegerType.isinstance(x.type): + return ir.IntegerAttr(x.owner.attributes["value"]).value + elif ir.FloatType.isinstance(x.type): + return ir.FloatAttr(x.owner.attributes["value"]).value + else: + raise ValueError(f"Unsupported constant type: {x.type}") + if op_name in binop_folds: + return binop_folds[op_name](_fold(v, fuel - 1) for v in x.owner.operands) + raise NotImplementedError(f"Folding not supported for {x.owner}") + + try: + return _fold(x, 10) + except FoldingError: + return None + + def _max_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out @@ -2681,6 +2717,11 @@ def _while_lowering_rule( def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): index, *args = args + constant_index = _fold_and_get_constant_value(index) + if constant_index is not None: + return jaxpr_subcomp( + ctx.lowering_context, branches[constant_index].jaxpr, *args + ) out_types = map(aval_to_ir_type, ctx.avals_out) pred = arith.cmpi( arith.CmpIPredicate.ne, index, ir_constant(0, index.type) @@ -3351,3 +3392,24 @@ def _pad(val): lowering_rules[lax.pad_p] = _pad_lowering_rule + + +def _platform_index_lowering( + ctx: mlir.LoweringRuleContext, + *, + platforms: Sequence[Sequence[str]], + has_default: bool, +): + for i, ps in enumerate(platforms): + # note - slightly odd structure here, as platforms is a seq[seq[str]] + if ps == ("mosaic",): + return ir_constant(i) + + if has_default: + return ir_constant(len(platforms)) + + raise NotImplementedError( + "No mosaic or default platform indexing rule found." + ) + +lowering_rules[lax.platform_index_p] = _platform_index_lowering diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 321b1dda19cf..3c8ab4f252d9 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -332,6 +332,7 @@ map as map, scan as scan, scan_p as scan_p, + platform_index_p as platform_index_p, switch as switch, while_loop as while_loop, while_p as while_p, diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index 400ad72ed486..e84b12e169fa 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -2095,6 +2095,18 @@ def kernel(x_ref, o_ref): ) self.assertTrue(acceptable_errors, "Failed with error: " + str(e)) + @parameterized.parameters((128, 128), (256, 256)) + def test_jnp_diagonal_pallas(self, n, m): + x = jnp.arange(n * m, dtype=jnp.float32).reshape((n, m)) + + def kernel(x_ref, out_ref): + out_ref[...] = jnp.diagonal(x_ref[...]) + + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct((n,), jnp.float32) + )(x) + np.testing.assert_array_equal(out, np.diagonal(x)) + class OpsInterpretTest(OpsTest): INTERPRET = True