Skip to content

Commit

Permalink
Prevent OOB indexing in StableHLO/MHLO ops. (#1846)
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK authored Nov 20, 2023
1 parent 202cb2b commit 458b691
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 3 deletions.
12 changes: 12 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,23 @@ cc_library(
":linalg_pass_inc_gen",
":stablehlo_ops",
":chlo_ops",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:LinalgUtils",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:SparseTensorDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorDialect",
],
Expand Down Expand Up @@ -1426,3 +1437,4 @@ test_suite(
"//stablehlo/conversions/linalg/tests:stablehlo_linalg_tests",
],
)

17 changes: 17 additions & 0 deletions stablehlo/conversions/linalg/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,24 @@ add_mlir_library(StablehloLinalgTransforms
Core

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRBufferizationDialect
MLIRComplexDialect
MLIRFuncDialect
MLIRIR
MLIRLinalgDialect
MLIRLinalgTransforms
MLIRLinalgUtils
MLIRMathDialect
MLIRMemRefDialect
MLIRPass
MLIRPass
MLIRSCFDialect
MLIRShapeDialect
MLIRSparseTensorDialect
MLIRSupport
MLIRTensorDialect
MLIRTransforms
MLIRTransforms
MLIRVectorDialect
)
16 changes: 16 additions & 0 deletions stablehlo/conversions/linalg/transforms/PassDetail.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/* Copyright 2022 The IREE Authors
Copyright 2023 OpenXLA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSDETAIL_H
#define STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSDETAIL_H

Expand Down
16 changes: 16 additions & 0 deletions stablehlo/conversions/linalg/transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/* Copyright 2022 The IREE Authors
Copyright 2023 OpenXLA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSES_H
#define STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSES_H

Expand Down
16 changes: 16 additions & 0 deletions stablehlo/conversions/linalg/transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/* Copyright 2022 The IREE Authors
Copyright 2023 OpenXLA Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef STABLEHLO_TO_LINALG_PASSES
#define STABLEHLO_TO_LINALG_PASSES

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
/* Copyright 2022 The IREE Authors
Copyright 2023 OpenXLA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "PassDetail.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
Expand Down
13 changes: 11 additions & 2 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3510,12 +3510,21 @@ LogicalResult verifyDynamicBroadcastInDimOp(
}

LogicalResult verifyDynamicIotaOp(std::optional<Location> location,
Value outputShape, int64_t outputDimension,
Value outputShape, int64_t iotaDimension,
Value result) {
if (!isCompatibleForHloTypeInference(outputShape, result.getType()))
auto shape = result.getType().cast<ShapedType>();
if (!isCompatibleForHloTypeInference(outputShape, shape))
return emitOptionalError(
location, "output_shape is incompatible with return type of operation ",
result.getType());

if (!shape.hasRank()) return success();

if (iotaDimension >= shape.getRank() || iotaDimension < 0)
return emitOptionalError(
location,
"iota dimension cannot go beyond the output rank or be negative.");

return success();
}

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ LogicalResult verifyDynamicBroadcastInDimOp(
Value result);

LogicalResult verifyDynamicIotaOp(std::optional<Location> location,
Value outputShape, int64_t outputDimension,
Value outputShape, int64_t iotaDimension,
Value result);

LogicalResult verifyDynamicPadOp(std::optional<Location> location,
Expand Down
26 changes: 26 additions & 0 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5648,6 +5648,32 @@ func.func @dynamic_iota_unranked() -> tensor<*xf32> {

// -----

func.func @dynamic_iota_unranked_large() -> tensor<*xf32> {
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 3 : (tensor<1xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
}

// -----

func.func @dynamic_iota_invalid_iota_dimension_negative() -> tensor<?xf32> {
// expected-error@+2 {{iota dimension cannot go beyond the output rank or be negative}}
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = -1 : (tensor<1xi64>) -> tensor<?xf32>
func.return %1 : tensor<?xf32>
}

// -----

func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor<?xf32> {
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
// expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}}
%1 = stablehlo.dynamic_iota %0, dim = 2 : (tensor<1xi64>) -> tensor<?xf32>
func.return %1 : tensor<?xf32>
}

// -----

func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> {
// @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}}
%0 = stablehlo.constant dense<[-1]> : tensor<1xi64>
Expand Down

0 comments on commit 458b691

Please sign in to comment.