Skip to content

Commit

Permalink
verifier and type inference changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Dec 4, 2023
1 parent afd8f5b commit c668b0e
Show file tree
Hide file tree
Showing 11 changed files with 928 additions and 63 deletions.
20 changes: 15 additions & 5 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ LogicalResult ReduceScatterOp::verify() {
}

INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
Expand Down Expand Up @@ -917,6 +916,15 @@ LogicalResult AllReduceOp::verify() {
getComputation());
}

LogicalResult AllReduceOp::inferReturnTypeComponents(
MLIRContext*, std::optional<Location> location, ValueShapeRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
AllReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferAllReduceOp(location, adaptor.getOperand(),
adaptor.getComputation(), inferredReturnShapes);
}

//===----------------------------------------------------------------------===//
// BatchNormGradOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1378,7 +1386,7 @@ LogicalResult ReduceWindowOp::inferReturnTypeComponents(
location, adaptor.getInputs(), adaptor.getInitValues(),
adaptor.getWindowDimensions(), adaptor.getWindowStrides(),
adaptor.getBaseDilations(), adaptor.getWindowDilations(),
adaptor.getPadding(), inferredReturnShapes);
adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes);
}

LogicalResult ReduceWindowOp::verify() {
Expand Down Expand Up @@ -1781,7 +1789,8 @@ LogicalResult ReduceOp::inferReturnTypeComponents(
ReduceOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferReduceOp(location, adaptor.getInputs().getTypes(),
adaptor.getInitValues().getTypes(),
adaptor.getDimensions(), inferredReturnShapes);
adaptor.getDimensions(), adaptor.getBody(),
inferredReturnShapes);
}

