Skip to content

Commit

Permalink
Add custom builder to reduce op allowing type inference (#1965)
Browse files Browse the repository at this point in the history
The PR implements the a custom `reduce` op builder similar to what we
have for mhlo
[code](https://github.com/openxla/xla/blob/50aec2b3b54ce7a861f45bc3b0ae9b2cc2ee2a28/xla/mlir_hlo/mhlo/IR/hlo_ops.cc#L3917).

## Background
#1869 allows the block
arguments of reduce op to have different element types than that of the
input arguments of reduce op and the output element type of the reduce
op has to equal to those block arguments. As a consequence the output
type of reduce op can no longer be inferred from the operand types. The
auto-generated builders creates a reduce op with empty block and, as a
result, does not allow inferring the type.

The proposed solution is to create a custom builder which takes the
element-type of the block arguments as arguments allowing result type
inference.
  • Loading branch information
sdasgup3 authored Jan 30, 2024
1 parent 5cad234 commit 728a7b1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
42 changes: 42 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,48 @@ LogicalResult ReduceOp::inferReturnTypeComponents(
inferredReturnShapes);
}

void ReduceOp::build(OpBuilder&, OperationState& odsState, ValueRange inputs,
ValueRange initValues, DenseI64ArrayAttr dimensions,
TypeRange elementTypes) {
odsState.addOperands(inputs);
odsState.addOperands(initValues);
odsState.addAttribute(getDimensionsAttrName(odsState.name), dimensions);
(void)odsState.addRegion();

SmallVector<int64_t> newDimensions;
Attribute encoding;
ReduceOp::Adaptor adaptor(
odsState.operands,
odsState.attributes.getDictionary(odsState.getContext()), {},
odsState.regions);

SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(adaptor.getInputs().getTypes(),
[](Type t) { return t.cast<ShapedType>(); })};
SmallVector<ShapedType> initValueTensorTypes{
llvm::map_range(adaptor.getInitValues().getTypes(),
[](Type t) { return t.cast<ShapedType>(); })};

if (failed(hlo::verifyReduceOpInputsAndInferShape(
odsState.location, inputArgTensorTypes, dimensions, newDimensions,
encoding)))
llvm::report_fatal_error("Failed to infer result type(s).");

SmallVector<Type> inferredReturnTypes;
for (auto [inputTy, elementTy] :
llvm::zip(inputArgTensorTypes, elementTypes)) {
if (inputTy.hasRank()) {
inferredReturnTypes.push_back(
RankedTensorType::get(newDimensions, elementTy, encoding));
} else {
if (encoding != nullptr)
llvm::report_fatal_error("attribute not supported.");
inferredReturnTypes.push_back(UnrankedTensorType::get(elementTy));
}
}
odsState.addTypes(inferredReturnTypes);
}

LogicalResult ReduceOp::verify() {
return hlo::verifyReduceOp(getLoc(), getInputs(), getInitValues(),
getDimensions(), getBody());
Expand Down
8 changes: 8 additions & 0 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,14 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
);
let regions = (region SizedRegion<1>:$body /*reduce_i4*/);

// Builder
// The following custom builder allows inferring the operation type using the
// 'element_types' of the arguments of the 'body'.
let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$init_values,
"DenseI64ArrayAttr":$dimensions, "TypeRange":$element_types)>,
];

let results = (outs Variadic<HLO_Tensor>);

let hasCustomAssemblyFormat = 1;
Expand Down

0 comments on commit 728a7b1

Please sign in to comment.