-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce PerAxis Quantized constraint for StableHLO Quantized OPs #2007
Conversation
Delegating my review to @sdasgup3. I took a scan over the code / tests and from a high level everything LGTM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for the work. Here is some of my initial reviews which I want to publish soon. I need to check the test files in my follow up review.
Here is the remaining list of ops which are uncovered and the following support need to be added.
constant
: is current usingHLO_StaticShapeTensor
which per this change has per-tensor output type. It should be both.iota
usesHLO_StaticShapeIntFpOrComplexTensor
which is not updated. iota is expected to supportper-tensor
infeed
: Accepts both per-tensor and per-axis.recv
: should take both per-tensor and per-axis.If
,Case
,While
: The output supports both.GetTupleElement
/Tuple
: supports both.broadcast
: update the type coinstraints treating the in/out types similar tobrodcast_in_dim
custom-call
: both -fft
: fix the result type as per spec.transpose
: both are supported.
Nit:
- dynamic_iota
, create_token
, dynamic_broadcast_in_dim
, cross-replica-sum
, einsum
, unary_einsum
, dynamic_reshape
, set_dimension_size
, trace
, return
, torch_index_select
, real_dynamic_slice
, dynamic_pad
, dynamic_gather
, dynamic_conv
, dynamic_reshape_shape
, cstr_reshapable
: Let's mention in the description that we are excluding the above stablehlo ops as they are not specced.
As follow up PRs:
- collective_broadcast should also support per-tensor type like other distribution ops.
get_dimension_size: fix the spec
Please free to open PRs for the above items.
Thanks for the thorough review, these OPs fall into result Quantized category, which I incorrectly ignored during audit. Updated the OP def. Updated the audit sheet.
Already taken care of?
Done
Yes, will create separate PR as it involves changes to the spec. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm with some minor comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll think more on naming, but that shouldn't block this from going in since tablegen variable names don't impact the generated code / impl.
and make Quantized test compatible. This was missed during resolving merged conflict for #2007
* made `isCompatibleElementTypeForHloTypeInference` stricter to return error for {not Quantize, Quantize}, {per-axis Quantized, per-tensor Quantized} cases * `AddOp` VHLO Test failures : addressed test failures because {not Quantize, Quantize} is not allowed * CorrectedTraits for `CholeskyOp` and `ClampOp` to match it with the spec ~~Note: This PR is based on in review PR #2007 Follow up PR will add/update OP verifiers for OPs which need special handling
StableHLO OPs supporting Quantization can be categorized into following three types
a. Only PerTensor Quantized Tensors
b. Only PerAxis Quantized Tensors (no OP in this category for now)
c. Both PerTensor and PerAxis Quantized Tensors
PerAxis constraint from the PR allow only PerTensorQuantized Tensor inputs to type (a) OPs and don't allow PerAxis Quantized Tensors.
Also,
Added negative test cases to validate this behavior
Added positive test cases for PerTensor , PerAxis Quantized Tensor support
Excluded OPs
BroadCast
andCall
dynamic_iota
,create_token
,dynamic_broadcast_in_dim
,cross-replica-sum
,einsum
,unary_einsum
,dynamic_reshape
,set_dimension_size
,trace
,return
,torch_index_select
,real_dynamic_slice
,dynamic_pad
,dynamic_gather
,dynamic_conv
,dynamic_reshape_shape
,cstr_reshapable