Skip to content

Commit

Permalink
compatibility constraint checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Dec 15, 2023
1 parent 901073b commit cd14e1a
Show file tree
Hide file tree
Showing 10 changed files with 2,788 additions and 9 deletions.
2 changes: 1 addition & 1 deletion stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ bool isPromotableElementType(Type type1, Type type2,
Type tensorEl2 = tensorTy2.getElementType();

if (ignoreFpPrecision && tensorEl1.isa<FloatType>() &&
tensorTy2.getElementType().isa<FloatType>())
tensorEl2.isa<FloatType>())
return true;

bool isSameType =
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(0, 16, 3); }
static Version getCurrentVersion() { return Version(0, 17, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def VHLO_Dialect : Dialect {
0.14.0: MLIR bytecode version 3 => 5 (revised to 4 in #1827).
0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO.
0.16.0: Introduce `collective_broadcast` operation.
0.17.0: Allow reduce operations to promote to higher bitwidth.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
96 changes: 96 additions & 0 deletions stablehlo/dialect/VhloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
Expand All @@ -37,6 +38,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h"
#include "stablehlo/dialect/AssemblyFormat.h"
#include "stablehlo/dialect/VhloBytecode.h"
#include "stablehlo/dialect/VhloTypes.h"

namespace mlir {
namespace vhlo {
Expand Down Expand Up @@ -296,5 +298,99 @@ void VhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const {
assert(succeeded(result));
}

///////////////////////////
// Op Constraint Versioning
///////////////////////////
// These could be migrated to ODS in VhloOps.td if we figured out a better way
// to represent this sort of constraint in tablegen.

namespace {
bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes,
TypeRange resultTypes) {
SmallVector<ShapedType> inputShapedTypes{
llvm::map_range(operandTypes, [](Type t) {
return convertTypeToBuiltinForPrint(t).cast<ShapedType>();
})};
SmallVector<ShapedType> resultShapedTypes{
llvm::map_range(resultTypes, [](Type t) {
return convertTypeToBuiltinForPrint(t).cast<ShapedType>();
})};

int64_t numInputs = inputShapedTypes.size();
for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
if (inputShapedTypes[inputIdx].getElementType() !=
resultShapedTypes[inputIdx].getElementType())
return true;
}
return false;
}
} // namespace

LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(),
getResult().getType()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
}

LogicalResult ReduceOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(),
getResultTypes()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
}

LogicalResult ReduceScatterOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(),
getResult().getType()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
}

LogicalResult ReduceWindowOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(),
getResultTypes()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
}

LogicalResult ScatterOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(),
getResultTypes()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
}

LogicalResult SelectAndScatterOpV1::validateConstraint(mlir::Operation* op,
Version targetVersion) {
// Allow mismatched operand and result types in v0.17.0
if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(),
getResult().getType()) &&
targetVersion < Version(0, 17, 0))
return failure();

return success();
}

} // namespace vhlo
} // namespace mlir
29 changes: 23 additions & 6 deletions stablehlo/dialect/VhloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ def VHLO_VersionedOpInterface : OpInterface<"VersionedOpInterface"> {
];
}

