diff --git a/stablehlo/tests/stablehlo_legalize_quant_to_int.mlir b/stablehlo/tests/stablehlo_legalize_quant_to_int.mlir index 5ea203c9da8..31ed1edf935 100644 --- a/stablehlo/tests/stablehlo_legalize_quant_to_int.mlir +++ b/stablehlo/tests/stablehlo_legalize_quant_to_int.mlir @@ -2633,3 +2633,45 @@ func.func @while_per_tensor_quantization(%arg0: tensor<4x!quant.uniform> } +// ----- + +// CHECK-LABEL: func.func @dot_general_with_i8_result_element_type +func.func @dot_general_with_i8_result_element_type(%arg0: tensor<2x3x4x!quant.uniform>, %arg1: tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> { + // CHECK: stablehlo.dot_general{{.*}} : (tensor<2x3x4xi8>, tensor<2x3x5xi8>) -> tensor<2x4x5xi32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert {{.*}} : (tensor<2x4x5xi32>) -> tensor<2x4x5xi8> + // CHECK: return %[[CONVERT]] : tensor<2x4x5xi8> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [0], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1] + > + } : (tensor<2x3x4x!quant.uniform>, tensor<2x3x5x!quant.uniform>) -> tensor<2x4x5x!quant.uniform> + func.return %0 : tensor<2x4x5x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_with_i8_result_element_type +func.func @convolution_with_i8_result_element_type( + %arg0: tensor<128x28x28x1x!quant.uniform>, + %arg1: tensor<3x3x1x128x!quant.uniform> + ) -> tensor<128x26x26x128x!quant.uniform> { + // CHECK: stablehlo.convolution{{.*}} : (tensor<128x28x28x1xi8>, tensor<3x3x1x128xi8>) -> tensor<128x26x26x128xi32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert {{.*}} : (tensor<128x26x26x128xi32>) -> tensor<128x26x26x128xi8> + // CHECK: return %[[CONVERT]] : tensor<128x26x26x128xi8> + %0 = stablehlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return %0 : tensor<128x26x26x128x!quant.uniform> +} diff --git a/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp b/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp index 510c5bffbd6..207e712d042 100644 --- a/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp +++ b/stablehlo/transforms/StablehloLegalizeQuantToMath.cpp @@ -980,8 +980,15 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, combinedZp = rewriter.create( op->getLoc(), resInt32TensorType, combinedZp, zpOffset, nullptr); } - rewriter.replaceOpWithNewOp( - op, resInt32TensorType, resI32, combinedZp, nullptr); + Value zpAdded = rewriter.create( + op->getLoc(), resInt32TensorType, resI32, combinedZp, nullptr); + + // Convert results back to result storage type. + auto resQuantType = getQuantType(getElementTypeOrSelf(op.getResult())); + auto resFinalTensorType = + resInt32TensorType.clone(getQuantStorageType(*resQuantType)); + rewriter.replaceOpWithNewOp(op, resFinalTensorType, + zpAdded); return success(); }