Skip to content

Commit

Permalink
Avoid duplicating extra class declarations (#1990)
Browse files Browse the repository at this point in the history
I noticed some duplication and wondered if it was possible to remove it.
Then I stumbled upon this thread which provided a solution:
https://discourse.llvm.org/t/how-to-add-new-content-to-extraclassdeclaration-in-subclasses/71537

The idea is to declare the common class declaration in the base class
and concatenate it with the child-specific declarations.

In some cases, the field was being overriden only to be set to the same
value as the parent. In such cases, I simply removed the override.
  • Loading branch information
mlevesquedion authored Feb 6, 2024
1 parent 8fbf991 commit aa9a196
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 31 deletions.
13 changes: 3 additions & 10 deletions stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ def CHLO_Dialect : Dialect {

class CHLO_Op<string mnemonic, list<Trait> traits> :
Op<CHLO_Dialect, mnemonic, traits> {
let extraClassDeclaration = [{
string commonClassDeclaration = [{
// Relax the strict default implementation with one that allows
// for StableHLO-specific differences.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
let extraClassDeclaration = commonClassDeclaration;
}

include "stablehlo/dialect/ChloEnums.td"
Expand Down Expand Up @@ -111,11 +112,6 @@ class CHLO_BroadcastBinaryElementwiseOp<
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
}];

let extraClassDeclaration = [{
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
}

def CHLO_BroadcastAddOp : CHLO_BroadcastBinaryElementwiseOp<"broadcast_add",
Expand Down Expand Up @@ -468,15 +464,12 @@ class CHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits,
$operand attr-dict `:` type($operand) `->` type($result)
}];

let extraClassDeclaration = [{
let extraClassDeclaration = commonClassDeclaration # [{
LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
operands.front(), &reifiedReturnShapes);
}
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
}

Expand Down
27 changes: 6 additions & 21 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ def StableHLO_Dialect : Dialect {

class StableHLO_Op<string mnemonic, list<Trait> traits> :
Op<StableHLO_Dialect, mnemonic, traits> {
let extraClassDeclaration = [{
string commonClassDeclaration = [{
// Relax the strict default implementation with one that allows
// for StableHLO-specific differences.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
let extraClassDeclaration = commonClassDeclaration;
}

include "stablehlo/dialect/StablehloEnums.td"
Expand All @@ -57,13 +58,7 @@ include "stablehlo/dialect/StablehloTypes.td"

class StableHLO_ShapedInterfaceOp<string mnemonic, list<Trait> traits> :
StableHLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]> {
let extraClassDeclaration = [{
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
}
["reifyReturnTypeShapes"]>]> {}

//===----------------------------------------------------------------------===//
// StableHLO nullary op definitions.
Expand Down Expand Up @@ -181,17 +176,14 @@ class StableHLO_UnaryElementwiseOp<string mnemonic, list<Trait> traits,
InferShapedTypeOpInterface, SameOperandsAndResultShape]> {
let arguments = (ins OperandType:$operand);
let results = (outs ResultType:$result);
let extraClassDeclaration = [{
let extraClassDeclaration = commonClassDeclaration # [{
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
operands.front(),
&reifiedReturnShapes);
}
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -667,17 +659,14 @@ class StableHLO_BinaryElementwiseOp<string mnemonic, list<Trait> traits,
OperandType:$rhs
);

let extraClassDeclaration = [{
let extraClassDeclaration = commonClassDeclaration # [{
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
operands.front(),
&reifiedReturnShapes);
}
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];

let results = (outs ResultType:$result);
Expand Down Expand Up @@ -1275,7 +1264,7 @@ def StableHLO_WhileOp: StableHLO_Op<"while", [

let hasVerifier = 1;

let extraClassDeclaration = [{
let extraClassDeclaration = commonClassDeclaration # [{
// Method of OpAsmOpInterface used during custom printing to name the block
// arguments in the nested regions. We name both the condition and the body
// regions entry arguments the same way, with a `iterArg` prefix. Since the
Expand All @@ -1286,10 +1275,6 @@ def StableHLO_WhileOp: StableHLO_Op<"while", [
for (BlockArgument arg : region.getArguments())
setNameFn(arg, "iterArg");
}

static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return mlir::hlo::isCompatibleForHloTypeInference(l, r);
}
}];
let hasCustomAssemblyFormat = 1;
}
Expand Down

0 comments on commit aa9a196

Please sign in to comment.