-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add interpreter for ConvolutionOp #1314
Conversation
|
stablehlo/reference/Ops.cpp
Outdated
SmallVector<Tensor> results; | ||
for (auto [left, right] : llvm::zip(lhses, rhses)) { | ||
SmallVector<ShapedTypeComponents> inferredConvolutionType; | ||
auto convolutionStatus = hlo::inferConvolutionOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the use of inferConvolutionOp
and it's interface
stablehlo/stablehlo/dialect/TypeInference.cpp
Line 1733 in 63b6f5a
LogicalResult inferConvolutionOp( |
evalConvolutionOp
The main motivation is: We are using a lot of boilerplate code to convert attributes from one type to other just to call infer*Ops. All these boilerplate can be removed if we have evalConvolutionOp
share the same interface with inferConv*Op
. This will improve the readability.
Benefit:
- We can get rid of all these wrap/unwrap code for various attributes. No
flattenPad
. - The callsite of
evalConvolutionOp
ineval
, where we have the loop over the ops, would be simplified as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your overall idea makes sense. But the counterargument would be that it comes at the cost of adding boiler plate code to create default parameters under eval
, and it would be an exception from a related issue #1031 of unwrapping MLIR based classes out of inferFooOps
. Let's hear from @burmako before I continue with this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be an exception from a related issue #1031 of unwrapping MLIR based classes out of inferFooOps
All we need here is a common interface for inferOp* and evalOp*, so as to remove the conversion code. I am perfectly fine with modifying the inferOp* interfaces as we did earlier as well.
Sure, let's hear from Eugene first.
49537c5
to
e1fcf8d
Compare
stablehlo/reference/Ops.cpp
Outdated
} | ||
return evalConcatenateOp(results, outputFeatureDimension, result.getType()); | ||
} | ||
auto lhsWindowDimensions = concatAndPermute( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per the spec
* `lhs_shape(n, hw, c) = permute([n] + hw + [c], [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension])`.
* `result_shape(n1, hw, c1) = permute([n1] + hw + [c1], [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension])`.
All instances of concateAndPermute
are using only permutation orders. We can defined them once like the following as pass onto concateAndPermute
auto lhsPermutation = inputBatchDimension + inputSpatialDimensions + inputFeatureDimension;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto lhsPermutation = inputBatchDimension + inputSpatialDimensions + inputFeatureDimension;
Since we can't quite use the syntax sugar mentioned above, I moved the computation for permutation out of the helper function. This saves some compute and is also closer to the spec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for that we need to added overload operator+
in Axes? Let us weigh in @burmako opinion on this. Keeping this unresolved for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. The only problem is that the way we overload operator+
would be different in Axes.h
and Sizes.h
. One option is to write concat
function, but this does not necessarily make the code as simple as using operator+
.
Finish adding my remaining set of comments! Sorry for taking a bit longer for the review. It took me some time to get the overall perceptive. Also, I can see how much attention and hard-work you have put in to get the implementation and make it working. Thanks @ghpvnist ! |
a05215b
to
651ee35
Compare
651ee35
to
6375baf
Compare
6375baf
to
4575bae
Compare
This is part 3 of #1964 to implement the remaining parts of #1314. One notable change in TypeInference.cpp is (C27), whose verification differs whether element type is quantized. We have the following constraints in the spec (excluding quantization-related constraints C28-C33): ``` (I1) `lhs` tensor. (I2) `rhs` tensor. (I3) `window_strides` 1-dimensional tensor constant of type `si64`. (I4) `padding` 2-dimensional tensor constant of type `si64`. (I5) `lhs_dilation` 1-dimensional tensor constant of type `si64`. (I6) `rhs_dilation` 1-dimensional tensor constant of type `si64`. (I7) `window_reversal` 1-dimensional tensor constant of type `i1`. (I8) `input_batch_dimension` constant of type `si64`. (I9) `input_feature_dimension` constant of type `si64`. (I10) `input_spatial_dimensions` 1-dimensional tensor constant of type `si64`. (I11) `kernel_input_feature_dimension` constant of type `si64`. (I12) `kernel_output_feature_dimension` constant of type `si64`. (I13) `kernel_spatial_dimensions` 1-dimensional tensor constant of type `si64`. (I14) `output_batch_dimension` constant of type `si64`. (I15) `output_feature_dimension` constant of type `si64`. (I16) `output_spatial_dimensions` 1-dimensional tensor constant of type `si64`. (I17) `feature_group_count` constant of type `si64`. (I18) `batch_group_count` constant of type `si64`. (I19) `precision_config` variadic number of enums of `DEFAULT`, `HIGH`, and `HIGHEST`. (C1) `N = rank(lhs) = rank(rhs)`. (C2) `size(window_strides) = N - 2`. (C3) `0 < window_strides`. (C4) `shape(padding) = [N - 2, 2]`. (C5) `size(lhs_dilation) = N - 2`. (C6) `0 < lhs_dilation`. (C7) `size(rhs_dilation) = N - 2`. (C8) `0 < rhs_dilation`. (C9) `size(window_reversal) = N - 2`. (C10) `dim(lhs, input_batch_dimension) % batch_group_count = 0`. (C11) `dim(lhs, input_feature_dimension) % feature_group_count = 0`. (C12) `size(input_spatial_dimensions) = N - 2`. (C13) Given `input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]`: * `is_unique(input_dimensions)`. * `0 <= input_dimensions < N`. (C14) `dim(rhs, kernel_input_feature_dimension = dim(lhs, input_feature_dimension) / feature_group_count`. (C15) `dim(rhs, kernel_output_feature_dimension) % batch_group_count = 0`. (C16) `dim(rhs, kernel_output_feature_dimension) % feature_group_count = 0`. (C17) `size(kernel_spatial_dimensions) = N - 2`. (C18) Given `kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`: * `is_unique(kernel_dimensions)`. * `0 <= kernel_dimensions < N`. (C19) `size(output_spatial_dimensions) = N - 2`. (C20) Given `output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]`: * `is_unique(output_dimensions)`. * `0 <= output_dimensions < N`. (C21) `0 < feature_group_count`. (C22) `0 < batch_group_count`. (C23) `feature_group_count = 1 or batch_group_count = 1`. (C24) `size(precision_config) = 2`. (C25) `dim(result, result_dim)` is defined as: * `dim(lhs, input_batch_dimension) / batch_group_count` if `result_dim = output_batch_dimension`. * `dim(rhs, kernel_output_feature_dimension)` if `result_dim = output_feature_dimension`. * `num_windows` otherwise, where: * `output_spatial_dimensions[spatial_dim] = result_dim`. * `lhs_dim = input_spatial_dimensions[spatial_dim]`. * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`. * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) = 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`. * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`. * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) = 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`. * `is_empty_window[lhs_dim] = padded_input_shape[lhs_dim] = 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]`. * `num_windows = is_empty_window[lhs_dim] ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`. (C26) `rank(result) = N`. (C27) `element_type(lhs) = element_type(rhs) = element_type(result)`. ``` These constraints will be comprehensively covered by the following tests: ``` I1: a) `lhs` tensor. (Covered by ODS). I2: a) `rhs` tensor. (Covered by ODS). I3: a) `window_strides` is not a 1-dimensional tensor. (Covered by ODS). b) element_type(`window_strides`) != `si64`. (Covered by ODS). I4: a) `padding` is not a 2-dimensional tensor. b) element_type(`padding`) != `si64`. (Covered by ODS). I5: a) `lhs_dilation` is not a 1-dimensional tensor. (Covered by ODS). b) element_type(`lhs_dilation`) != `si64`. (Covered by ODS). I6: a) `rhs_dilation` is not a 1-dimensional tensor. (Covered by ODS). b) element_type(`rhs_dilation`) != `si64`. (Covered by ODS). I7: a) `window_reversal` is not a 1-dimensional tensor. (Covered by ODS). b) element_type(`window_reversal`) != `i1`. (Covered by ODS). I8: a) element_type(`input_batch_dimension`) != `si64`. (Covered by ODS). I9: a) element_type(`input_feature_dimension`) != `si64`. (Covered by ODS). I10: a) `input_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS). b) element_type(`input_spatial_dimensions`) != `si64`. (Covered by ODS). I11: a) element_type(`kernel_input_feature_dimension`) != `si64`. (Covered by ODS). I12: a) element_type(`kernel_output_feature_dimension`) != `si64`. (Covered by ODS). I13: a) `kernel_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS). b) element_type(`kernel_spatial_dimensions`) != `si64`. (Covered by ODS). I14: a) element_type(`output_batch_dimension`) != `si64`. (Covered by ODS). I15: a) element_type(`output_feature_dimension`) != `si64`. (Covered by ODS). I16: a) `output_spatial_dimensions` is not a 1-dimensional tensor. (Covered by ODS). b) element_type(`output_spatial_dimensions`) != `si64`. (Covered by ODS). I17: a) element_type(`feature_group_count`) != `si64`. (Covered by ODS). I18: a) element_type(`batch_group_count`) != `si64`. (Covered by ODS). I19: a) `precision_config` does not have variadic number of enums of `DEFAULT`, `HIGH`, and `HIGHEST`. (Covered by ODS). C1: a) N = rank(`lhs`) != rank(`rhs`). C2: a) size(`window_strides`) != N - 2. C3: a) `window_strides[i]` <= 0 for any i in [0, size(`window_strides`)). C4: a) dim(`padding`, 0) != N - 2. b) dim(`padding`, 1) != 2. C5: a) size(`lhs_dilation`) != N - 2. C6: a) `lhs_dilation[i]` <= 0 for any i in [0, size(`lhs_dilation`)). C7: a) size(`rhs_dilation`) != N - 2. C8: a) `rhs_dilation[i]` <= 0 for any i in [0, size(`rhs_dilation`)). C9: a) size(`window_reversal`) != N - 2. C10: a) `dim(lhs, input_batch_dimension) % batch_group_count != 0`. C11: a) `dim(lhs, input_feature_dimension) % feature_group_count != 0`. C12: a) size(`input_spatial_dimensions`) != N - 2. C13: a) Given `input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]`: * Any dimensions in `input_dimensions` are not unique. b) Given `input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]`: * For any i in `input_dimensions`, i < 0. c) Given `input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]`: * For any i in `input_dimensions`, i >= N. C14: a) `dim(rhs, kernel_input_feature_dimension != dim(lhs, input_feature_dimension) / feature_group_count`. C15: a) `dim(rhs, kernel_output_feature_dimension) % batch_group_count != 0`. C16: a) `dim(rhs, kernel_output_feature_dimension) % feature_group_count != 0`. C17: a) size(`kernel_spatial_dimensions`) != N - 2. C18: a) Given `kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`: * Any dimensions in `kernel_dimensions` are not unique. b) Given `kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`: * For any i in$ `kernel_dimensions`, i < 0. c) Given `kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`: * For any i in `kernel_dimensions`, i >= N. C19: a) size(`output_spatial_dimensions`) != N - 2. C20: a) Given `output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]`: * Any dimensions in `output_dimensions` are not unique. b) Given `output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]`: * For any i in `output_dimensions`, i < 0. c) Given `output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]`: * For any i in `output_dimensions`, i >= N. C21: a) `feature_group_count <= 0`. C22: a) `batch_group_count <= 0`. C23: a) `feature_group_count` != 1 and `batch_group_count` != 1. C24: a) size(`precision_config`) != 2. C25: a) For result_dim in [0, N): `dim(result, result_dim)` != `dim(lhs, input_batch_dimension) / batch_group_count`, if `result_dim = output_batch_dimension`. b) For result_dim in [0, N): `dim(result, result_dim)` != `dim(rhs, kernel_output_feature_dimension)`, if `result_dim = output_feature_dimension`. c) For result_dim in [0, N): `dim(result, result_dim)` != `num_windows` otherwise, where: * `output_spatial_dimensions[spatial_dim] = result_dim`. * `lhs_dim = input_spatial_dimensions[spatial_dim]`. * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`. * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) == 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`. * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`. * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) == 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`. * `num_windows = (padded_input_shape[lhs_dim] == 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]) ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`. C26: a) rank(result) != N. C27: a) element_type(`lhs`) != element_type(`rhs`). ``` If we drop the "Covered by ODS" pieces, this will leave us with the following test cases: ``` I4a: `padding` is not a 2-dimensional tensor. C1a: rank(`lhs`) != rank(`rhs`) != N. C2a: size(`window_strides`) != N - 2. C3a: `window_strides[i]` <= 0 for any i in [0, size(`window_strides`)). C4a: dim(`padding`, 0) != N - 2. C4b: dim(`padding`, 1) != 2. C5a: size(`lhs_dilation`) != N - 2. C6a: `lhs_dilation[i]` <= 0 for any i in [0, size(`lhs_dilation`)). C7a: size(`rhs_dilation`) != N - 2. C8a: `rhs_dilation[i]` <= 0 for any i in [0, size(`rhs_dilation`)). C9a: size(`window_reversal`) != N - 2. C10a: `dim(lhs, input_batch_dimension) % batch_group_count != 0`. C11a: `dim(lhs, input_feature_dimension) % feature_group_count != 0`. C12a: size(`input_spatial_dimensions`) != N - 2. C13a: Given `input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]`: * Any dimensions in `input_dimensions` are not unique. C13b: Given `input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]`: * For any i in `input_dimensions`, i < 0. C13c: Given `input_dimensions = [input_batch_dimension] + input_spatial_dimensions + [input_feature_dimension]`: * For any i in `input_dimensions`, i >= N. C14a: `dim(rhs, kernel_input_feature_dimension != dim(lhs, input_feature_dimension) / feature_group_count`. C15a: `dim(rhs, kernel_output_feature_dimension) % batch_group_count != 0`. C16a: `dim(rhs, kernel_output_feature_dimension) % feature_group_count != 0`. C17a: size(`kernel_spatial_dimensions`) != N - 2. C18a: Given `kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`: * Any dimensions in `kernel_dimensions` are not unique. C18b: Given `kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`: * For any i in$ `kernel_dimensions`, i < 0. C18c: Given `kernel_dimensions = kernel_spatial_dimensions + [kernel_input_feature_dimension] + [kernel_output_feature_dimension]`: * For any i in `kernel_dimensions`, i >= N. C19a: size(`output_spatial_dimensions`) != N - 2. C20a: Given `output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]`: * Any dimensions in `output_dimensions` are not unique. b) Given `output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]`: * For any i in `output_dimensions`, i < 0. c) Given `output_dimensions = [output_batch_dimension] + output_spatial_dimensions + [output_feature_dimension]`: * For any i in `output_dimensions`, i >= N. C21a: `feature_group_count <= 0`. C22a: `batch_group_count <= 0`. C23a: `feature_group_count` != 1 and `batch_group_count` != 1. C24a: size(`precision_config`) != 2. C25a: For result_dim in [0, N): `dim(result, result_dim)` != `dim(lhs, input_batch_dimension) / batch_group_count`, if `result_dim = output_batch_dimension`. C25b: For result_dim in [0, N): `dim(result, result_dim)` != `dim(rhs, kernel_output_feature_dimension)`, if `result_dim = output_feature_dimension`. C25c: For result_dim in [0, N): `dim(result, result_dim)` != `num_windows` otherwise, where: * `output_spatial_dimensions[spatial_dim] = result_dim`. * `lhs_dim = input_spatial_dimensions[spatial_dim]`. * `rhs_dim = kernel_spatial_dimensions[spatial_dim]`. * `dilated_input_shape[lhs_dim] = dim(lhs, lhs_dim) == 0 ? 0 : (dim(lhs, lhs_dim) - 1) * lhs_dilation[spatial_dim] + 1`. * `padded_input_shape[lhs_dim] = padding[spatial_dim, 0] + dilated_input_shape[lhs_dim] + padding[spatial_dim, 1]`. * `dilated_window_shape[lhs_dim] = dim(rhs, rhs_dim) == 0 ? 0 : (dim(rhs, rhs_dim) - 1) * rhs_dilation[spatial_dim] + 1`. * `num_windows = (padded_input_shape[lhs_dim] == 0 || dilated_window_shape[lhs_dim] > padded_input_shape[lhs_dim]) ? 0 : floor((padded_input_shape[lhs_dim] - dilated_window_shape[lhs_dim]) / window_strides[spatial_dim]) + 1`. C26a: rank(result) != N. C27a: element_type(`lhs`) != element_type(`rhs`). ``` Notes: * (new C24) is left untouched as there are still pending action item regarding the number of precision config values allowed in #879. closes #2092
We have the following constraints in the spec (excluding quantization-related constraints C28-C33):
These constraints will be comprehensively covered by the following tests:
If we drop the "Covered by ODS" pieces, this will leave us with the following test cases:
Notes:
closes #970