Skip to content

Commit

Permalink
Iter1: Address feedback comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Oct 10, 2023
1 parent ad6a64e commit a326fb5
Showing 1 changed file with 31 additions and 25 deletions.
56 changes: 31 additions & 25 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -4001,13 +4001,13 @@ doesn't hold for many popular reductions. E.g. floating-point addition for
`body` and zero for `init_values` don't actually form a monoid because
floating-point addition is not associative.

More formally, `results...[j0, ..., jR-1] = reduce_implicit_convert(reduce(
input_slices_converted), type(func_outputs(body)...), type(results...)))` where:
More formally, `results...[j0, ..., jR-1] = convert_with_quantized_type(reduce(
input_slices_converted), type(results...)))` where:

* `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted
at `dimensions`.
* `input_slices_converted = reduce_implicit_convert(input_slices...,
type(inputs...), type(func_inputs(body)...)`.
* `input_slices_converted = convert_with_quantized_type(input_slices...,
type(func_inputs(body)[:len(func_inputs(body))//2])...)`.
* `reduce(input_slices_converted) = exec(schedule)` for some binary tree
`schedule` where:
* `exec(node) = body(exec(node.left), exec(node.right))`.
Expand All @@ -4018,7 +4018,7 @@ input_slices_converted), type(func_outputs(body)...), type(results...)))` where:
`index_space(input_slices_converted)` in the ascending lexicographic order
of `index`.
* Interspersed with an implementation-defined amount of
`reduce_implicit_convert(init_values..., type(init_values...), type(func_inputs(body)[len(func_inputs(body)//2)]:)...)`
`convert_with_quantized_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)`
at implementation-defined positions.

#### Inputs
Expand All @@ -4045,11 +4045,7 @@ input_slices_converted), type(func_outputs(body)...), type(results...)))` where:
* (C5) `is_unique(dimensions)`.
* (C6) `body` has type `tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_integer(element_type(inputs[i])) = is_integer(element_type(E[i]))` or
`is_float(element_type(inputs[i])) = is_float(element_type(E[i]))` or
`is_complex(element_type(inputs[i])) = is_complex(element_type(E[i]))` or
`(is_quantized(element_type(inputs[i])) = is_quantized(element_type(E[i])) and
expressed_type(element_type(inputs[i])) = expressed_type(element_type(E[i])))`.
`same_elementtype_ignoring_bitwidth(element_type(inputs[i]), element_type(E[i]))`.
* (C7) `shape(results...) = shape(inputs...)` except that the dimension
sizes of `inputs...` corresponding to `dimensions` are not included.
* (C8) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
Expand Down Expand Up @@ -4279,11 +4275,7 @@ where:
* (C12) `shape(padding) = [rank(inputs[0]), 2]`.
* (C13) `body` has type `tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_integer(element_type(inputs[i])) = is_integer(element_type(E[i]))` or
`is_float(element_type(inputs[i])) = is_float(element_type(E[i]))` or
`is_complex(element_type(inputs[i])) = is_complex(element_type(E[i]))` or
`(is_quantized(element_type(inputs[i])) = is_quantized(element_type(E[i])) and
expressed_type(element_type(inputs[i])) = expressed_type(element_type(E[i])))`.
`same_elementtype_ignoring_bitwidth(element_type(inputs[i]), element_type(E[i]))`.
* (C14) `same(shape(results...))`.
* (C15) `shape(results[0]) = num_windows` where:
* `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`.
Expand Down Expand Up @@ -4977,11 +4969,7 @@ More formally:
* (C9) `select` has type `(tensor<E>, tensor<E>) -> tensor<i1>` where
`E = element_type(operand)`.
* (C10) `scatter` has type `(tensor<E>, tensor<E>) -> tensor<E>` where
`is_integer(element_type(operand)) = is_integer(element_type(E))` or
`is_float(element_type(operand)) = is_float(element_type(E))` or
`is_complex(element_type(operand)) = is_complex(element_type(E))` or
`(is_quantized(element_type(operand)) = is_quantized(element_type(E)) and
expressed_type(element_type(operand)) = expressed_type(element_type(E)))`.
`same_elementtype_ignoring_bitwidth(element_type(operand), element_type(E))`.
* (C11) `baseline_type(operand) = baseline_type(result)`.
<!-- markdownlint-enable line-length -->

Expand Down Expand Up @@ -6286,6 +6274,19 @@ If `x` is a value or placeholder, this function is a shortcut for
`member_name(type(x))`. If `x` is not a type that has an appropriate member, or
a value or a placeholder of such a type, returns `None`.

* `same_elementtype_ignoring_bitwidth(x: Type, y: Type) -> bool` checks for the
equality of `x` and `y`, ignoring the bitwidth, when they are of type
`TensorElementType`. When `x` and `y` are `QuantizedTensorElementType`s,
the function checks for the equality of `QuantizationExpressedType` component.

```python
def same_elementtype_ignoring_bitwidth(x: Type, y: Type) -> Value:
return is_integer(x) = is_integer(y) or
is_float(x) = is_float(y) or
is_complex(x) = is_complex(y) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
```

#### Construction of values

* `operation_name(*xs: Value | Type) -> Value`. Available for all operations.
Expand Down Expand Up @@ -6330,19 +6331,24 @@ function returns `true`. If `x` is not a tensor, returns `None`.
tensors and returns `num_results` slices of `x` along the axis `axis`.
If `x` is not a tensor or `dim(x, axis) % num_results != 0`, returns `None`.

* `reduce_implicit_convert(x: Value, source_type: Type, destination_type: Type)
* `convert_with_quantized_type(x: Value, destination_type: Type)
-> Value` is defined on tensors and returns the converted value of `x` based on
the `source_type` and `destination_type` as follows:
the `type(x)` and `destination_type` as follows:

```python
def reduce_implicit_convert(x: Value, source_type: Type, destination_type: Type) -> Value:
if source_type == destination_type:
def convert_with_quantized_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(source_type) and is_quantized(destination_type):
if is_quantized(type(x)) and is_quantized(destination_type):
return quantize(x, destination_type)
return convert(x, destination_type)
```

There is plan to merge `convert`, `uniform_quantize` and `uniform_dequantize`
operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). After
the merge we do not need the above function and can use the operation name for
`convert` instead.

#### Shape computations

* `axes(x: Value | Placeholder | Type) -> Value` is a shortcut for
Expand Down

0 comments on commit a326fb5

Please sign in to comment.