Skip to content

Commit

Permalink
GetDimensionSizeOp : make arguments quantized type (#2059)
Browse files Browse the repository at this point in the history
~~DynamicBroadcastInDimOp: is yet to be a part of spec.~~ (see comments)
GetDimensionSizeOp: will be changed to quantized types in spec. 
Back porting from the integration as file check tests were failing for
both.
  • Loading branch information
abhigunj authored Mar 5, 2024
1 parent b5a199f commit ab9ea3e
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
11 changes: 6 additions & 5 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -3040,14 +3040,15 @@ behavior is undefined. More formally, for all `i1 < i2` from `indices(result)`,
#### Semantics

Produces the size of the given `dimension` of the `operand`. More formally,
`result = dim(operand, dimension)`.
`result = dim(operand, dimension)`. The Semantics concerns only with the shape
component of the type. The element-type could be anything.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-------------|-------------------------|-------------|
| (I1) | `operand` | tensor | (C1) |
| (I2) | `dimension` | constant of type `si64` | (C1) |
| Label | Name | Type | Constraints |
|-------|-------------|----------------------------|-------------|
| (I1) | `operand` | tensor or quantized tensor | (C1) |
| (I2) | `dimension` | constant of type `si64` | (C1) |

#### Outputs

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2413,7 +2413,7 @@ def StableHLO_GetDimensionSizeOp: StableHLO_Op<"get_dimension_size",
```
}];
let arguments = (ins
HLO_Tensor:$operand, /*get_dimension_size_i1*/
HLO_TensorOrPerAxisQuantizedTensor:$operand, /*get_dimension_size_i1*/
I64Attr:$dimension /*get_dimension_size_i2*/
);
// TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the
Expand Down
5 changes: 5 additions & 0 deletions stablehlo/tests/ops_stablehlo_quantized.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
func.func @ops_per_axis_quantization(
%arg0: tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>,
%arg1: tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30}>>,
%shape: tensor<3xi64>,
%token0: !stablehlo.token) {
%bitcast_convert = "stablehlo.bitcast_convert"(%arg0) : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>
%broadcast_in_dim_1 = "stablehlo.broadcast_in_dim" (%arg0) {broadcast_dimensions = array<i64: 0, 1, 3>} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x3x2x!quant.uniform<i8<-128:127>:f32:3, {0.1:-30, 0.5:-20}>>
%broadcast_in_dim_2 = "stablehlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = array<i64: 0, 1, 2>} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30}>>) -> tensor<2x2x2x!quant.uniform<i8<-128:127>:f32:0, {0.1:-30, 0.1:-30}>>
%custom_call = "stablehlo.custom_call" (%arg0) {call_target_name = "foo"} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>
%get_dimension_size = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<i32>
%outfeed = "stablehlo.outfeed"(%arg0, %token0) {outfeed_config = ""} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>, !stablehlo.token) -> !stablehlo.token
%reshape = "stablehlo.reshape" (%arg0) : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>) -> tensor<2x2x!quant.uniform<i8<-128:127>:f32:1, {0.1:-30, 0.5:-20}>>
%send = "stablehlo.send"(%arg0, %token0) {channel_handle = #stablehlo.channel_handle<handle = 5, type = 2>, is_host_transfer = true} : (tensor<1x2x2x!quant.uniform<i8<-128:127>:f32:2, {0.1:-30, 0.5:-20}>>, !stablehlo.token) -> !stablehlo.token
Expand Down Expand Up @@ -44,6 +46,7 @@ func.func @ops_per_tensor_quantization(
%arg1: tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>,
%arg2: tensor<!quant.uniform<i8:f32, 1.0:17>>,
%arg3: tensor<2x4x!quant.uniform<i8:f32, 1.0:17>>,
%shape: tensor<3xi64>,
%token0: !stablehlo.token) {
%abs = "stablehlo.abs"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%add = "stablehlo.add"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
Expand All @@ -61,9 +64,11 @@ func.func @ops_per_tensor_quantization(
%cosine = "stablehlo.cosine"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%custom_call = "stablehlo.custom_call" (%arg0) {call_target_name = "foo"} : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%divide = "stablehlo.divide"(%arg0, %arg1) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%dynamic_broadcast_in_dim = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 0, 1, 2>} : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>, tensor<3xi64>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%exponential = "stablehlo.exponential"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%exponential_minus_one = "stablehlo.exponential_minus_one"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%floor = "stablehlo.floor"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%get_dimension_size = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<i32>
%is_finite = "stablehlo.is_finite"(%arg0) {} : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2xi1>
%log = "stablehlo.log"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
%log_plus_one = "stablehlo.log_plus_one"(%arg0) : (tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x2x2x!quant.uniform<i8:f32, 1.0:17>>
Expand Down

0 comments on commit ab9ea3e

Please sign in to comment.