From a8f44d0becc773e0ec01703235559827bd16f68f Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Fri, 8 Mar 2024 11:13:54 -0800 Subject: [PATCH] Updates to ConvolutionOP verifier to support quantization constraints (#2079) --- stablehlo/dialect/TypeInference.cpp | 74 ++++++++++++++++++++ stablehlo/tests/ops_stablehlo_quantized.mlir | 72 +++++++++++++++++++ 2 files changed, 146 insertions(+) diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 48cfb4d80b0..8ba072fa4de 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -65,6 +65,25 @@ limitations under the License. namespace mlir { namespace hlo { +namespace { +//===----------------------------------------------------------------------===// +// Utils for quantization specific verifications +//===----------------------------------------------------------------------===// +template +bool allQuantized(ArrayRef typeRange) { + return llvm::all_of(typeRange, [&](Type val) { + return val.cast().getElementType().isa(); + }); +} + +template +bool noneQuantized(ArrayRef typeRange) { + return llvm::all_of(typeRange, [&](Type val) { + return !val.cast().getElementType().isa(); + }); +} + +} // namespace //===----------------------------------------------------------------------===// // Utils for shape functions. @@ -3453,6 +3472,61 @@ LogicalResult verifyConvolutionOp( "is incompatible with return type of operation ", shapedResultType, ""); + llvm::SmallVector typeEntries{lhsType, rhsType, resultType}; + if (noneQuantized(typeEntries)) return success(); + // convolution_c28 + if (!allQuantized(typeEntries)) { + return emitOptionalError(location, + "not all of operands and result are quantized"); + } + + auto lhsQType = + getElementTypeOrSelf(lhsType).dyn_cast(); + auto rhsQType = + getElementTypeOrSelf(rhsType).dyn_cast(); + auto resultQType = + getElementTypeOrSelf(resultType).dyn_cast(); + // convolution_c29 + if (lhsQType.getStorageType() != rhsQType.getStorageType()) + return emitOptionalError(location, "mismatched operand storage types ", + lhsQType.getStorageType(), " and ", + rhsQType.getStorageType()); + // convolution_c30 + auto expressedType = lhsQType.getExpressedType(); + if (expressedType != rhsQType.getExpressedType() || + expressedType != resultQType.getExpressedType()) + return emitOptionalError(location, + "mismatched operands and result expressed types"); + + llvm::SmallVector typeEntriesPerAxis{rhsType, resultType}; + if (noneQuantized(typeEntriesPerAxis)) + return success(); + // convolution_c31 + if (!allQuantized(typeEntriesPerAxis)) { + return emitOptionalError(location, + "rhs and result are of mixed per_tensor and " + "per_axis quantized tensor type ", + rhsType, " and ", resultType); + } + + auto rhsQPAType = rhsQType.dyn_cast(); + auto resultQPAType = + resultQType.dyn_cast(); + // convolution_c32 + if (rhsQPAType && + rhsQPAType.getQuantizedDimension() != kernelOutputFeatureDimension) + return emitOptionalError( + location, "mismatched kernel_output_feature_dimension ", + kernelOutputFeatureDimension, " and rhs quantized dimension ", + rhsQPAType.getQuantizedDimension()); + // convolution_c33 + if (resultQPAType && + resultQPAType.getQuantizedDimension() != outputFeatureDimension) + return emitOptionalError(location, "mismatched output_feature_dimension ", + outputFeatureDimension, + " and result quantized dimension ", + resultQPAType.getQuantizedDimension()); + return success(); } diff --git a/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/tests/ops_stablehlo_quantized.mlir index cb54cd55e6d..0423afffda9 100644 --- a/stablehlo/tests/ops_stablehlo_quantized.mlir +++ b/stablehlo/tests/ops_stablehlo_quantized.mlir @@ -821,3 +821,75 @@ func.func @illegal_storage_type_for_quantized_element_type(%arg0: tensor<4x!quan %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor<4x!quant.uniform>) -> tensor<4xf32> func.return %0 : tensor<4xf32> } + +// ----- + +func.func @convolution_c28(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { + // expected-error@+1 {{not all of operands and result are quantized}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207xf32>, tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> + func.return %0 : tensor<1x8x8x16x!quant.uniform> +} + +// ----- + +func.func @convolution_c29(%arg0: tensor<1x8x8x207x!quant.uniform>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { + // expected-error@+1 {{mismatched operand storage types 'i16' and 'i8'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207x!quant.uniform>, tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> + func.return %0 : tensor<1x8x8x16x!quant.uniform> +} + +// ----- + +func.func @convolution_c30(%arg0: tensor<1x8x8x207x!quant.uniform>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { + // expected-error@+1 {{mismatched operands and result expressed types}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207x!quant.uniform>, tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> + func.return %0 : tensor<1x8x8x16x!quant.uniform> +} + +// ----- + +func.func @convolution_c31(%arg0: tensor<1x8x8x207x!quant.uniform>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { + // expected-error@+1 {{rhs and result are of mixed per_tensor and per_axis quantized tensor type 'tensor<3x3x207x16x!quant.uniform>' and 'tensor<1x8x8x16x!quant.uniform>'}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207x!quant.uniform>, tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> + func.return %0 : tensor<1x8x8x16x!quant.uniform> +} + +// ----- + +func.func @convolution_c32(%arg0: tensor<1x8x8x207x!quant.uniform>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { + // expected-error@+1 {{mismatched kernel_output_feature_dimension 3 and rhs quantized dimension 0}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207x!quant.uniform>, tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> + func.return %0 : tensor<1x8x8x16x!quant.uniform> +} + +// ----- + +func.func @convolution_c33(%arg0: tensor<1x8x8x207x!quant.uniform>, %arg1: tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> { + // expected-error@+1 {{mismatched output_feature_dimension 3 and result quantized dimension 0}} + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : + (tensor<1x8x8x207x!quant.uniform>, tensor<3x3x207x16x!quant.uniform>) -> tensor<1x8x8x16x!quant.uniform> + func.return %0 : tensor<1x8x8x16x!quant.uniform> +}