Skip to content

Commit

Permalink
Remove the qunt-to-math pass limitation of Dot/Conv op result type (#…
Browse files Browse the repository at this point in the history
…2462)

[ParentPR](#2461)

`quant-to-math` pass
[assumes](https://github.com/openxla/stablehlo/blob/eba821aa1c54a21d70331d7926dfc8b929f988f3/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp#L984)
that the return type of quantized dot_general or convolution is always
having `i32` as storage type.
With that the following program with the result storage type of `i8`fail
to materialize all the intermediate converted values.


```
func.func @dot_general_per_tensor_quantization(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<2x3x5x!quant.uniform<i8:f32, 1.0:0>>) -> tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>> {
  // expected-error@+1 {{failed to legalize operation 'stablehlo.dot_general' that was explicitly marked illegal}}
  %0 = "stablehlo.dot_general"(%arg0, %arg1) {
    dot_dimension_numbers = #stablehlo.dot<
      lhs_batching_dimensions = [0],
      rhs_batching_dimensions = [0],
      lhs_contracting_dimensions = [1],
      rhs_contracting_dimensions = [1]
    >
  } : (tensor<2x3x4x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x3x5x!quant.uniform<i8:f32, 1.0:0>>) -> tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>>
  func.return %0 : tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>>
}
```

One option to fix this to provode source/target materialization
[link](https://mlir.llvm.org/docs/DialectConversion/#type-converter),
but we found that for other Ops
[e.g.](https://github.com/openxla/stablehlo/blob/eba821aa1c54a21d70331d7926dfc8b929f988f3/stablehlo/transforms/StablehloLegalizeQuantToInt.cpp#L510),
there is a precedent on how to convert the math computed in `i32` is
converted back to result type. The PR implements the missing conversion.

Note to the reviewers: To may just focus on the very last commit of the
chain. The rest is coming from parent PR.
  • Loading branch information
sdasgup3 authored Jul 26, 2024
1 parent 0673dd2 commit fbcf294
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
42 changes: 42 additions & 0 deletions stablehlo/tests/stablehlo_legalize_quant_to_int.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2633,3 +2633,45 @@ func.func @while_per_tensor_quantization(%arg0: tensor<4x!quant.uniform<i8:f32,
func.return %while : tensor<?x!quant.uniform<i8:f32, 1.0:17>>
}

// -----

// CHECK-LABEL: func.func @dot_general_with_i8_result_element_type
func.func @dot_general_with_i8_result_element_type(%arg0: tensor<2x3x4x!quant.uniform<i8:f32, 1.0:17>>, %arg1: tensor<2x3x5x!quant.uniform<i8:f32, 1.0:0>>) -> tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>> {
// CHECK: stablehlo.dot_general{{.*}} : (tensor<2x3x4xi8>, tensor<2x3x5xi8>) -> tensor<2x4x5xi32>
// CHECK: %[[CONVERT:.*]] = stablehlo.convert {{.*}} : (tensor<2x4x5xi32>) -> tensor<2x4x5xi8>
// CHECK: return %[[CONVERT]] : tensor<2x4x5xi8>
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
dot_dimension_numbers = #stablehlo.dot<
lhs_batching_dimensions = [0],
rhs_batching_dimensions = [0],
lhs_contracting_dimensions = [1],
rhs_contracting_dimensions = [1]
>
} : (tensor<2x3x4x!quant.uniform<i8:f32, 1.0:17>>, tensor<2x3x5x!quant.uniform<i8:f32, 1.0:0>>) -> tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>>
func.return %0 : tensor<2x4x5x!quant.uniform<i8:f32, 1.0:17>>
}

// -----

// CHECK-LABEL: func.func @convolution_with_i8_result_element_type
func.func @convolution_with_i8_result_element_type(
%arg0: tensor<128x28x28x1x!quant.uniform<i8:f32, 2.000000e+00:4>>,
%arg1: tensor<3x3x1x128x!quant.uniform<i8:f32, 3.000000e+00:0>>
) -> tensor<128x26x26x128x!quant.uniform<i8:f32, 1.000000e+00:5>> {
// CHECK: stablehlo.convolution{{.*}} : (tensor<128x28x28x1xi8>, tensor<3x3x1x128xi8>) -> tensor<128x26x26x128xi32>
// CHECK: %[[CONVERT:.*]] = stablehlo.convert {{.*}} : (tensor<128x26x26x128xi32>) -> tensor<128x26x26x128xi8>
// CHECK: return %[[CONVERT]] : tensor<128x26x26x128xi8>
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {
stride = [1, 1], pad = [[0, 0], [0, 0]],
lhs_dilate = [1, 1],
rhs_dilate = [1, 1]
}
{
batch_group_count = 1 : i64,
feature_group_count = 1 : i64
} : (tensor<128x28x28x1x!quant.uniform<i8:f32, 2.000000e+00:4>>, tensor<3x3x1x128x!quant.uniform<i8:f32, 3.000000e+00:0>>)
-> tensor<128x26x26x128x!quant.uniform<i8:f32, 1.000000e+00:5>>
return %0 : tensor<128x26x26x128x!quant.uniform<i8:f32, 1.000000e+00:5>>
}
11 changes: 9 additions & 2 deletions stablehlo/transforms/StablehloLegalizeQuantToMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,15 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor,
combinedZp = rewriter.create<chlo::BroadcastSubOp>(
op->getLoc(), resInt32TensorType, combinedZp, zpOffset, nullptr);
}
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, resInt32TensorType, resI32, combinedZp, nullptr);
Value zpAdded = rewriter.create<chlo::BroadcastAddOp>(
op->getLoc(), resInt32TensorType, resI32, combinedZp, nullptr);

// Convert results back to result storage type.
auto resQuantType = getQuantType(getElementTypeOrSelf(op.getResult()));
auto resFinalTensorType =
resInt32TensorType.clone(getQuantStorageType(*resQuantType));
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, resFinalTensorType,
zpAdded);
return success();
}

Expand Down

0 comments on commit fbcf294

Please sign in to comment.