Skip to content

Commit

Permalink
[Pallas TPU] Use vector.broadcast instead of vector.BroadcastOp to fi…
Browse files Browse the repository at this point in the history
…x type check failure

This returns an ir.Value instead of an operation and avoids a type check failure in write_env in jaxpr_subcomp

PiperOrigin-RevId: 706839044
  • Loading branch information
tlongeri authored and Google-ML-Automation committed Dec 17, 2024
1 parent 7dd401c commit c2798fe
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,9 +1699,9 @@ def _dot_general_lowering_rule(
list(bcast_shape), _dtype_to_ir_type(ctx.avals_out[0].dtype)
)
if ctx.avals_in[0].shape != bcast_shape:
x = vector.BroadcastOp(bcast_shape, x)
x = vector.broadcast(bcast_shape, x)
if ctx.avals_in[1].shape != bcast_shape:
y = vector.BroadcastOp(bcast_shape, y)
y = vector.broadcast(bcast_shape, y)
red_type = aval_to_ir_type(lhs_aval.update(shape=(lhs_aval.shape[0],)))
acc = arith.ConstantOp(
red_type, ir.DenseElementsAttr.get_splat(red_type, val)
Expand Down Expand Up @@ -1942,10 +1942,10 @@ def _bcast(x, y, x_aval, y_aval, out_aval):
out_shape = list(out_aval.shape)
if x_aval.shape != out_aval.shape:
x_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(x_dtype))
x = vector.BroadcastOp(x_ty, x)
x = vector.broadcast(x_ty, x)
if y_aval.shape != out_aval.shape:
y_ty = ir.VectorType.get(out_shape, _dtype_to_ir_type(y_dtype))
y = vector.BroadcastOp(y_ty, y)
y = vector.broadcast(y_ty, y)
return x, y


Expand Down Expand Up @@ -2173,7 +2173,7 @@ def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
if aval_out.shape == ():
one = ir_constant(1.0, mlir_type=out_type)
else:
one = vector.BroadcastOp(out_type, ir_constant(1.0))
one = vector.broadcast(out_type, ir_constant(1.0))
denom = arith.addf(one, exp_neg_x)
return arith.divf(one, denom)

Expand Down Expand Up @@ -3309,10 +3309,7 @@ def _pad(val):
)

if isinstance(padding_value, ir.OpResult):
pad = vector.BroadcastOp(
pad_vec_type,
padding_value,
).result
pad = vector.broadcast(pad_vec_type, padding_value)
else:
scalar_attr = ir.FloatAttr.get(operand.type.element_type, padding_value)
pad = arith.ConstantOp(
Expand Down

0 comments on commit c2798fe

Please sign in to comment.