From 458b691ea89a862a9d72529583fc922a8f39439b Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 20 Nov 2023 13:14:50 -0600 Subject: [PATCH] Prevent OOB indexing in StableHLO/MHLO ops. (#1846) Backport of https://github.com/tensorflow/mlir-hlo/commit/ff93081bde146359869caff42bf6346383f9eb8b --- BUILD.bazel | 12 +++++++++ .../linalg/transforms/CMakeLists.txt | 17 ++++++++++++ .../linalg/transforms/PassDetail.h | 16 ++++++++++++ .../conversions/linalg/transforms/Passes.h | 16 ++++++++++++ .../conversions/linalg/transforms/Passes.td | 16 ++++++++++++ .../transforms/StablehloLegalizeToLinalg.cpp | 16 ++++++++++++ stablehlo/dialect/TypeInference.cpp | 13 ++++++++-- stablehlo/dialect/TypeInference.h | 2 +- stablehlo/tests/ops_stablehlo.mlir | 26 +++++++++++++++++++ 9 files changed, 131 insertions(+), 3 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index d1b8bd0c891..9bb1e03b266 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -375,12 +375,23 @@ cc_library( ":linalg_pass_inc_gen", ":stablehlo_ops", ":chlo_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BufferizationDialect", + "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgUtils", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorDialect", ], @@ -1426,3 +1437,4 @@ test_suite( "//stablehlo/conversions/linalg/tests:stablehlo_linalg_tests", ], ) + diff --git a/stablehlo/conversions/linalg/transforms/CMakeLists.txt b/stablehlo/conversions/linalg/transforms/CMakeLists.txt index a148f8a22fe..4731573be85 100644 --- a/stablehlo/conversions/linalg/transforms/CMakeLists.txt +++ b/stablehlo/conversions/linalg/transforms/CMakeLists.txt @@ -21,7 +21,24 @@ add_mlir_library(StablehloLinalgTransforms Core LINK_LIBS PUBLIC + MLIRArithDialect + MLIRBufferizationDialect + MLIRComplexDialect + MLIRFuncDialect MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLinalgUtils + MLIRMathDialect + MLIRMemRefDialect MLIRPass + MLIRPass + MLIRSCFDialect + MLIRShapeDialect + MLIRSparseTensorDialect + MLIRSupport + MLIRTensorDialect + MLIRTransforms MLIRTransforms + MLIRVectorDialect ) diff --git a/stablehlo/conversions/linalg/transforms/PassDetail.h b/stablehlo/conversions/linalg/transforms/PassDetail.h index d7d51094e8d..a9bbb9641ba 100644 --- a/stablehlo/conversions/linalg/transforms/PassDetail.h +++ b/stablehlo/conversions/linalg/transforms/PassDetail.h @@ -1,3 +1,19 @@ +/* Copyright 2022 The IREE Authors + Copyright 2023 OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSDETAIL_H #define STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSDETAIL_H diff --git a/stablehlo/conversions/linalg/transforms/Passes.h b/stablehlo/conversions/linalg/transforms/Passes.h index 3aec8ba3c81..25737bcbb06 100644 --- a/stablehlo/conversions/linalg/transforms/Passes.h +++ b/stablehlo/conversions/linalg/transforms/Passes.h @@ -1,3 +1,19 @@ +/* Copyright 2022 The IREE Authors + Copyright 2023 OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSES_H #define STABLEHLO_CONVERSIONS_LINALG_TRANSFORMS_PASSES_H diff --git a/stablehlo/conversions/linalg/transforms/Passes.td b/stablehlo/conversions/linalg/transforms/Passes.td index c6c27521f45..1bff08accc3 100644 --- a/stablehlo/conversions/linalg/transforms/Passes.td +++ b/stablehlo/conversions/linalg/transforms/Passes.td @@ -1,3 +1,19 @@ +/* Copyright 2022 The IREE Authors + Copyright 2023 OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #ifndef STABLEHLO_TO_LINALG_PASSES #define STABLEHLO_TO_LINALG_PASSES diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 43864861cdd..0a08b8e6ac5 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -1,3 +1,19 @@ +/* Copyright 2022 The IREE Authors + Copyright 2023 OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 038e670625b..7f023191114 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3510,12 +3510,21 @@ LogicalResult verifyDynamicBroadcastInDimOp( } LogicalResult verifyDynamicIotaOp(std::optional location, - Value outputShape, int64_t outputDimension, + Value outputShape, int64_t iotaDimension, Value result) { - if (!isCompatibleForHloTypeInference(outputShape, result.getType())) + auto shape = result.getType().cast(); + if (!isCompatibleForHloTypeInference(outputShape, shape)) return emitOptionalError( location, "output_shape is incompatible with return type of operation ", result.getType()); + + if (!shape.hasRank()) return success(); + + if (iotaDimension >= shape.getRank() || iotaDimension < 0) + return emitOptionalError( + location, + "iota dimension cannot go beyond the output rank or be negative."); + return success(); } diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index 015cae63759..b83573fb759 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -427,7 +427,7 @@ LogicalResult verifyDynamicBroadcastInDimOp( Value result); LogicalResult verifyDynamicIotaOp(std::optional location, - Value outputShape, int64_t outputDimension, + Value outputShape, int64_t iotaDimension, Value result); LogicalResult verifyDynamicPadOp(std::optional location, diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 86d803d2542..2d4875409c6 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -5648,6 +5648,32 @@ func.func @dynamic_iota_unranked() -> tensor<*xf32> { // ----- +func.func @dynamic_iota_unranked_large() -> tensor<*xf32> { + %0 = stablehlo.constant dense<[4]> : tensor<1xi64> + %1 = stablehlo.dynamic_iota %0, dim = 3 : (tensor<1xi64>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> +} + +// ----- + +func.func @dynamic_iota_invalid_iota_dimension_negative() -> tensor { + // expected-error@+2 {{iota dimension cannot go beyond the output rank or be negative}} + %0 = stablehlo.constant dense<[4]> : tensor<1xi64> + %1 = stablehlo.dynamic_iota %0, dim = -1 : (tensor<1xi64>) -> tensor + func.return %1 : tensor +} + +// ----- + +func.func @dynamic_iota_invalid_iota_dimension_too_big() -> tensor { + %0 = stablehlo.constant dense<[4]> : tensor<1xi64> + // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} + %1 = stablehlo.dynamic_iota %0, dim = 2 : (tensor<1xi64>) -> tensor + func.return %1 : tensor +} + +// ----- + func.func @dynamic_iota_output_shape_negative_size() -> tensor<4xf32> { // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<4xf32>'}} %0 = stablehlo.constant dense<[-1]> : tensor<1xi64>