Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verifier and Type inference changes for reduction operations #1869

Merged
merged 4 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ LogicalResult ReduceScatterOp::verify() {
}

INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
Expand Down Expand Up @@ -919,6 +918,15 @@ LogicalResult AllReduceOp::verify() {
getComputation());
}

LogicalResult AllReduceOp::inferReturnTypeComponents(
MLIRContext*, std::optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
AllReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferAllReduceOp(location, adaptor.getOperand(),
adaptor.getComputation(), inferredReturnShapes);
}

//===----------------------------------------------------------------------===//
// BatchNormGradOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1379,7 +1387,7 @@ LogicalResult ReduceWindowOp::inferReturnTypeComponents(
location, adaptor.getInputs(), adaptor.getInitValues(),
adaptor.getWindowDimensions(), adaptor.getWindowStrides(),
adaptor.getBaseDilations(), adaptor.getWindowDilations(),
adaptor.getPadding(), inferredReturnShapes);
adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes);
}

LogicalResult ReduceWindowOp::verify() {
Expand Down Expand Up @@ -1782,7 +1790,8 @@ LogicalResult ReduceOp::inferReturnTypeComponents(
ReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferReduceOp(location, adaptor.getInputs().getTypes(),
adaptor.getInitValues().getTypes(),
adaptor.getDimensions(), inferredReturnShapes);
adaptor.getDimensions(), adaptor.getBody(),
inferredReturnShapes);
}

