diff --git a/docs/spec.md b/docs/spec.md index 9640c43d054..bf434b0bc9b 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -4001,12 +4001,12 @@ doesn't hold for many popular reductions. E.g. floating-point addition for `body` and zero for `init_values` don't actually form a monoid because floating-point addition is not associative. -More formally, `results...[j0, ..., jR-1] = convert_with_quantized_type(reduce( +More formally, `results...[j0, ..., jR-1] = convert_or_quantize(reduce( input_slices_converted), type(results...)))` where: * `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted at `dimensions`. -* `input_slices_converted = convert_with_quantized_type(input_slices..., +* `input_slices_converted = convert_or_quantize(input_slices..., type(func_inputs(body)[:len(func_inputs(body))//2])...)`. * `reduce(input_slices_converted) = exec(schedule)` for some binary tree `schedule` where: @@ -4018,7 +4018,7 @@ input_slices_converted), type(results...)))` where: `index_space(input_slices_converted)` in the ascending lexicographic order of `index`. * Interspersed with an implementation-defined amount of - `convert_with_quantized_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)` + `convert_or_quantize(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)` at implementation-defined positions. #### Inputs @@ -4045,7 +4045,7 @@ input_slices_converted), type(results...)))` where: * (C5) `is_unique(dimensions)`. * (C6) `body` has type `tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `same_elementtype_ignoring_bitwidth(element_type(inputs[i]), element_type(E[i]))`. + `is_convertible_or_quantizable(element_type(inputs[i]), element_type(E[i]))`. * (C7) `shape(results...) = shape(inputs...)` except that the dimension sizes of `inputs...` corresponding to `dimensions` are not included. * (C8) `baseline_element_type(inputs...) = baseline_element_type(results...)`. @@ -4275,7 +4275,7 @@ where: * (C12) `shape(padding) = [rank(inputs[0]), 2]`. * (C13) `body` has type `(tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `same_elementtype_ignoring_bitwidth(element_type(inputs[i]), element_type(E[i]))`. + `is_convertible_or_quantizable(element_type(inputs[i]), element_type(E[i]))`. * (C14) `same(shape(results...))`. * (C15) `shape(results[0]) = num_windows` where: * `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`. @@ -4969,7 +4969,7 @@ More formally: * (C9) `select` has type `(tensor, tensor) -> tensor` where `E = element_type(operand)`. * (C10) `scatter` has type `(tensor, tensor) -> tensor` where - `same_elementtype_ignoring_bitwidth(element_type(operand), element_type(E))`. + `is_convertible_or_quantizable(element_type(operand), element_type(E))`. * (C11) `baseline_type(operand) = baseline_type(result)`. @@ -6274,13 +6274,13 @@ If `x` is a value or placeholder, this function is a shortcut for `member_name(type(x))`. If `x` is not a type that has an appropriate member, or a value or a placeholder of such a type, returns `None`. -* `same_elementtype_ignoring_bitwidth(x: Type, y: Type) -> bool` checks for the +* `is_convertible_or_quantizable(x: Type, y: Type) -> bool` checks for the equality of `x` and `y`, ignoring the bitwidth, when they are of type `TensorElementType`. When `x` and `y` are `QuantizedTensorElementType`s, the function checks for the equality of `QuantizationExpressedType` component. ```python -def same_elementtype_ignoring_bitwidth(x: Type, y: Type) -> Value: +def is_convertible_or_quantizable(x: Type, y: Type) -> Value: return is_integer(x) = is_integer(y) or is_float(x) = is_float(y) or is_complex(x) = is_complex(y) or @@ -6304,6 +6304,33 @@ and [slicing](https://docs.python.org/3/reference/expressions.html#slicings) notations from Python are available to index into tensors, quantized tensors and tuples. +* `convert_or_quantize(x: Value, destination_type: Type) -> Value` is defined on +tensors and returns the converted value of `x` based on the `type(x)` and +`destination_type` as follows: + +```python +def convert_or_quantize(x: Value, destination_type: Type) -> Value: + if type(x) == destination_type: + return x + + if is_quantized(destination_type): + if is_quantized(type(x)): + return quantize(x, destination_type) + assert is_float(type(x)) + return quantize(x, destination_type) + + if is_quantized(type(x)): + assert destination_type = expressed_type(type(x)) + return dequantize(type(x)) + + return convert(x, destination_type) +``` + +There is plan to merge `convert`, `uniform_quantize` and `uniform_dequantize` +operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). After +the merge we do not need the above function and can use the operation name for +`convert` instead. + * `is_nan(x: Value) -> Value` is defined on tensors and returns `true` if all elements of `x` are `NaN` or `false` otherwise. If `x` is not a tensor, returns `None`. @@ -6331,24 +6358,6 @@ function returns `true`. If `x` is not a tensor, returns `None`. tensors and returns `num_results` slices of `x` along the axis `axis`. If `x` is not a tensor or `dim(x, axis) % num_results != 0`, returns `None`. -* `convert_with_quantized_type(x: Value, destination_type: Type) --> Value` is defined on tensors and returns the converted value of `x` based on -the `type(x)` and `destination_type` as follows: - -```python -def convert_with_quantized_type(x: Value, destination_type: Type) -> Value: - if type(x) == destination_type: - return x - if is_quantized(type(x)) and is_quantized(destination_type): - return quantize(x, destination_type) - return convert(x, destination_type) -``` - -There is plan to merge `convert`, `uniform_quantize` and `uniform_dequantize` -operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). After -the merge we do not need the above function and can use the operation name for -`convert` instead. - #### Shape computations * `axes(x: Value | Placeholder | Type) -> Value` is a shortcut for