From c2798fe7a0ad4532525660cc7d31e7a60d9b14e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Mon, 16 Dec 2024 14:35:59 -0800 Subject: [PATCH] [Pallas TPU] Use vector.broadcast instead of vector.BroadcastOp to fix type check failure This returns an ir.Value instead of an operation and avoids a type check failure in write_env in jaxpr_subcomp PiperOrigin-RevId: 706839044 --- jax/_src/pallas/mosaic/lowering.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f798c8e07bc2..4620b8b445b3 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -1699,9 +1699,9 @@ def _dot_general_lowering_rule( list(bcast_shape), _dtype_to_ir_type(ctx.avals_out[0].dtype) ) if ctx.avals_in[0].shape != bcast_shape: - x = vector.BroadcastOp(bcast_shape, x) + x = vector.broadcast(bcast_shape, x) if ctx.avals_in[1].shape != bcast_shape: - y = vector.BroadcastOp(bcast_shape, y) + y = vector.broadcast(bcast_shape, y) red_type = aval_to_ir_type(lhs_aval.update(shape=(lhs_aval.shape[0],))) acc = arith.ConstantOp( red_type, ir.DenseElementsAttr.get_splat(red_type, val) @@ -1942,10 +1942,10 @@ def _bcast(x, y, x_aval, y_aval, out_aval): out_shape = list(out_aval.shape) if x_aval.shape != out_aval.shape: x_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(x_dtype)) - x = vector.BroadcastOp(x_ty, x) + x = vector.broadcast(x_ty, x) if y_aval.shape != out_aval.shape: y_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(y_dtype)) - y = vector.BroadcastOp(y_ty, y) + y = vector.broadcast(y_ty, y) return x, y @@ -2173,7 +2173,7 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x): if aval_out.shape == (): one = ir_constant(1.0, mlir_type=out_type) else: - one = vector.BroadcastOp(out_type, ir_constant(1.0)) + one = vector.broadcast(out_type, ir_constant(1.0)) denom = arith.addf(one, exp_neg_x) return arith.divf(one, denom) @@ -3309,10 +3309,7 @@ def _pad(val): ) if isinstance(padding_value, ir.OpResult): - pad = vector.BroadcastOp( - pad_vec_type, - padding_value, - ).result + pad = vector.broadcast(pad_vec_type, padding_value) else: scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value) pad = arith.ConstantOp(