LogicalResult ReduceOp::verify() {
Expand Down Expand Up @@ -2295,8 +2304,8 @@ LogicalResult SelectAndScatterOp::inferReturnTypes(
SmallVectorImpl<Type>& inferredReturnTypes) {
SelectAndScatterOp::Adaptor adaptor(operands, attributes, properties,
regions);
return hlo::inferSelectAndScatterOp(adaptor.getOperand(),
inferredReturnTypes);
return hlo::inferSelectAndScatterOp(
adaptor.getOperand(), adaptor.getScatter(), inferredReturnTypes);
}

LogicalResult SelectAndScatterOp::verify() {
Expand All @@ -2316,6 +2325,7 @@ LogicalResult ScatterOp::inferReturnTypes(
SmallVectorImpl<Type>& inferredReturnTypes) {
ScatterOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferScatterOp(location, adaptor.getInputs(),
adaptor.getUpdateComputation(),
inferredReturnTypes);
}

Expand Down
16 changes: 7 additions & 9 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,7 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",
}

def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
[HLO_CompatibleOperandsAndResultType /*all_reduce_c6*/]> {
[InferTensorType /*all_reduce_c6, all_reduce_c7*/]> {
let summary = "AllReduce operation";
let description = [{
Within each process group in the process grid, applies a reduction function
Expand Down Expand Up @@ -1362,8 +1362,7 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
let hasVerifier = 1;
}

def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter",
[SameOperandsAndResultElementType /*reduce_scatter_c8*/]> {
def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", []> {
let summary = "ReduceScatter operation";
let description = [{
Within each process group in the process grid, performs reduction, using
Expand Down Expand Up @@ -1448,7 +1447,7 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",
def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
RecursiveMemoryEffects,
SameVariadicOperandSize /*reduce_c3*/,
InferTensorTypeWithReify /*reduce_c7*/,
InferTensorTypeWithReify /*reduce_c7, reduce_c8*/,
SingleBlockImplicitTerminator<"ReturnOp">
]> { /*reduce_c7*/
let summary = "Reduce operation";
sdasgup3 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -1475,8 +1474,6 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
Variadic<HLO_Tensor>:$init_values, /*reduce_i2*/
I64ElementsAttr:$dimensions /*reduce_i3*/
);
// TODO(hinsu): Verify that the attached body arguments and results are
// compatible with reduce op's operands.
let regions = (region SizedRegion<1>:$body /*reduce_i4*/);

let results = (outs Variadic<HLO_Tensor>);
Expand Down Expand Up @@ -2514,7 +2511,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [

def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects,
SameVariadicOperandSize /*scatter_c5*/,
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16*/]> {
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16,
scater_c17*/]> {
let summary = "Scatter operation";
let description = [{
Produces `results` tensors which are equal to `inputs` tensors except that
Expand Down Expand Up @@ -2587,8 +2585,8 @@ def StableHLO_SelectOp: StableHLO_Op<"select", [Pure, HLO_BroadcastingElementwis
}

def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter",
[DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11*/,
RecursiveMemoryEffects]> {
[DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11,
select_and_scatter_c12*/, RecursiveMemoryEffects]> {
let summary = "SelectAndScatter operation";
let description = [{
Scatters the values from the `source` tensor using `scatter` based on the
Expand Down
155 changes: 127 additions & 28 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,52 @@ bool tensorsHaveSameElType(Type type1, Type type2,
return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision);
}

unsigned getBitWidth(Type type) {
if (auto complexTy = type.dyn_cast<ComplexType>())
return 2 * getBitWidth(complexTy.getElementType());
if (auto quantTy = type.dyn_cast<quant::QuantizedType>())
return getBitWidth(quantTy.getStorageType());
return type.getIntOrFloatBitWidth();
}

template <typename T>
bool matchesType(Type a, Type b) {
bool matches = a.isa<T>() && b.isa<T>();
// Check that expressed type matches for quantized types
if constexpr (std::is_same<T, quant::QuantizedType>::value) {
return matches && (a.cast<quant::QuantizedType>().getExpressedType() ==
b.cast<quant::QuantizedType>().getExpressedType());
}
return matches;
}

// Returns true if the element-type of type1 can be promoted to that of type2.
// An element-type 'x' is promotatble to element-type 'y' is they have the same
// base type and bitwidth(x) <= bitwidth(y). When 'x' and 'y' are quantized
// element-types, then promotion is applied only to the 'storage_type'
// component.
bool isPromotableElementType(Type type1, Type type2,
bool ignoreFpPrecision = false) {
auto tensorTy1 = type1.dyn_cast<ShapedType>();
auto tensorTy2 = type2.dyn_cast<ShapedType>();

if (!tensorTy1 || !tensorTy2) return false;

Type tensorEl1 = tensorTy1.getElementType();
Type tensorEl2 = tensorTy2.getElementType();

bool isSameType = matchesType<IntegerType>(tensorEl1, tensorEl2) ||
matchesType<FloatType>(tensorEl1, tensorEl2) ||
matchesType<ComplexType>(tensorEl1, tensorEl2) ||
matchesType<quant::QuantizedType>(tensorEl1, tensorEl2);

if (!isSameType) return false;

if (ignoreFpPrecision && tensorEl1.isa<FloatType>()) return true;

return getBitWidth(tensorEl1) <= getBitWidth(tensorEl2);
}

// Return true if type1 and type2 are shape-compatible and have same element
// type. If 'ignoreFpPrecision' is True, then allow floats with different
// precisions while checking element-types.
Expand Down Expand Up @@ -405,12 +451,6 @@ SmallVector<int64_t> inferWindowOutputShape(ArrayRef<int64_t> baseShape,
return outputDimensions;
}

unsigned potentiallyComplexBitWidth(Type type) {
auto complexTy = type.dyn_cast<ComplexType>();
return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
: type.getIntOrFloatBitWidth();
}

LogicalResult verifyReplicaGroups(std::optional<Location> location,
DenseIntElementsAttr replicaGroups,
bool allGroupsMustHaveSameSize,
Expand Down Expand Up @@ -530,6 +570,17 @@ LogicalResult verifyReduceOpInputsAndInferShape(
return success();
}

// Returns the types of the terminator arguments of the input mlir::Block
// 'block'.
SmallVector<ShapedType> getAccumulatorTypes(Block& block) {
SmallVector<ShapedType> accumulatorSubShapes;
for (Value retOperand : block.getTerminator()->getOperands()) {
auto shapedTy = retOperand.getType().cast<ShapedType>();
accumulatorSubShapes.push_back(shapedTy);
}
return accumulatorSubShapes;
}

LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
ArrayRef<ShapedType> inputTypes,
ArrayRef<ShapedType> initValueTypes,
Expand Down Expand Up @@ -598,24 +649,35 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,

// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// reduce_window_i2, scatter_c6, scatter_c15, select_and_scatter_c10
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
initValueTypes[inputIdx],
/*ignoreFpPrecision=*/true))
if (failed(verifyCompatibleShape(initValueTypes[inputIdx],
accumulatorSubShapes[inputIdx])))
return emitOptionalError(
loc, "The shape of reduction-region's result type at index ",
inputIdx, " differs from the op's corresponding init-value type: ",
accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]);

if (!isPromotableElementType(initValueTypes[inputIdx],
accumulatorSubShapes[inputIdx],
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The type of reduction-region's result type at index ", inputIdx,
" differs from the op's corresponding init-value type: ",
loc, "The element-type of reduction-region's result type at index ",
inputIdx,
" is expected to be promotable from the op's corresponding "
"init-value element-type: ",
accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]);

// reduce_c6, reduce_window_c3, scatter_c6, scatter_c15,
// select_and_scatter_c10
if (!tensorsHaveSameElType(
if (!isPromotableElementType(
inputTypes[inputIdx],
block.getArgument(numInputs + inputIdx).getType(), true))
block.getArgument(numInputs + inputIdx).getType(),
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The element-type of reduction-region's argument at index ",
numInputs + inputIdx, " is expected to be ",
numInputs + inputIdx, " is expected to be promotable from ",
inputTypes[inputIdx].getElementType(), ", but got ",
block.getArgument(numInputs + inputIdx).getType(), " as its type.");
getElementTypeOrSelf(
block.getArgument(numInputs + inputIdx).getType()));

Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
auto blockArgTensorTy = blockArgType.cast<ShapedType>();
Expand Down Expand Up @@ -1453,6 +1515,18 @@ LogicalResult inferAllToAllOp(
return success();
}

LogicalResult inferAllReduceOp(
std::optional<Location> location, Value operand, Region& computation,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// all_reduce_c6, all_reduce_c7
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnShapes.emplace_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
return success();
}

LogicalResult inferBatchNormGradOp(
std::optional<Location> location, Value operand, Value scale, Value mean,
Value variance, Value gradOutput, int64_t featureIndex,
Expand Down Expand Up @@ -2532,7 +2606,7 @@ LogicalResult inferRealOp(std::optional<Location>, Value operand,

LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
TypeRange initValueTypes, DenseIntElementsAttr dimensions,
TypeRange initValueTypes, DenseIntElementsAttr dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
Expand All @@ -2546,10 +2620,11 @@ LogicalResult inferReduceOp(
initValueTensorTypes, dimensions,
newDimensions, encoding)))
return failure();
// reduce_c2, reduce_c3, reduce_c7
// reduce_c3, reduce_c7, reduce_c8
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
ShapedType inputType = inputArgTensorTypes[inputIdx];
Type elementType = inputType.getElementType();
Type elementType = accumulatorTypes[inputIdx].getElementType();
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
else
Expand All @@ -2565,7 +2640,7 @@ LogicalResult inferReduceWindowOp(
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<DenseIntElementsAttr> baseDilations,
std::optional<DenseIntElementsAttr> windowDilations,
std::optional<DenseIntElementsAttr> padding,
std::optional<DenseIntElementsAttr> padding, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
Expand All @@ -2582,21 +2657,22 @@ LogicalResult inferReduceWindowOp(
return failure();

// reduce_window_c1, reduce_window_c14...reduce_window_c16
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
for (size_t i = 0; i < inputTypes.size(); ++i) {
auto inputRankedType = inputs[i].getType().dyn_cast<RankedTensorType>();
if (!inputRankedType) {
inferredReturnShapes.emplace_back(inputTypes[i].getElementType());
inferredReturnShapes.emplace_back(accumulatorTypes[i].getElementType());
} else {
auto resultShape =
inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow);
auto inputBounds = encodingToBounds(inputRankedType.getEncoding());
if (inputBounds.empty()) {
inferredReturnShapes.emplace_back(resultShape,
inputTypes[i].getElementType());
accumulatorTypes[i].getElementType());
} else {
auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow);
inferredReturnShapes.emplace_back(
resultShape, inputTypes[i].getElementType(),
resultShape, accumulatorTypes[i].getElementType(),
boundsToEncoding(inputRankedType.getEncoding(), resultBounds));
}
}
Expand Down Expand Up @@ -2661,8 +2737,16 @@ LogicalResult inferRngOp(
}

LogicalResult inferScatterOp(std::optional<Location>, ValueRange inputs,
Region& updateComputation,
SmallVectorImpl<Type>& inferredReturnTypes) {
llvm::append_range(inferredReturnTypes, inputs.getTypes());
// scatter_c16, scatter_c17
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(updateComputation.front());
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
auto inputShapedTy = inputs[inputIdx].getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
inputShapedTy, accumulatorTypes[inputIdx].getElementType()));
}
return success();
}

Expand Down Expand Up @@ -2692,9 +2776,14 @@ LogicalResult inferSelectOp(
}

LogicalResult inferSelectAndScatterOp(
Value operand, SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11
inferredReturnTypes.push_back(operand.getType());
Value operand, Region& scatter,
SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11, select_and_scatter_c12
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(scatter.front());
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
return success();
}

Expand Down Expand Up @@ -3139,8 +3228,8 @@ LogicalResult verifyBitcastConvertOp(std::optional<Location> location,
location, "cannot convert between real and complex types, but got: ",
operandShapedType, " and ", targetShapedType);

auto targetEltBitWidth = potentiallyComplexBitWidth(targetElt);
auto operandEltBitWidth = potentiallyComplexBitWidth(operandElt);
auto targetEltBitWidth = getBitWidth(targetElt);
auto operandEltBitWidth = getBitWidth(operandElt);

auto operandType = operandShapedType.dyn_cast<RankedTensorType>();
auto targetType = targetShapedType.dyn_cast<RankedTensorType>();
Expand Down Expand Up @@ -3821,6 +3910,16 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,
operandType.getDimSize(index), ") and result (",
resultType.getDimSize(index), ")");
}

// reduce_scatter_c9
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
if (resultType.getElementType() != accumulatorTypes[0].getElementType()) {
return emitOptionalError(location, "result element-type is expected to be ",
accumulatorTypes[0].getElementType(), ", but got ",
resultType.getElementType());
}

return success();
}

Expand Down
Loading
Loading