Skip to content

Commit

Permalink
Add shape refinement pass for DotOp (#2064)
Browse files Browse the repository at this point in the history
`DotOp` currently doesn't have shape refinement, so if downstream
exports stablehlo with dot, and it contains unbounded dynamic shapes
with `stablehlo.dot`, it fails to refine shape.

This change enables refinement, and prevents users from having to
transform `dot` back to `dot_general` whenever canonicalization pass
(for `dot_general` to `dot`) is run implicitly.
  • Loading branch information
ghpvnist authored Mar 5, 2024
1 parent ab9ea3e commit 3269dcc
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 18 deletions.
4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,8 @@ LogicalResult CholeskyOp::inferReturnTypeComponents(
//===----------------------------------------------------------------------===//

LogicalResult DotOp::verify() {
return hlo::verifyDotOp(getLoc(), getLhs(), getRhs(), getPrecisionConfig(),
getResult());
return hlo::verifyDotOp(getLoc(), getLhs().getType(), getRhs().getType(),
getPrecisionConfig(), getResult());
}

// PrecisionConfig - std::optional attribute, print the array as raw enums
Expand Down
18 changes: 6 additions & 12 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,19 +1884,12 @@ LogicalResult inferCreateTokenOp(HloDialectInterface* dialect,
}

LogicalResult inferDotOp(
std::optional<Location> location, Value lhs, Value rhs,
std::optional<ArrayAttr> precisionConfig,
std::optional<Location> location, RankedTensorType lhsType,
RankedTensorType rhsType, std::optional<ArrayAttr> precisionConfig,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
if (failed(verifyPrecisionConfig(location, precisionConfig)))
return failure();

auto lhsType = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsType = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhsType || !rhsType) {
inferredReturnShapes.push_back({});
return success();
}

SmallVector<int64_t> dimensions;
if (1 == lhsType.getRank() && 1 == rhsType.getRank() &&
// vector dot vector
Expand Down Expand Up @@ -3403,11 +3396,12 @@ LogicalResult verifyConvolutionOp(
return success();
}

LogicalResult verifyDotOp(std::optional<Location> location, Value lhs,
Value rhs, std::optional<ArrayAttr> precisionConfig,
LogicalResult verifyDotOp(std::optional<Location> location,
RankedTensorType lhsType, RankedTensorType rhsType,
std::optional<ArrayAttr> precisionConfig,
Value result) {
SmallVector<ShapedTypeComponents> inferredReturnShapes;
if (failed(inferDotOp(location, lhs, rhs, precisionConfig,
if (failed(inferDotOp(location, lhsType, rhsType, precisionConfig,
inferredReturnShapes)))
return failure();

Expand Down
9 changes: 5 additions & 4 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ LogicalResult inferCreateTokenOp(HloDialectInterface* dialect,
SmallVectorImpl<Type>& inferredReturnTypes);

LogicalResult inferDotOp(
std::optional<Location> location, Value lhs, Value rhs,
std::optional<ArrayAttr> precisionConfig,
std::optional<Location> location, RankedTensorType lhsType,
RankedTensorType rhsType, std::optional<ArrayAttr> precisionConfig,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes);

LogicalResult inferDotGeneralOp(
Expand Down Expand Up @@ -412,8 +412,9 @@ LogicalResult verifyConvolutionOp(
int64_t featureGroupCount, int64_t batchGroupCount,
std::optional<ArrayAttr> precisionConfig, Type resultType);

LogicalResult verifyDotOp(std::optional<Location> location, Value lhs,
Value rhs, std::optional<ArrayAttr> precisionConfig,
LogicalResult verifyDotOp(std::optional<Location> location,
RankedTensorType lhsType, RankedTensorType rhsType,
std::optional<ArrayAttr> precisionConfig,
Value result);

LogicalResult verifyDotGeneralOp(std::optional<Location> location, Value lhs,
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,15 @@ func.func @refine_dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>

// -----

// CHECK-LABEL: @refine_dot
func.func @refine_dot(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<?x?xf32> {
// CHECK: stablehlo.dot{{.*}} -> tensor<3x5xf32>
%0 = stablehlo.dot %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: @refine_dynamic_broadcast_in_dim
func.func @refine_dynamic_broadcast_in_dim(%arg0: tensor<4xf32>) -> tensor<?x?xf32> {
// CHECK: stablehlo.dynamic_broadcast_in_dim{{.*}} -> tensor<3x4xf32>
Expand Down
14 changes: 14 additions & 0 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,19 @@ struct RefineDotGeneralOpPattern : public OpRewritePattern<DotGeneralOp> {
}
};

struct RefineDotOpPattern : public OpRewritePattern<DotOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(DotOp op,
PatternRewriter& rewriter) const override {
SmallVector<ShapedTypeComponents> inferredReturnShapes;
if (failed(hlo::inferDotOp(
/*location=*/{}, op.getLhs().getType(), op.getRhs().getType(),
op.getPrecisionConfig(), inferredReturnShapes)))
return rewriter.notifyMatchFailure(op, "inferDotOp failed");
return refineReturnTypes(rewriter, op, inferredReturnShapes);
}
};

struct RefineDynamicBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -1179,6 +1192,7 @@ void populateStablehloRefineShapesPatterns(RewritePatternSet* patterns,
patterns->add<RefineConvolutionOpPattern>(context);
patterns->add<RefineCustomCallOpPattern>(context);
patterns->add<RefineDotGeneralOpPattern>(context);
patterns->add<RefineDotOpPattern>(context);
patterns->add<RefineDynamicBroadcastInDimOpPattern>(context);
patterns->add<RefineDynamicConvOpPattern>(context);
patterns->add<RefineDynamicIotaOpPattern>(context);
Expand Down

0 comments on commit 3269dcc

Please sign in to comment.