From 3269dccdf7b3fba96172e28d0a64a9dca8197122 Mon Sep 17 00:00:00 2001 From: Gunhyun Park Date: Tue, 5 Mar 2024 11:50:59 -0800 Subject: [PATCH] Add shape refinement pass for DotOp (#2064) `DotOp` currently doesn't have shape refinement, so if downstream exports stablehlo with dot, and it contains unbounded dynamic shapes with `stablehlo.dot`, it fails to refine shape. This change enables refinement, and prevents users from having to transform `dot` back to `dot_general` whenever canonicalization pass (for `dot_general` to `dot`) is run implicitly. --- stablehlo/dialect/StablehloOps.cpp | 4 ++-- stablehlo/dialect/TypeInference.cpp | 18 ++++++------------ stablehlo/dialect/TypeInference.h | 9 +++++---- stablehlo/tests/stablehlo_refine_shapes.mlir | 9 +++++++++ stablehlo/transforms/StablehloRefineShapes.cpp | 14 ++++++++++++++ 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index fc6f01c081f..a88f993fb5a 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -481,8 +481,8 @@ LogicalResult CholeskyOp::inferReturnTypeComponents( //===----------------------------------------------------------------------===// LogicalResult DotOp::verify() { - return hlo::verifyDotOp(getLoc(), getLhs(), getRhs(), getPrecisionConfig(), - getResult()); + return hlo::verifyDotOp(getLoc(), getLhs().getType(), getRhs().getType(), + getPrecisionConfig(), getResult()); } // PrecisionConfig - std::optional attribute, print the array as raw enums diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 46158aaa8d8..f7915996d7c 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -1884,19 +1884,12 @@ LogicalResult inferCreateTokenOp(HloDialectInterface* dialect, } LogicalResult inferDotOp( - std::optional location, Value lhs, Value rhs, - std::optional precisionConfig, + std::optional location, RankedTensorType lhsType, + RankedTensorType rhsType, std::optional precisionConfig, SmallVectorImpl& inferredReturnShapes) { if (failed(verifyPrecisionConfig(location, precisionConfig))) return failure(); - auto lhsType = lhs.getType().dyn_cast(); - auto rhsType = rhs.getType().dyn_cast(); - if (!lhsType || !rhsType) { - inferredReturnShapes.push_back({}); - return success(); - } - SmallVector dimensions; if (1 == lhsType.getRank() && 1 == rhsType.getRank() && // vector dot vector @@ -3403,11 +3396,12 @@ LogicalResult verifyConvolutionOp( return success(); } -LogicalResult verifyDotOp(std::optional location, Value lhs, - Value rhs, std::optional precisionConfig, +LogicalResult verifyDotOp(std::optional location, + RankedTensorType lhsType, RankedTensorType rhsType, + std::optional precisionConfig, Value result) { SmallVector inferredReturnShapes; - if (failed(inferDotOp(location, lhs, rhs, precisionConfig, + if (failed(inferDotOp(location, lhsType, rhsType, precisionConfig, inferredReturnShapes))) return failure(); diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index fa0123e3d19..7f9d7315032 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -194,8 +194,8 @@ LogicalResult inferCreateTokenOp(HloDialectInterface* dialect, SmallVectorImpl& inferredReturnTypes); LogicalResult inferDotOp( - std::optional location, Value lhs, Value rhs, - std::optional precisionConfig, + std::optional location, RankedTensorType lhsType, + RankedTensorType rhsType, std::optional precisionConfig, SmallVectorImpl& inferredReturnShapes); LogicalResult inferDotGeneralOp( @@ -412,8 +412,9 @@ LogicalResult verifyConvolutionOp( int64_t featureGroupCount, int64_t batchGroupCount, std::optional precisionConfig, Type resultType); -LogicalResult verifyDotOp(std::optional location, Value lhs, - Value rhs, std::optional precisionConfig, +LogicalResult verifyDotOp(std::optional location, + RankedTensorType lhsType, RankedTensorType rhsType, + std::optional precisionConfig, Value result); LogicalResult verifyDotGeneralOp(std::optional location, Value lhs, diff --git a/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/tests/stablehlo_refine_shapes.mlir index 23c9a402344..44c215b1019 100644 --- a/stablehlo/tests/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/stablehlo_refine_shapes.mlir @@ -562,6 +562,15 @@ func.func @refine_dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32> // ----- +// CHECK-LABEL: @refine_dot +func.func @refine_dot(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor { + // CHECK: stablehlo.dot{{.*}} -> tensor<3x5xf32> + %0 = stablehlo.dot %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: @refine_dynamic_broadcast_in_dim func.func @refine_dynamic_broadcast_in_dim(%arg0: tensor<4xf32>) -> tensor { // CHECK: stablehlo.dynamic_broadcast_in_dim{{.*}} -> tensor<3x4xf32> diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index b9921ae6856..178753c9001 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -740,6 +740,19 @@ struct RefineDotGeneralOpPattern : public OpRewritePattern { } }; +struct RefineDotOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DotOp op, + PatternRewriter& rewriter) const override { + SmallVector inferredReturnShapes; + if (failed(hlo::inferDotOp( + /*location=*/{}, op.getLhs().getType(), op.getRhs().getType(), + op.getPrecisionConfig(), inferredReturnShapes))) + return rewriter.notifyMatchFailure(op, "inferDotOp failed"); + return refineReturnTypes(rewriter, op, inferredReturnShapes); + } +}; + struct RefineDynamicBroadcastInDimOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1179,6 +1192,7 @@ void populateStablehloRefineShapesPatterns(RewritePatternSet* patterns, patterns->add(context); patterns->add(context); patterns->add(context); + patterns->add(context); patterns->add(context); patterns->add(context); patterns->add(context);