LogicalResult ReduceOp::verify() {
Expand Down Expand Up @@ -2312,8 +2321,8 @@ LogicalResult SelectAndScatterOp::inferReturnTypes(
SmallVectorImpl<Type>& inferredReturnTypes) {
SelectAndScatterOp::Adaptor adaptor(operands, attributes, properties,
regions);
return hlo::inferSelectAndScatterOp(adaptor.getOperand(),
inferredReturnTypes);
return hlo::inferSelectAndScatterOp(
adaptor.getOperand(), adaptor.getScatter(), inferredReturnTypes);
}

LogicalResult SelectAndScatterOp::verify() {
Expand All @@ -2333,6 +2342,7 @@ LogicalResult ScatterOp::inferReturnTypes(
SmallVectorImpl<Type>& inferredReturnTypes) {
ScatterOp::Adaptor adaptor(operands, attributes, properties, regions);
return hlo::inferScatterOp(location, adaptor.getInputs(),
adaptor.getUpdateComputation(),
inferredReturnTypes);
}

Expand Down
14 changes: 7 additions & 7 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",
}

def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
[HLO_CompatibleOperandsAndResultType /*all_reduce_c6*/]> {
[InferTensorType /*all_reduce_c6, all_reduce_c7*/]> {
let summary = "AllReduce operation";
let description = [{
Within each process group in the process grid, applies a reduction function
Expand Down Expand Up @@ -1361,8 +1361,7 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",
let hasVerifier = 1;
}

def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter",
[SameOperandsAndResultElementType /*reduce_scatter_c8*/]> {
def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", []> {
let summary = "ReduceScatter operation";
let description = [{
Within each process group in the process grid, performs reduction, using
Expand Down Expand Up @@ -1447,7 +1446,7 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",
def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [
RecursiveMemoryEffects,
SameVariadicOperandSize /*reduce_c3*/,
InferTensorTypeWithReify /*reduce_c7*/,
InferTensorTypeWithReify /*reduce_c7, reduce_c8*/,
SingleBlockImplicitTerminator<"ReturnOp">
]> { /*reduce_c7*/
let summary = "Reduce operation";
Expand Down Expand Up @@ -2512,7 +2511,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [

def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects,
SameVariadicOperandSize /*scatter_c5*/,
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16*/]> {
DeclareOpInterfaceMethods<InferTypeOpInterface> /*scatter_c16,
scater_c17*/]> {
let summary = "Scatter operation";
let description = [{
Produces `results` tensors which are equal to `inputs` tensors except that
Expand Down Expand Up @@ -2585,8 +2585,8 @@ def StableHLO_SelectOp: StableHLO_Op<"select", [Pure, HLO_BroadcastingElementwis
}

def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter",
[DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11*/,
RecursiveMemoryEffects]> {
[DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11,
select_and_scatter_c12*/, RecursiveMemoryEffects]> {
let summary = "SelectAndScatter operation";
let description = [{
Scatters the values from the `source` tensor using `scatter` based on the
Expand Down
148 changes: 122 additions & 26 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,51 @@ bool tensorsHaveSameElType(Type type1, Type type2,
return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision);
}

unsigned potentiallyComplexBitWidth(Type type) {
auto complexTy = type.dyn_cast<ComplexType>();
return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
: type.getIntOrFloatBitWidth();
}

// Returns true if the element-type of type1 can be promoted to that of type2.
// An element-type 'x' is promotatble to element-type 'y' is they have the same
// base type and bitwidth(x) <= bitwidth(y). When 'x' and 'y' are quantized
// element-types, then promotion is applied only to the 'storage_type'
// component.
bool isPromotableElementType(Type type1, Type type2,
bool ignoreFpPrecision = false) {
auto tensorTy1 = type1.dyn_cast<ShapedType>();
auto tensorTy2 = type2.dyn_cast<ShapedType>();

if (!tensorTy1 || !tensorTy2) return false;

Type tensorEl1 = tensorTy1.getElementType();
Type tensorEl2 = tensorTy2.getElementType();

if (ignoreFpPrecision && tensorEl1.isa<FloatType>() &&
tensorTy2.getElementType().isa<FloatType>())
return true;

bool isSameType =
(tensorEl1.isa<IntegerType>() and tensorEl2.isa<IntegerType>()) ||
(tensorEl1.isa<FloatType>() and tensorEl2.isa<FloatType>()) ||
(tensorEl1.isa<ComplexType>() and tensorEl2.isa<ComplexType>()) ||
(tensorEl1.isa<quant::QuantizedType>() and
tensorEl2.isa<quant::QuantizedType>());

if (!isSameType) return false;

if (!tensorEl1.isa<quant::QuantizedType>())
return potentiallyComplexBitWidth(tensorEl1) <=
potentiallyComplexBitWidth(tensorEl2);

auto quantType1 = tensorEl1.cast<quant::QuantizedType>();
auto quantType2 = tensorEl2.cast<quant::QuantizedType>();
return quantType1.getExpressedType() == quantType2.getExpressedType() &&
potentiallyComplexBitWidth(quantType1.getStorageType()) <=
potentiallyComplexBitWidth(quantType2.getStorageType());
}

// Return true if type1 and type2 are shape-compatible and have same element
// type. If 'ignoreFpPrecision' is True, then allow floats with different
// precisions while checking element-types.
Expand Down Expand Up @@ -405,12 +450,6 @@ SmallVector<int64_t> inferWindowOutputShape(ArrayRef<int64_t> baseShape,
return outputDimensions;
}

unsigned potentiallyComplexBitWidth(Type type) {
auto complexTy = type.dyn_cast<ComplexType>();
return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth()
: type.getIntOrFloatBitWidth();
}

LogicalResult verifyReplicaGroups(std::optional<Location> location,
DenseIntElementsAttr replicaGroups,
bool allGroupsMustHaveSameSize,
Expand Down Expand Up @@ -530,6 +569,17 @@ LogicalResult verifyReduceOpInputsAndInferShape(
return success();
}

// Returns the types of the terminator arguments of the input mlir::Block
// 'block'.
SmallVector<ShapedType> getAccumulatorTypes(Block& block) {
SmallVector<ShapedType> accumulatorSubShapes;
for (Value retOperand : block.getTerminator()->getOperands()) {
auto shapedTy = retOperand.getType().dyn_cast<ShapedType>();
accumulatorSubShapes.push_back(shapedTy);
}
return accumulatorSubShapes;
}

LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,
ArrayRef<ShapedType> inputTypes,
ArrayRef<ShapedType> initValueTypes,
Expand Down Expand Up @@ -598,24 +648,37 @@ LogicalResult verifyReducerShape(std::optional<Location> loc, Block& block,

// all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13,
// reduce_window_i2, scatter_c6, scatter_c15, select_and_scatter_c10
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
initValueTypes[inputIdx],
/*ignoreFpPrecision=*/true))
if (failed(verifyCompatibleShape(initValueTypes[inputIdx],
accumulatorSubShapes[inputIdx])))
return emitOptionalError(
loc, "The type of reduction-region's result type at index ", inputIdx,
" differs from the op's corresponding init-value type: ",
loc, "The shape of reduction-region's result type at index ",
inputIdx, " differs from the op's corresponding init-value type: ",
accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]);

if (!isPromotableElementType(initValueTypes[inputIdx],
accumulatorSubShapes[inputIdx],
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The element-type of reduction-region's result type at index ",
inputIdx,
" is expected to be promotable from the op's corresponding "
"init-value element-type: ",
accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]);

// reduce_c6, reduce_window_c3, scatter_c6, scatter_c15,
// select_and_scatter_c10
if (!tensorsHaveSameElType(
if (!isPromotableElementType(
inputTypes[inputIdx],
block.getArgument(numInputs + inputIdx).getType(), true))
block.getArgument(numInputs + inputIdx).getType(),
/*ignoreFpPrecision=*/true))
return emitOptionalError(
loc, "The element-type of reduction-region's argument at index ",
numInputs + inputIdx, " is expected to be ",
numInputs + inputIdx, " is expected to be promotable from ",
inputTypes[inputIdx].getElementType(), ", but got ",
block.getArgument(numInputs + inputIdx).getType(), " as its type.");
block.getArgument(numInputs + inputIdx)
.getType()
.cast<ShapedType>()
.getElementType());

Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
auto blockArgTensorTy = blockArgType.cast<ShapedType>();
Expand Down Expand Up @@ -1453,6 +1516,17 @@ LogicalResult inferAllToAllOp(
return success();
}

LogicalResult inferAllReduceOp(
std::optional<Location> location, Value operand, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
// all_reduce_c6, all_reduce_c7
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnShapes.emplace_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
return success();
}

LogicalResult inferBatchNormGradOp(
std::optional<Location> location, Value operand, Value scale, Value mean,
Value variance, Value gradOutput, int64_t featureIndex,
Expand Down Expand Up @@ -2554,7 +2628,7 @@ LogicalResult inferRealOp(std::optional<Location>, Value operand,

LogicalResult inferReduceOp(
std::optional<Location> location, TypeRange inputTypes,
TypeRange initValueTypes, DenseIntElementsAttr dimensions,
TypeRange initValueTypes, DenseIntElementsAttr dimensions, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputArgTensorTypes{
llvm::map_range(inputTypes, [](Type t) { return t.cast<ShapedType>(); })};
Expand All @@ -2568,10 +2642,11 @@ LogicalResult inferReduceOp(
initValueTensorTypes, dimensions,
newDimensions, encoding)))
return failure();
// reduce_c2, reduce_c3, reduce_c7
// reduce_c3, reduce_c7, reduce_c8
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) {
ShapedType inputType = inputArgTensorTypes[inputIdx];
Type elementType = inputType.getElementType();
Type elementType = accumulatorTypes[inputIdx].getElementType();
if (inputType.hasRank())
inferredReturnShapes.emplace_back(newDimensions, elementType, encoding);
else
Expand All @@ -2587,7 +2662,7 @@ LogicalResult inferReduceWindowOp(
std::optional<DenseIntElementsAttr> windowStrides,
std::optional<DenseIntElementsAttr> baseDilations,
std::optional<DenseIntElementsAttr> windowDilations,
std::optional<DenseIntElementsAttr> padding,
std::optional<DenseIntElementsAttr> padding, Region& body,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<ShapedType> inputTypes{llvm::map_range(
inputs.getTypes(), [](Type t) { return t.cast<ShapedType>(); })};
Expand All @@ -2604,21 +2679,22 @@ LogicalResult inferReduceWindowOp(
return failure();

// reduce_window_c1, reduce_window_c14...reduce_window_c16
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
for (size_t i = 0; i < inputTypes.size(); ++i) {
auto inputRankedType = inputs[i].getType().dyn_cast<RankedTensorType>();
if (!inputRankedType) {
inferredReturnShapes.emplace_back(inputTypes[i].getElementType());
inferredReturnShapes.emplace_back(accumulatorTypes[i].getElementType());
} else {
auto resultShape =
inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow);
auto inputBounds = encodingToBounds(inputRankedType.getEncoding());
if (inputBounds.empty()) {
inferredReturnShapes.emplace_back(resultShape,
inputTypes[i].getElementType());
accumulatorTypes[i].getElementType());
} else {
auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow);
inferredReturnShapes.emplace_back(
resultShape, inputTypes[i].getElementType(),
resultShape, accumulatorTypes[i].getElementType(),
boundsToEncoding(inputRankedType.getEncoding(), resultBounds));
}
}
Expand Down Expand Up @@ -2683,8 +2759,15 @@ LogicalResult inferRngOp(
}

LogicalResult inferScatterOp(std::optional<Location>, ValueRange inputs,
Region& body,
SmallVectorImpl<Type>& inferredReturnTypes) {
llvm::append_range(inferredReturnTypes, inputs.getTypes());
// scatter_c16, scatter_c17
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) {
auto inputShapedTy = inputs[inputIdx].getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
inputShapedTy, accumulatorTypes[inputIdx].getElementType()));
}
return success();
}

Expand Down Expand Up @@ -2714,9 +2797,12 @@ LogicalResult inferSelectOp(
}

LogicalResult inferSelectAndScatterOp(
Value operand, SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11
inferredReturnTypes.push_back(operand.getType());
Value operand, Region& body, SmallVectorImpl<Type>& inferredReturnTypes) {
// select_and_scatter_c11, select_and_scatter_c12
SmallVector<ShapedType> accumulatorTypes = getAccumulatorTypes(body.front());
auto operandShapedTy = operand.getType().cast<ShapedType>();
inferredReturnTypes.push_back(getSameShapeTensorType(
operandShapedTy, accumulatorTypes[0].getElementType()));
return success();
}

Expand Down Expand Up @@ -3871,6 +3957,16 @@ LogicalResult verifyReduceScatterOp(std::optional<Location> location,
operandType.getDimSize(index), ") and result (",
resultType.getDimSize(index), ")");
}

// reduce_scatter_c9
SmallVector<ShapedType> accumulatorTypes =
getAccumulatorTypes(computation.front());
if (resultType.getElementType() != accumulatorTypes[0].getElementType()) {
return emitOptionalError(location, "result element-type is expected to be ",
accumulatorTypes[0].getElementType(), ", but got ",
resultType.getElementType());
}

return success();
}

Expand Down
Loading

0 comments on commit c668b0e

Please sign in to comment.