Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specification for UniformQuantizeOp and UniformDequantizeOp #1496

Merged
merged 7 commits into from
Jun 19, 2023
Merged
127 changes: 119 additions & 8 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,6 @@ in StableHLO programs. In the meanwhile, here is the list of these operations:
`dynamic_gather`, `dynamic_iota`, `dynamic_pad`, `dynamic_reshape`,
`real_dynamic_slice`, `set_dimension_size`
([#8](https://github.com/openxla/stablehlo/issues/8)).
* "Quantization" category of StableHLO operations - they were bootstrapped from
MHLO, but we haven't specced them yet: `uniform_quantize`
([#531](https://github.com/openxla/stablehlo/issues/531)) and
`uniform_dequantize`
([#530](https://github.com/openxla/stablehlo/issues/530)).
sdasgup3 marked this conversation as resolved.
Show resolved Hide resolved
* Shape computations, including `arith`, `shape` and `tensor` operations
([#8](https://github.com/openxla/stablehlo/issues/8)).

Expand Down Expand Up @@ -5535,6 +5530,87 @@ Produces a `result` tuple from values `val`.
// %result: ([1.0, 2.0], (3))
```

### uniform_dequantize

#### Semantics

Performs element-wise conversion of quantized tensor `operand` to a
floating-point tensor `result` according to the quantization parameters defined
by the `operand` type.

More formally, `result = dequantize(operand)`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-----------|------------------|-------------|
| (I1) | `operand` | quantized tensor | (C1), (C2) |

#### Outputs

| Name | Type | Constraints |
|----------|-------------------------------|-------------|
| `result` | tensor of floating-point type | (C1), (C2) |

#### Constraints

* (C1) `shape(operand) = shape(result)`.
* (C2) `element_type(result) = expressed_type(operand)`.

#### Examples

```mlir
// %operand: [10, 10]
%result = "stablehlo.uniform_dequantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
// %result: [4.0, 15.0]
```

### uniform_quantize

#### Semantics

Performs element-wise conversion of floating-point tensor or quantized tensor
`operand` to a quantized tensor `result` according to the quantization
parameters defined by the `result` type.

More formally,

* If `is_float(operand)`:
* `result = quantize(operand, type(result))`.
* If `is_quantized(operand)`:
* `float_result = dequantize(operand)`.
* `result = quantize(float_result, type(result))`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-----------|--------------------------------------------|-------------|
| (I1) | `operand` | tensor of floating-point or quantized type | (C1), (C2) |

#### Outputs

| Name | Type | Constraints |
|----------|------------------|-------------|
| `result` | quantized tensor | (C1), (C2) |

#### Constraints

* (C1) `shape(operand) = shape(result)`.
* (C2) `expressed_type(result) = is_float(operand) ? element_type(operand) :
expressed_type(operand)`.

#### Examples

```mlir
sdasgup3 marked this conversation as resolved.
Show resolved Hide resolved
// %operand: [4.0, 15.0]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
// %result: [10, 10]

// %operand: [10, 10]
%result = "stablehlo.uniform_quantize"(%operand) : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-20,0.2:-30}>>
// %result: [20, 45]
```

### while

#### Semantics
Expand Down Expand Up @@ -6136,6 +6212,41 @@ def baseline_type(x: Value | Placeholder | Type) -> Type:
return baseline_element_type(type(x))
```

* `dequantize` is defined on quantized tensor types and turns them into
floating-point tensor types. This happens via converting quantized elements
which represent integer values of the storage type into corresponding
floating-point values of the expressed type using the zero point and scale
associated with the quantized element type. At the moment, this function only
works for per-tensor quantization. Per-axis quantization is work in progress
([#1574](https://github.com/openxla/stablehlo/issues/1574)).

```python
def dequantize(x: Value) -> Value:
assert is_quantized(x)
x_storage = bitcast_convert(x, storage_type(x))
x_storage_sub = x_storage - zero_point(x_storage)
x_expressed_sub = convert(x_storage_sub, expressed_type(x))
return x_expressed_sub * scale(x)
```

* `quantize` is defined on floating-point tensor types and turns them into
quantized tensor types. This happens via converting floating-point values
of the expressed type into corresponding integer values of the storage type
using the zero point and scale associated with the quantized element type.
At the moment, this function only works for per-tensor quantization. Per-axis
quantization is work in progress
([#1574](https://github.com/openxla/stablehlo/issues/1574)).

```python
def quantize(x: Value, type: Type) -> Value:
assert is_float(x) and is_quantized(type)
x_expressed_rounded = round_nearest_even(x / scale(type))
x_storage_rounded = convert(x_expressed_rounded, storage_type(type))
x_storage_add = x_storage_rounded + zero_point(type)
x_storage = clamp(storage_min(type), x_storage_add, storage_max(type))
return bitcast_convert(x_storage, type)
```

* `dequantize_op_quantize` is used to specify element-wise computations on
quantized tensors. It dequantizes, i.e. turns quantized elements into their
expressed types, then performs an operation, and then quantizes, i.e. turns
Expand All @@ -6147,10 +6258,10 @@ works for per-tensor quantization. Per-axis quantization is work in progress
def dequantize_op_quantize(op, *inputs_and_output_type):
inputs = inputs_and_output_type[:-1]
output_type = inputs_and_output_type[-1]
float_inputs = [(x - zero_point(x)) * scale(x) for x in inputs]

float_inputs = map(dequantize, inputs)
float_result = op(*float_inputs)
rounded_result = round_nearest_even(float_result / scale(output_type))
return clamp(storage_min(output_type), rounded_result, storage_max(output_type))
return quantize(float_result, output_type)
```

#### Grid computations
Expand Down
4 changes: 2 additions & 2 deletions docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ one of the following tracking labels.
| triangular_solve | yes | revisit | yes | no | revisit |
| tuple | yes | yes | yes | yes | no |
| unary_einsum | no | revisit | no | yes | revisit |
| uniform_dequantize | no | yes\* | yes\* | yes | no |
| uniform_quantize | no | yes\* | infeasible | yes | no |
| uniform_dequantize | yes | yes | yes | yes | no |
| uniform_quantize | yes | revisit | infeasible | yes | no |
| while | yes | revisit | yes | revisit | yes |
| xor | yes | yes | yes | yes | yes |
31 changes: 16 additions & 15 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3055,38 +3055,39 @@ def StableHLO_RngBitGeneratorOp : StableHLO_Op<"rng_bit_generator", [Pure]> {

// TODO(b/230662142): Implement unknown scales/zero_point cases.
def StableHLO_UniformQuantizeOp : StableHLO_UnaryElementwiseOp<"uniform_quantize",
sdasgup3 marked this conversation as resolved.
Show resolved Hide resolved
[Pure], TensorOf<[F32, BF16, HLO_QuantizedInt]>,
HLO_QuantizedIntTensor> {
[Pure], TensorOf<[HLO_Float, HLO_QuantizedInt]> /*uniform_quantize_i1*/,
HLO_QuantizedIntTensor> { /*uniform_quantize_c1*/
let summary = "UniformQuantize operation";
let description = [{
This operation is a work in progress, so it is not yet included in
the StableHLO specification: https://github.com/openxla/stablehlo/issues/588.
Performs element-wise conversion of floating-point tensor or quantized
tensor `operand` to a quantized tensor `result` according to the
quantization parameters defined by the `result` type.

Informally, this operation converts floating point tensors or uniform
quantized tensors to uniform quantized tensors according to the quantization
parameters defined by the result type.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_quantize

Example:
```mlir
%result = stablehlo.uniform_quantize %operand : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
%result = stablehlo.uniform_quantize %operand : (tensor<2xf32>) -> tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>
```
}];
}

def StableHLO_UniformDequantizeOp : StableHLO_UnaryElementwiseOp<"uniform_dequantize",
[InferTensorType, Pure], HLO_QuantizedIntTensor, TensorOf<[F32, BF16]>> {
[InferTensorType, Pure], HLO_QuantizedIntTensor /*uniform_dequantize_i1*/,
HLO_FpTensor> { /*uniform_dequantize_c1, uniform_dequantize_c2*/
let summary = "UniformDequantize operation";
let description = [{
This operation is a work in progress, so it is not yet included in
the StableHLO specification: https://github.com/openxla/stablehlo/issues/588.
Performs element-wise conversion of quantized tensor `operand` to a
floating-point tensor `result` according to the quantization parameters
defined by the `operand` type.

Informally, this operation converts uniform quantized tensors to floating
point tensors according to the quantization parameters defined by the
operand type.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#uniform_dequantize

Example:
```mlir
%result = stablehlo.uniform_dequantize %operand : (tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32>
%result = stablehlo.uniform_dequantize %operand : (tensor<2x!quant.uniform<i8:f32:0, {0.1:-30,0.5:-20}>>) -> tensor<2xf32>
```
}];
}
Expand Down
2 changes: 2 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3009,6 +3009,7 @@ LogicalResult inferUniformDequantizeOp(
// Trait HLO_QuantizedIntTensor in ODS guarantees QuantizedType;
auto quantType = operandType.getElementType().cast<quant::QuantizedType>();
auto shape = operandType.cast<ShapedType>().getShape();
// uniform_dequantize_c1, uniform_dequantize_c2
inferredReturnShapes.emplace_back(shape, quantType.getExpressedType());
return success();
}
Expand All @@ -3017,6 +3018,7 @@ LogicalResult inferUniformQuantizeOp(
std::optional<Location> location, Value operand,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
auto operandType = operand.getType().cast<ShapedType>();
// uniform_quantize_c1
inferredReturnShapes.emplace_back(
operandType.hasRank() ? operandType.getShape() : ArrayRef<int64_t>{});
return success();
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/infer_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ func.func @clamp(%arg0: tensor<1xi32>) -> tensor<1xindex> {

// -----

// CHECK: func @uniform_dequantize
func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xindex> {
// CHECK: func @uniform_dequantize_c2
func.func @uniform_dequantize_c2(%arg: tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xindex> {
%0 = stablehlo.uniform_dequantize %arg : (tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32>
// CHECK: types0 = tensor<16x16xf32>
%1 = "hlo_test_infer.get_return_types"(%0) : (tensor<16x16xf32>) -> tensor<16x16xindex>
Expand Down
13 changes: 2 additions & 11 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5090,7 +5090,7 @@ func.func @quantized_dot_general(%arg0: tensor<2x16x32x!quant.uniform<i8:f32, 2.

// -----

// CHECK: func @uniform_quantize
// CHECK-LABEL: func @uniform_quantize
func.func @uniform_quantize(%arg: tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>> {
%0 = stablehlo.uniform_quantize %arg : (tensor<16x16xf32>) -> tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
func.return %0 : tensor<16x16x!quant.uniform<ui8:f32, 34.0:16>>
Expand All @@ -5106,30 +5106,21 @@ func.func @uniform_requantize(%arg: tensor<16x16x!quant.uniform<i8:f32, 5.0:20>>

// -----

// CHECK: func @uniform_dequantize
// CHECK-LABEL: func @uniform_dequantize
func.func @uniform_dequantize(%arg: tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32> {
%0 = stablehlo.uniform_dequantize %arg : (tensor<16x16x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<16x16xf32>
func.return %0 : tensor<16x16xf32>
}

// -----

// CHECK: func @uniform_dequantize_unranked
func.func @uniform_dequantize_unranked(%arg: tensor<*x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<*xf32> {
%0 = stablehlo.uniform_dequantize %arg : (tensor<*x!quant.uniform<i8:f32, 34.0:16>>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @uniform_dequantize_not_quantize(%arg: tensor<16x16xf32>) -> tensor<16x16xf32> {
// expected-error@+1 {{operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer values, but got 'tensor<16x16xf32>'}}
%0 = stablehlo.uniform_dequantize %arg : (tensor<16x16xf32>) -> tensor<16x16xf32>
func.return %0 : tensor<16x16xf32>
}

// -----

// CHECK-LABEL: func @quantized_constants
func.func @quantized_constants() -> (tensor<2x!quant.uniform<i8:f32, 2.0:15>>, tensor<2x!quant.uniform<ui8:f32, 34.0:16>>, tensor<2x!quant.uniform<i8:f32, 2.0:15>>) {
%0 = stablehlo.constant() {value = dense<[1, 2]> : tensor<2xi8>} : () -> tensor<2x!quant.uniform<i8:f32, 2.000000e+00:15>>
Expand Down