Skip to content

Commit

Permalink
Updates to ConvolutionOP verifier to support quantization constraints (
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigunj authored Mar 8, 2024
1 parent d978408 commit a8f44d0
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
74 changes: 74 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,25 @@ limitations under the License.

namespace mlir {
namespace hlo {
namespace {
//===----------------------------------------------------------------------===//
// Utils for quantization specific verifications
//===----------------------------------------------------------------------===//
template <typename T>
bool allQuantized(ArrayRef<Type> typeRange) {
return llvm::all_of(typeRange, [&](Type val) {
return val.cast<ShapedType>().getElementType().isa<T>();
});
}

template <typename T>
bool noneQuantized(ArrayRef<Type> typeRange) {
return llvm::all_of(typeRange, [&](Type val) {
return !val.cast<ShapedType>().getElementType().isa<T>();
});
}

} // namespace

//===----------------------------------------------------------------------===//
// Utils for shape functions.
Expand Down Expand Up @@ -3453,6 +3472,61 @@ LogicalResult verifyConvolutionOp(
"is incompatible with return type of operation ",
shapedResultType, "");

llvm::SmallVector<Type, 3> typeEntries{lhsType, rhsType, resultType};
if (noneQuantized<quant::QuantizedType>(typeEntries)) return success();
// convolution_c28
if (!allQuantized<quant::QuantizedType>(typeEntries)) {
return emitOptionalError(location,
"not all of operands and result are quantized");
}

auto lhsQType =
getElementTypeOrSelf(lhsType).dyn_cast<quant::QuantizedType>();
auto rhsQType =
getElementTypeOrSelf(rhsType).dyn_cast<quant::QuantizedType>();
auto resultQType =
getElementTypeOrSelf(resultType).dyn_cast<quant::QuantizedType>();
// convolution_c29
if (lhsQType.getStorageType() != rhsQType.getStorageType())
return emitOptionalError(location, "mismatched operand storage types ",
lhsQType.getStorageType(), " and ",
rhsQType.getStorageType());
// convolution_c30
auto expressedType = lhsQType.getExpressedType();
if (expressedType != rhsQType.getExpressedType() ||
expressedType != resultQType.getExpressedType())
return emitOptionalError(location,
"mismatched operands and result expressed types");

llvm::SmallVector<Type, 2> typeEntriesPerAxis{rhsType, resultType};
if (noneQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis))
return success();
// convolution_c31
if (!allQuantized<quant::UniformQuantizedPerAxisType>(typeEntriesPerAxis)) {
return emitOptionalError(location,
"rhs and result are of mixed per_tensor and "
"per_axis quantized tensor type ",
rhsType, " and ", resultType);
}

auto rhsQPAType = rhsQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
auto resultQPAType =
resultQType.dyn_cast<quant::UniformQuantizedPerAxisType>();
// convolution_c32
if (rhsQPAType &&
rhsQPAType.getQuantizedDimension() != kernelOutputFeatureDimension)
return emitOptionalError(
location, "mismatched kernel_output_feature_dimension ",
kernelOutputFeatureDimension, " and rhs quantized dimension ",
rhsQPAType.getQuantizedDimension());
// convolution_c33
if (resultQPAType &&
resultQPAType.getQuantizedDimension() != outputFeatureDimension)
return emitOptionalError(location, "mismatched output_feature_dimension ",
outputFeatureDimension,
" and result quantized dimension ",
resultQPAType.getQuantizedDimension());

return success();
}

Expand Down
72 changes: 72 additions & 0 deletions stablehlo/tests/ops_stablehlo_quantized.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -821,3 +821,75 @@ func.func @illegal_storage_type_for_quantized_element_type(%arg0: tensor<4x!quan
%0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor<4x!quant.uniform<si8:f32, 1.000000e+00>>) -> tensor<4xf32>
func.return %0 : tensor<4xf32>
}

// -----

func.func @convolution_c28(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
// expected-error@+1 {{not all of operands and result are quantized}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
}

// -----

func.func @convolution_c29(%arg0: tensor<1x8x8x207x!quant.uniform<i16:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
// expected-error@+1 {{mismatched operand storage types 'i16' and 'i8'}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x8x8x207x!quant.uniform<i16:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
}

// -----

func.func @convolution_c30(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f64, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
// expected-error@+1 {{mismatched operands and result expressed types}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x8x8x207x!quant.uniform<i8:f64, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
}

// -----

func.func @convolution_c31(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
// expected-error@+1 {{rhs and result are of mixed per_tensor and per_axis quantized tensor type 'tensor<3x3x207x16x!quant.uniform<i8:f32:0, {1.000000e-01:-30}>>' and 'tensor<1x8x8x16x!quant.uniform<i8:f32, 1.000000e+01:50>>'}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
}

// -----

func.func @convolution_c32(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>> {
// expected-error@+1 {{mismatched kernel_output_feature_dimension 3 and rhs quantized dimension 0}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:0, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>>
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32:0, {0.1:-30}>>
}

// -----

func.func @convolution_c33(%arg0: tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32:3, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>> {
// expected-error@+1 {{mismatched output_feature_dimension 3 and result quantized dimension 0}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x8x8x207x!quant.uniform<i8:f32, 2.0:15>>, tensor<3x3x207x16x!quant.uniform<i8:f32:3, {0.1:-30}>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>>
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32:0, {2.0:-30}>>
}

0 comments on commit a8f44d0

Please sign in to comment.