diff --git a/stablehlo/dialect/ChloOps.td b/stablehlo/dialect/ChloOps.td index 5e3385618f0..bbc21620131 100644 --- a/stablehlo/dialect/ChloOps.td +++ b/stablehlo/dialect/ChloOps.td @@ -59,13 +59,14 @@ def CHLO_Dialect : Dialect { class CHLO_Op traits> : Op { - let extraClassDeclaration = [{ + string commonClassDeclaration = [{ // Relax the strict default implementation with one that allows // for StableHLO-specific differences. static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { return mlir::hlo::isCompatibleForHloTypeInference(l, r); } }]; + let extraClassDeclaration = commonClassDeclaration; } include "stablehlo/dialect/ChloEnums.td" @@ -111,11 +112,6 @@ class CHLO_BroadcastBinaryElementwiseOp< `(` type($lhs) `,` type($rhs) `)` `->` type(results) }]; - let extraClassDeclaration = [{ - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { - return mlir::hlo::isCompatibleForHloTypeInference(l, r); - } - }]; } def CHLO_BroadcastAddOp : CHLO_BroadcastBinaryElementwiseOp<"broadcast_add", @@ -468,15 +464,12 @@ class CHLO_UnaryElementwiseOp traits, $operand attr-dict `:` type($operand) `->` type($result) }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = commonClassDeclaration # [{ LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), operands.front(), &reifiedReturnShapes); } - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { - return mlir::hlo::isCompatibleForHloTypeInference(l, r); - } }]; } diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 8c5c7dc7369..30d561da32c 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -42,13 +42,14 @@ def StableHLO_Dialect : Dialect { class StableHLO_Op traits> : Op { - let extraClassDeclaration = [{ + string commonClassDeclaration = [{ // Relax the strict default implementation with one that allows // for StableHLO-specific differences. static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { return mlir::hlo::isCompatibleForHloTypeInference(l, r); } }]; + let extraClassDeclaration = commonClassDeclaration; } include "stablehlo/dialect/StablehloEnums.td" @@ -57,13 +58,7 @@ include "stablehlo/dialect/StablehloTypes.td" class StableHLO_ShapedInterfaceOp traits> : StableHLO_Op]> { - let extraClassDeclaration = [{ - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { - return mlir::hlo::isCompatibleForHloTypeInference(l, r); - } - }]; -} + ["reifyReturnTypeShapes"]>]> {} //===----------------------------------------------------------------------===// // StableHLO nullary op definitions. @@ -181,7 +176,7 @@ class StableHLO_UnaryElementwiseOp traits, InferShapedTypeOpInterface, SameOperandsAndResultShape]> { let arguments = (ins OperandType:$operand); let results = (outs ResultType:$result); - let extraClassDeclaration = [{ + let extraClassDeclaration = commonClassDeclaration # [{ LogicalResult reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { @@ -189,9 +184,6 @@ class StableHLO_UnaryElementwiseOp traits, operands.front(), &reifiedReturnShapes); } - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { - return mlir::hlo::isCompatibleForHloTypeInference(l, r); - } }]; let assemblyFormat = [{ @@ -667,7 +659,7 @@ class StableHLO_BinaryElementwiseOp traits, OperandType:$rhs ); - let extraClassDeclaration = [{ + let extraClassDeclaration = commonClassDeclaration # [{ LogicalResult reifyReturnTypeShapes( OpBuilder& builder, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { @@ -675,9 +667,6 @@ class StableHLO_BinaryElementwiseOp traits, operands.front(), &reifiedReturnShapes); } - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { - return mlir::hlo::isCompatibleForHloTypeInference(l, r); - } }]; let results = (outs ResultType:$result); @@ -1275,7 +1264,7 @@ def StableHLO_WhileOp: StableHLO_Op<"while", [ let hasVerifier = 1; - let extraClassDeclaration = [{ + let extraClassDeclaration = commonClassDeclaration # [{ // Method of OpAsmOpInterface used during custom printing to name the block // arguments in the nested regions. We name both the condition and the body // regions entry arguments the same way, with a `iterArg` prefix. Since the @@ -1286,10 +1275,6 @@ def StableHLO_WhileOp: StableHLO_Op<"while", [ for (BlockArgument arg : region.getArguments()) setNameFn(arg, "iterArg"); } - - static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { - return mlir::hlo::isCompatibleForHloTypeInference(l, r); - } }]; let hasCustomAssemblyFormat = 1; }