Skip to content

Commit

Permalink
Specification of reduce/reduce_window/select_and_scatter ops
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Oct 10, 2023
1 parent e1e306a commit ad6a64e
Showing 1 changed file with 80 additions and 49 deletions.
129 changes: 80 additions & 49 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -4001,49 +4001,58 @@ doesn't hold for many popular reductions. E.g. floating-point addition for
`body` and zero for `init_values` don't actually form a monoid because
floating-point addition is not associative.

More formally, `results...[j0, ..., jR-1] = reduce(input_slices)` where:
More formally, `results...[j0, ..., jR-1] = reduce_implicit_convert(reduce(
input_slices_converted), type(func_outputs(body)...), type(results...)))` where:

* `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted
at `dimensions`.
* `reduce(input_slices) = exec(schedule)` for some binary tree `schedule`
where:
* `input_slices_converted = reduce_implicit_convert(input_slices...,
type(inputs...), type(func_inputs(body)...)`.
* `reduce(input_slices_converted) = exec(schedule)` for some binary tree
`schedule` where:
* `exec(node) = body(exec(node.left), exec(node.right))`.
* `exec(leaf) = leaf.value`.
* `schedule` is an implementation-defined full binary tree whose in-order
traversal consists of:
* `input_slices...[index]` values, for all `index` in
`index_space(input_slices)` in the ascending lexicographic order of `index`.
* Interspersed with an implementation-defined amount of `init_values`
* `input_slices_converted...[index]` values, for all `index` in
`index_space(input_slices_converted)` in the ascending lexicographic order
of `index`.
* Interspersed with an implementation-defined amount of
`reduce_implicit_convert(init_values..., type(init_values...), type(func_inputs(body)[len(func_inputs(body)//2)]:)...)`
at implementation-defined positions.

#### Inputs

| Label | Name | Type | Constraints |
|-------|---------------|----------------------------------------------|---------------------|
| (I1) | `inputs` | variadic number of tensors | (C1-C4), (C6), (C7) |
| (I2) | `init_values` | variadic number of 0-dimensional tensors | (C2), (C3) |
| (I3) | `dimensions` | 1-dimensional tensor constant of type `si64` | (C4), (C5), (C7) |
| (I4) | `body` | function | (C6) |
| Label | Name | Type | Constraints |
|-------|---------------|--------------------------------------------------------------------------|---------------------|
| (I1) | `inputs` | variadic number of tensors or per-tensor quantized tensors | (C1-C4), (C6), (C7) |
| (I2) | `init_values` | variadic number of 0-dimensional tensors or per-tensor quantized tensors | (C2), (C3) |
| (I3) | `dimensions` | 1-dimensional tensor constant of type `si64` | (C4), (C5), (C7) |
| (I4) | `body` | function | (C6) |

#### Outputs

| Name | Type | Constraints |
|-----------|----------------------------|------------------|
| `results` | variadic number of tensors | (C2), (C3), (C7) |
| Name | Type | Constraints |
|-----------|------------------------------------------------------------|------------------|
| `results` | variadic number of tensors or per-tensor quantized tensors | (C3), (C7), (C8) |

#### Constraints

* (C1) `same(shape(inputs...))`.
* (C2) `element_type(inputs...) = element_type(init_values...) =
element_type(results...)`.
* (C2) `element_type(inputs...) = element_type(init_values...)`.
* (C3) `0 < size(inputs) = size(init_values) = size(results) = N`.
* (C4) `0 <= dimensions < rank(inputs[0])`.
* (C5) `is_unique(dimensions)`.
* (C6) `body` has type `tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`Ei = element_type(inputs[i])`.
`is_integer(element_type(inputs[i])) = is_integer(element_type(E[i]))` or
`is_float(element_type(inputs[i])) = is_float(element_type(E[i]))` or
`is_complex(element_type(inputs[i])) = is_complex(element_type(E[i]))` or
`(is_quantized(element_type(inputs[i])) = is_quantized(element_type(E[i])) and
expressed_type(element_type(inputs[i])) = expressed_type(element_type(E[i])))`.
* (C7) `shape(results...) = shape(inputs...)` except that the dimension
sizes of `inputs...` corresponding to `dimensions` are not included.
* (C8) `baseline_element_type(inputs...) = baseline_element_type(results...)`.

#### Examples

Expand Down Expand Up @@ -4236,22 +4245,22 @@ where:

#### Inputs

| Label | Name | Type | Constraints |
|-------|---------------------|----------------------------------------------|-------------------------------------------------|
| (I1) | `inputs` | variadic number of tensors | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | `init_values` | variadic number of 0-dimensional tensors | (C1), (C13), (C16) |
| (I3) | `window_dimensions` | 1-dimensional tensor constant of type `si64` | (C4), (C5), (C15) |
| (I4) | `window_strides` | 1-dimensional tensor constant of type `si64` | (C6), (C7), (C15) |
| (I5) | `base_dilations` | 1-dimensional tensor constant of type `si64` | (C8), (C9), (C15) |
| (I6) | `window_dilations` | 1-dimensional tensor constant of type `si64` | (C10), (C11), (C15) |
| (I7) | `padding` | 2-dimensional tensor constant of type `si64` | (C12), (C15) |
| (I8) | `body` | function | (C13) |
| Label | Name | Type | Constraints |
|-------|---------------------|--------------------------------------------------------------------------|-------------------------------------------------|
| (I1) | `inputs` | variadic number of tensors or per-tensor quantized tensors | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | `init_values` | variadic number of 0-dimensional tensors or per-tensor quantized tensors | (C1), (C13), (C16) |
| (I3) | `window_dimensions` | 1-dimensional tensor constant of type `si64` | (C4), (C5), (C15) |
| (I4) | `window_strides` | 1-dimensional tensor constant of type `si64` | (C6), (C7), (C15) |
| (I5) | `base_dilations` | 1-dimensional tensor constant of type `si64` | (C8), (C9), (C15) |
| (I6) | `window_dilations` | 1-dimensional tensor constant of type `si64` | (C10), (C11), (C15) |
| (I7) | `padding` | 2-dimensional tensor constant of type `si64` | (C12), (C15) |
| (I8) | `body` | function | (C13) |

#### Outputs

| Name | Type | Constraints |
|-----------|----------------------------|-----------------|
| `results` | variadic number of tensors | (C1), (C14-C16) |
| Name | Type | Constraints |
|-----------|------------------------------------------------------------|-----------------|
| `results` | variadic number of tensors or per-tensor quantized tensors | (C1), (C14-C16) |

#### Constraints

Expand All @@ -4268,16 +4277,21 @@ where:
* (C10) `size(window_dilations) = rank(inputs[0])`.
* (C11) `0 < window_dilations`.
* (C12) `shape(padding) = [rank(inputs[0]), 2]`.
* (C13) `body` has type `(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)`
where `Ei = element_type(inputs[i])`.
* (C13) `body` has type `tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_integer(element_type(inputs[i])) = is_integer(element_type(E[i]))` or
`is_float(element_type(inputs[i])) = is_float(element_type(E[i]))` or
`is_complex(element_type(inputs[i])) = is_complex(element_type(E[i]))` or
`(is_quantized(element_type(inputs[i])) = is_quantized(element_type(E[i])) and
expressed_type(element_type(inputs[i])) = expressed_type(element_type(E[i])))`.
* (C14) `same(shape(results...))`.
* (C15) `shape(results[0]) = num_windows` where:
* `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`.
* `padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]`.
* `dilated_window_shape = (window_dimensions - 1) * window_dilations + 1`.
* `is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape`.
* `num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1`.
* (C16) `element_type(results...) = element_type(init_values...)`.
* (C16) `baseline_element_type(results...) = baseline_element_type(init_values...)`.
<!-- markdownlint-enable line-length -->

#### Examples
Expand Down Expand Up @@ -4929,22 +4943,22 @@ More formally:

#### Inputs

| Label | Name | Type | Constraints |
|-------|---------------------|----------------------------------------------|-------------------------|
| (I1) | `operand` | tensor | (C1-C4), (C6), (C8-C11) |
| (I2) | `source` | tensor | (C1), (C2) |
| (I3) | `init_value` | 0-dimensional tensor | (C3) |
| (I4) | `window_dimensions` | 1-dimensional tensor constant of type `si64` | (C2), (C4), (C5) |
| (I5) | `window_strides` | 1-dimensional tensor constant of type `si64` | (C2), (C6), (C7) |
| (I6) | `padding` | 2-dimensional tensor constant of type `si64` | (C2), (C8) |
| (I7) | `select` | function | (C9) |
| (I8) | `scatter` | function | (C10) |
| Label | Name | Type | Constraints |
|-------|---------------------|-----------------------------------------------------|-------------------------|
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1-C4), (C6), (C8-C11) |
| (I2) | `source` | tensor or per-tensor quantized tensor | (C1), (C2) |
| (I3) | `init_value` | 0-dimensional tensor or per-tensor quantized tensor | (C3) |
| (I4) | `window_dimensions` | 1-dimensional tensor constant of type `si64` | (C2), (C4), (C5) |
| (I5) | `window_strides` | 1-dimensional tensor constant of type `si64` | (C2), (C6), (C7) |
| (I6) | `padding` | 2-dimensional tensor constant of type `si64` | (C2), (C8) |
| (I7) | `select` | function | (C9) |
| (I8) | `scatter` | function | (C10) |

#### Outputs

| Name | Type | Constraints |
|----------|--------|-------------|
| `result` | tensor | (C11) |
| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C11) |

#### Constraints

Expand All @@ -4963,8 +4977,12 @@ More formally:
* (C9) `select` has type `(tensor<E>, tensor<E>) -> tensor<i1>` where
`E = element_type(operand)`.
* (C10) `scatter` has type `(tensor<E>, tensor<E>) -> tensor<E>` where
`E = element_type(operand)`.
* (C11) `type(operand) = type(result)`.
`is_integer(element_type(operand)) = is_integer(element_type(E))` or
`is_float(element_type(operand)) = is_float(element_type(E))` or
`is_complex(element_type(operand)) = is_complex(element_type(E))` or
`(is_quantized(element_type(operand)) = is_quantized(element_type(E)) and
expressed_type(element_type(operand)) = expressed_type(element_type(E)))`.
* (C11) `baseline_type(operand) = baseline_type(result)`.
<!-- markdownlint-enable line-length -->

#### Examples
Expand Down Expand Up @@ -6312,6 +6330,19 @@ function returns `true`. If `x` is not a tensor, returns `None`.
tensors and returns `num_results` slices of `x` along the axis `axis`.
If `x` is not a tensor or `dim(x, axis) % num_results != 0`, returns `None`.

* `reduce_implicit_convert(x: Value, source_type: Type, destination_type: Type)
-> Value` is defined on tensors and returns the converted value of `x` based on
the `source_type` and `destination_type` as follows:

```python
def reduce_implicit_convert(x: Value, source_type: Type, destination_type: Type) -> Value:
if source_type == destination_type:
return x
if is_quantized(source_type) and is_quantized(destination_type):
return quantize(x, destination_type)
return convert(x, destination_type)
```

#### Shape computations

* `axes(x: Value | Placeholder | Type) -> Value` is a shortcut for
Expand Down

0 comments on commit ad6a64e

Please sign in to comment.