From 9ff5c6be04ebd81a793c26b3d1f1092d0e117d8e Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Tue, 10 Oct 2023 19:01:53 +0000 Subject: [PATCH] Iter1: Address feedback comments --- docs/spec.md | 58 +++++++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/docs/spec.md b/docs/spec.md index 6d9d4e555b1..9640c43d054 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -4001,13 +4001,13 @@ doesn't hold for many popular reductions. E.g. floating-point addition for `body` and zero for `init_values` don't actually form a monoid because floating-point addition is not associative. -More formally, `results...[j0, ..., jR-1] = reduce_implicit_convert(reduce( -input_slices_converted), type(func_outputs(body)...), type(results...)))` where: +More formally, `results...[j0, ..., jR-1] = convert_with_quantized_type(reduce( +input_slices_converted), type(results...)))` where: * `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted at `dimensions`. -* `input_slices_converted = reduce_implicit_convert(input_slices..., - type(inputs...), type(func_inputs(body)...)`. +* `input_slices_converted = convert_with_quantized_type(input_slices..., + type(func_inputs(body)[:len(func_inputs(body))//2])...)`. * `reduce(input_slices_converted) = exec(schedule)` for some binary tree `schedule` where: * `exec(node) = body(exec(node.left), exec(node.right))`. @@ -4018,7 +4018,7 @@ input_slices_converted), type(func_outputs(body)...), type(results...)))` where: `index_space(input_slices_converted)` in the ascending lexicographic order of `index`. * Interspersed with an implementation-defined amount of - `reduce_implicit_convert(init_values..., type(init_values...), type(func_inputs(body)[len(func_inputs(body)//2)]:)...)` + `convert_with_quantized_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)` at implementation-defined positions. #### Inputs @@ -4045,11 +4045,7 @@ input_slices_converted), type(func_outputs(body)...), type(results...)))` where: * (C5) `is_unique(dimensions)`. * (C6) `body` has type `tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `is_integer(element_type(inputs[i])) = is_integer(element_type(E[i]))` or - `is_float(element_type(inputs[i])) = is_float(element_type(E[i]))` or - `is_complex(element_type(inputs[i])) = is_complex(element_type(E[i]))` or - `(is_quantized(element_type(inputs[i])) = is_quantized(element_type(E[i])) and - expressed_type(element_type(inputs[i])) = expressed_type(element_type(E[i])))`. + `same_elementtype_ignoring_bitwidth(element_type(inputs[i]), element_type(E[i]))`. * (C7) `shape(results...) = shape(inputs...)` except that the dimension sizes of `inputs...` corresponding to `dimensions` are not included. * (C8) `baseline_element_type(inputs...) = baseline_element_type(results...)`. @@ -4277,13 +4273,9 @@ where: * (C10) `size(window_dilations) = rank(inputs[0])`. * (C11) `0 < window_dilations`. * (C12) `shape(padding) = [rank(inputs[0]), 2]`. -* (C13) `body` has type `tensor, ..., tensor, tensor, ...,` +* (C13) `body` has type `(tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `is_integer(element_type(inputs[i])) = is_integer(element_type(E[i]))` or - `is_float(element_type(inputs[i])) = is_float(element_type(E[i]))` or - `is_complex(element_type(inputs[i])) = is_complex(element_type(E[i]))` or - `(is_quantized(element_type(inputs[i])) = is_quantized(element_type(E[i])) and - expressed_type(element_type(inputs[i])) = expressed_type(element_type(E[i])))`. + `same_elementtype_ignoring_bitwidth(element_type(inputs[i]), element_type(E[i]))`. * (C14) `same(shape(results...))`. * (C15) `shape(results[0]) = num_windows` where: * `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`. @@ -4977,11 +4969,7 @@ More formally: * (C9) `select` has type `(tensor, tensor) -> tensor` where `E = element_type(operand)`. * (C10) `scatter` has type `(tensor, tensor) -> tensor` where - `is_integer(element_type(operand)) = is_integer(element_type(E))` or - `is_float(element_type(operand)) = is_float(element_type(E))` or - `is_complex(element_type(operand)) = is_complex(element_type(E))` or - `(is_quantized(element_type(operand)) = is_quantized(element_type(E)) and - expressed_type(element_type(operand)) = expressed_type(element_type(E)))`. + `same_elementtype_ignoring_bitwidth(element_type(operand), element_type(E))`. * (C11) `baseline_type(operand) = baseline_type(result)`. @@ -6286,6 +6274,19 @@ If `x` is a value or placeholder, this function is a shortcut for `member_name(type(x))`. If `x` is not a type that has an appropriate member, or a value or a placeholder of such a type, returns `None`. +* `same_elementtype_ignoring_bitwidth(x: Type, y: Type) -> bool` checks for the +equality of `x` and `y`, ignoring the bitwidth, when they are of type +`TensorElementType`. When `x` and `y` are `QuantizedTensorElementType`s, +the function checks for the equality of `QuantizationExpressedType` component. + +```python +def same_elementtype_ignoring_bitwidth(x: Type, y: Type) -> Value: + return is_integer(x) = is_integer(y) or + is_float(x) = is_float(y) or + is_complex(x) = is_complex(y) or + (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y)) +``` + #### Construction of values * `operation_name(*xs: Value | Type) -> Value`. Available for all operations. @@ -6330,19 +6331,24 @@ function returns `true`. If `x` is not a tensor, returns `None`. tensors and returns `num_results` slices of `x` along the axis `axis`. If `x` is not a tensor or `dim(x, axis) % num_results != 0`, returns `None`. -* `reduce_implicit_convert(x: Value, source_type: Type, destination_type: Type) +* `convert_with_quantized_type(x: Value, destination_type: Type) -> Value` is defined on tensors and returns the converted value of `x` based on -the `source_type` and `destination_type` as follows: +the `type(x)` and `destination_type` as follows: ```python -def reduce_implicit_convert(x: Value, source_type: Type, destination_type: Type) -> Value: - if source_type == destination_type: +def convert_with_quantized_type(x: Value, destination_type: Type) -> Value: + if type(x) == destination_type: return x - if is_quantized(source_type) and is_quantized(destination_type): + if is_quantized(type(x)) and is_quantized(destination_type): return quantize(x, destination_type) return convert(x, destination_type) ``` +There is plan to merge `convert`, `uniform_quantize` and `uniform_dequantize` +operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). After +the merge we do not need the above function and can use the operation name for +`convert` instead. + #### Shape computations * `axes(x: Value | Placeholder | Type) -> Value` is a shortcut for