Skip to content

Commit

Permalink
Iter2: Address feedback comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Oct 11, 2023
1 parent 9ff5c6b commit e0051f5
Showing 1 changed file with 35 additions and 26 deletions.
61 changes: 35 additions & 26 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -4001,12 +4001,12 @@ doesn't hold for many popular reductions. E.g. floating-point addition for
`body` and zero for `init_values` don't actually form a monoid because
floating-point addition is not associative.

More formally, `results...[j0, ..., jR-1] = convert_with_quantized_type(reduce(
More formally, `results...[j0, ..., jR-1] = convert_or_quantize(reduce(
input_slices_converted), type(results...)))` where:

* `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted
at `dimensions`.
* `input_slices_converted = convert_with_quantized_type(input_slices...,
* `input_slices_converted = convert_or_quantize(input_slices...,
type(func_inputs(body)[:len(func_inputs(body))//2])...)`.
* `reduce(input_slices_converted) = exec(schedule)` for some binary tree
`schedule` where:
Expand All @@ -4018,7 +4018,7 @@ input_slices_converted), type(results...)))` where:
`index_space(input_slices_converted)` in the ascending lexicographic order
of `index`.
* Interspersed with an implementation-defined amount of
`convert_with_quantized_type(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)`
`convert_or_quantize(init_values..., type(func_inputs(body)[len(func_inputs(body))//2:])...)`
at implementation-defined positions.

#### Inputs
Expand All @@ -4045,7 +4045,7 @@ input_slices_converted), type(results...)))` where:
* (C5) `is_unique(dimensions)`.
* (C6) `body` has type `tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`same_elementtype_ignoring_bitwidth(element_type(inputs[i]), element_type(E[i]))`.
`is_convertible_or_quantizable(element_type(inputs[i]), element_type(E[i]))`.
* (C7) `shape(results...) = shape(inputs...)` except that the dimension
sizes of `inputs...` corresponding to `dimensions` are not included.
* (C8) `baseline_element_type(inputs...) = baseline_element_type(results...)`.
Expand Down Expand Up @@ -4275,7 +4275,7 @@ where:
* (C12) `shape(padding) = [rank(inputs[0]), 2]`.
* (C13) `body` has type `(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`same_elementtype_ignoring_bitwidth(element_type(inputs[i]), element_type(E[i]))`.
`is_convertible_or_quantizable(element_type(inputs[i]), element_type(E[i]))`.
* (C14) `same(shape(results...))`.
* (C15) `shape(results[0]) = num_windows` where:
* `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`.
Expand Down Expand Up @@ -4969,7 +4969,7 @@ More formally:
* (C9) `select` has type `(tensor<E>, tensor<E>) -> tensor<i1>` where
`E = element_type(operand)`.
* (C10) `scatter` has type `(tensor<E>, tensor<E>) -> tensor<E>` where
`same_elementtype_ignoring_bitwidth(element_type(operand), element_type(E))`.
`is_convertible_or_quantizable(element_type(operand), element_type(E))`.
* (C11) `baseline_type(operand) = baseline_type(result)`.
<!-- markdownlint-enable line-length -->

Expand Down Expand Up @@ -6274,13 +6274,13 @@ If `x` is a value or placeholder, this function is a shortcut for
`member_name(type(x))`. If `x` is not a type that has an appropriate member, or
a value or a placeholder of such a type, returns `None`.

* `same_elementtype_ignoring_bitwidth(x: Type, y: Type) -> bool` checks for the
* `is_convertible_or_quantizable(x: Type, y: Type) -> bool` checks for the
equality of `x` and `y`, ignoring the bitwidth, when they are of type
`TensorElementType`. When `x` and `y` are `QuantizedTensorElementType`s,
the function checks for the equality of `QuantizationExpressedType` component.

```python
def same_elementtype_ignoring_bitwidth(x: Type, y: Type) -> Value:
def is_convertible_or_quantizable(x: Type, y: Type) -> Value:
return is_integer(x) = is_integer(y) or
is_float(x) = is_float(y) or
is_complex(x) = is_complex(y) or
Expand All @@ -6304,6 +6304,33 @@ and [slicing](https://docs.python.org/3/reference/expressions.html#slicings)
notations from Python are available to index into tensors, quantized tensors
and tuples.

* `convert_or_quantize(x: Value, destination_type: Type) -> Value` is defined on
tensors and returns the converted value of `x` based on the `type(x)` and
`destination_type` as follows:

```python
def convert_or_quantize(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x

if is_quantized(destination_type):
if is_quantized(type(x)):
return quantize(x, destination_type)
assert is_float(type(x))
return quantize(x, destination_type)

if is_quantized(type(x)):
assert destination_type = expressed_type(type(x))
return dequantize(type(x))

return convert(x, destination_type)
```

There is plan to merge `convert`, `uniform_quantize` and `uniform_dequantize`
operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). After
the merge we do not need the above function and can use the operation name for
`convert` instead.

* `is_nan(x: Value) -> Value` is defined on tensors and returns `true` if
all elements of `x` are `NaN` or `false` otherwise. If `x` is not a tensor,
returns `None`.
Expand Down Expand Up @@ -6331,24 +6358,6 @@ function returns `true`. If `x` is not a tensor, returns `None`.
tensors and returns `num_results` slices of `x` along the axis `axis`.
If `x` is not a tensor or `dim(x, axis) % num_results != 0`, returns `None`.

* `convert_with_quantized_type(x: Value, destination_type: Type)
-> Value` is defined on tensors and returns the converted value of `x` based on
the `type(x)` and `destination_type` as follows:

```python
def convert_with_quantized_type(x: Value, destination_type: Type) -> Value:
if type(x) == destination_type:
return x
if is_quantized(type(x)) and is_quantized(destination_type):
return quantize(x, destination_type)
return convert(x, destination_type)
```

There is plan to merge `convert`, `uniform_quantize` and `uniform_dequantize`
operations ([#1576](https://github.com/openxla/stablehlo/issues/1576)). After
the merge we do not need the above function and can use the operation name for
`convert` instead.

#### Shape computations

* `axes(x: Value | Placeholder | Type) -> Value` is a shortcut for
Expand Down

0 comments on commit e0051f5

Please sign in to comment.