def VHLO_VersionedOpConstraintInterface : OpInterface<"VersionedOpConstraintInterface"> {
let cppNamespace = "::mlir::vhlo";
let methods = [
InterfaceMethod<
[{Validate versioned constraints on a versioned op.
Used if the spec'ed constraints of an op change over time.}],
"mlir::LogicalResult", "validateConstraint",
(ins "mlir::Operation*":$op, "mlir::vhlo::Version":$targetVersion)>,
];
}

class VHLO_Op<string mnemonic, string minVersion, string maxVersion, list<Trait> traits = []> :
Op<VHLO_Dialect, mnemonic,
[DeclareOpInterfaceMethods<VHLO_VersionedOpInterface>] # traits> {
Expand Down Expand Up @@ -92,7 +103,8 @@ def VHLO_AllGatherOpV1 : VHLO_Op<"all_gather_v1", "0.9.0", "current"> {
let results = (outs VHLO_AnyType:$result);
}

def VHLO_AllReduceOpV1 : VHLO_Op<"all_reduce_v1", "0.9.0", "current"> {
def VHLO_AllReduceOpV1 : VHLO_Op<"all_reduce_v1", "0.9.0", "current",
[DeclareOpInterfaceMethods<VHLO_VersionedOpConstraintInterface>]> {
let arguments = (ins
VHLO_AnyType:$operand,
VHLO_AnyAttr:$replica_groups,
Expand Down Expand Up @@ -754,7 +766,8 @@ def VHLO_RecvOpV1 : VHLO_Op<"recv_v1", "0.9.0", "current"> {
let results = (outs Variadic<VHLO_AnyType>:$results);
}

def VHLO_ReduceOpV1 : VHLO_Op<"reduce_v1", "0.9.0", "current", [SameVariadicOperandSize]> {
def VHLO_ReduceOpV1 : VHLO_Op<"reduce_v1", "0.9.0", "current",
[SameVariadicOperandSize, DeclareOpInterfaceMethods<VHLO_VersionedOpConstraintInterface>]> {
let arguments = (ins
Variadic<VHLO_AnyType>:$inputs,
Variadic<VHLO_AnyType>:$init_values,
Expand All @@ -773,7 +786,8 @@ def VHLO_ReducePrecisionOpV1 : VHLO_Op<"reduce_precision_v1", "0.9.0", "current"
let results = (outs VHLO_AnyType:$output);
}

def VHLO_ReduceScatterOpV1 : VHLO_Op<"reduce_scatter_v1", "0.9.0", "current"> {
def VHLO_ReduceScatterOpV1 : VHLO_Op<"reduce_scatter_v1", "0.9.0", "current",
[DeclareOpInterfaceMethods<VHLO_VersionedOpConstraintInterface>]> {
let arguments = (ins
VHLO_AnyType:$operand,
VHLO_AnyAttr:$scatter_dimension,
Expand All @@ -785,7 +799,8 @@ def VHLO_ReduceScatterOpV1 : VHLO_Op<"reduce_scatter_v1", "0.9.0", "current"> {
let results = (outs VHLO_AnyType:$result);
}

def VHLO_ReduceWindowOpV1 : VHLO_Op<"reduce_window_v1", "0.9.0", "current", [SameVariadicOperandSize]> {
def VHLO_ReduceWindowOpV1 : VHLO_Op<"reduce_window_v1", "0.9.0", "current",
[SameVariadicOperandSize, DeclareOpInterfaceMethods<VHLO_VersionedOpConstraintInterface>]> {
let arguments = (ins
Variadic<VHLO_AnyType>:$inputs,
Variadic<VHLO_AnyType>:$init_values,
Expand Down Expand Up @@ -864,7 +879,8 @@ def VHLO_RsqrtOpV1 : VHLO_Op<"rsqrt_v1", "0.9.0", "current"> {
let results = (outs VHLO_AnyType:$result);
}

def VHLO_ScatterOpV1 : VHLO_Op<"scatter_v1", "0.9.0", "current", [SameVariadicOperandSize]> {
def VHLO_ScatterOpV1 : VHLO_Op<"scatter_v1", "0.9.0", "current",
[SameVariadicOperandSize, DeclareOpInterfaceMethods<VHLO_VersionedOpConstraintInterface>]> {
let arguments = (ins
Variadic<VHLO_AnyType>:$inputs,
VHLO_AnyType:$scatter_indices,
Expand All @@ -880,7 +896,8 @@ def VHLO_ScatterOpV1 : VHLO_Op<"scatter_v1", "0.9.0", "current", [SameVariadicOp
let results = (outs Variadic<VHLO_AnyType>:$results);
}

def VHLO_SelectAndScatterOpV1 : VHLO_Op<"select_and_scatter_v1", "0.9.0", "current"> {
def VHLO_SelectAndScatterOpV1 : VHLO_Op<"select_and_scatter_v1", "0.9.0", "current",
[DeclareOpInterfaceMethods<VHLO_VersionedOpConstraintInterface>]> {
let arguments = (ins
VHLO_AnyType:$operand,
VHLO_AnyType:$source,
Expand Down
Loading

0 comments on commit cd14e1a

Please sign in to comment.