Skip to content

Commit

Permalink
Support only the promotion use-cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Dec 6, 2023
1 parent fc84bc5 commit 400557f
Showing 1 changed file with 42 additions and 37 deletions.
79 changes: 42 additions & 37 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,8 @@ defined as follows:

Afterwards, within each `process_group`:

* `result@process[result_index] = to_destination_type(exec(schedule),
type(result))` for some binary tree `schedule` where:
* `result@process[result_index] = exec(schedule)` for some binary tree
`schedule` where:
* `exec(node)` = `computation(exec(node.left), exec(node.right))`.
* `exec(leaf)` = `leaf.value`.
* `schedule` is an implementation-defined binary tree whose in-order
Expand All @@ -800,7 +800,7 @@ Afterwards, within each `process_group`:

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C6) |
| `result` | tensor or per-tensor quantized tensor | (C6-C7) |

#### Constraints

Expand All @@ -812,8 +812,9 @@ Afterwards, within each `process_group`:
* (C3) `0 <= replica_groups < size(replica_groups)`.
* (C4) If `use_global_device_ids = true`, then `channel_id > 0`.
* (C5) `computation` has type `(tensor<E>, tensor<E>) -> (tensor<E>)` where
`is_promotable(E, element_type(operand))`.
* (C6) `is_promotable(type(result), type(operand))`.
`is_promotable(element_type(operand), E)`.
* (C6) `shape(result) = shape(operand)`.
* (C7) `element_type(result) = E`.

#### Examples

Expand Down Expand Up @@ -4074,8 +4075,7 @@ doesn't hold for many popular reductions. E.g. floating-point addition for
`body` and zero for `init_values` don't actually form a monoid because
floating-point addition is not associative.

More formally, `results...[j0, ..., jR-1] = to_destination_type(
reduce(input_slices_converted), type(results...)))` where:
More formally, `results...[j0, ..., jR-1] = reduce(input_slices_converted)` where:

* `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted
at `dimensions`.
Expand Down Expand Up @@ -4119,10 +4119,10 @@ reduce(input_slices_converted), type(results...)))` where:
* (C5) `is_unique(dimensions)`.
* (C6) `body` has type `(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_promotable(element_type(inputs[i]), element_type(E[i]))`.
`is_promotable(element_type(inputs[i]), Ei)`.
* (C7) `shape(results...) = shape(inputs...)` except that the dimension
sizes of `inputs...` corresponding to `dimensions` are not included.
* (C8) `is_promotable(element_type(inputs...), element_type(results...))`.
* (C8) `element_type(results[i]) = Ei` for all `i` in `[0,N)`.

#### Examples

Expand Down Expand Up @@ -4244,7 +4244,7 @@ Afterwards, within each `process_group`:

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C8) |
| `result` | tensor or per-tensor quantized tensor | (C8-C9) |

#### Constraints

Expand All @@ -4258,10 +4258,11 @@ Afterwards, within each `process_group`:
* (C5) `0 <= replica_groups < size(replica_groups)`.
* (C6) If `use_global_device_ids = true`, then `channel_id > 0`.
* (C7) `computation` has type `(tensor<E>, tensor<E>) -> (tensor<E>)` where
`is_promotable(E, element_type(operand))`.
* (C8) `is_promotable(type(result), type(operand))` except:
`is_promotable(element_type(operand), E)`.
* (C8) `shape(result) = shape(operand)` except:
* `dim(result, scatter_dimension) = dim(operand, scatter_dimension) /
dim(process_groups, 1)`.
* (C9) `element_type(result) = E`.

#### Examples

Expand Down Expand Up @@ -4318,7 +4319,7 @@ where:
| Label | Name | Type | Constraints |
|-------|---------------------|--------------------------------------------------------------------------|-------------------------------------------------|
| (I1) | `inputs` | variadic number of tensors or per-tensor quantized tensors | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) |
| (I2) | `init_values` | variadic number of 0-dimensional tensors or per-tensor quantized tensors | (C1), (C13), (C16) |
| (I2) | `init_values` | variadic number of 0-dimensional tensors or per-tensor quantized tensors | (C1), (C13) |
| (I3) | `window_dimensions` | 1-dimensional tensor constant of type `si64` | (C4), (C5), (C15) |
| (I4) | `window_strides` | 1-dimensional tensor constant of type `si64` | (C6), (C7), (C15) |
| (I5) | `base_dilations` | 1-dimensional tensor constant of type `si64` | (C8), (C9), (C15) |
Expand Down Expand Up @@ -4349,15 +4350,15 @@ where:
* (C12) `shape(padding) = [rank(inputs[0]), 2]`.
* (C13) `body` has type `(tensor<E0>, ..., tensor<EN-1>, tensor<E0>, ...,`
`tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)` where
`is_promotable(element_type(inputs[i]), element_type(E[i]))`.
`is_promotable(element_type(inputs[i]), Ei)`.
* (C14) `same(shape(results...))`.
* (C15) `shape(results[0]) = num_windows` where:
* `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`.
* `padded_input_shape = padding[:, 0] + dilated_input_shape + padding[:, 1]`.
* `dilated_window_shape = (window_dimensions - 1) * window_dilations + 1`.
* `is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape`.
* `num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1`.
* (C16) `is_promotable(element_type(results...), element_type(init_values...))`.
* (C16) `element_type(results[i]) = Ei` for all `i` in `[0,N)`.
<!-- markdownlint-enable line-length -->

#### Examples
Expand Down Expand Up @@ -4816,17 +4817,12 @@ Given that, `results = exec(schedule, inputs)`, where:
`index_space(updates[0])`.
* `exec([update_index, ...], results) = exec([...], updated_results)` where:
* If `result_index` is in bounds for `shape(results...)`
* `results_converted = to_destination_type(
results...[result_index], type(func_inputs(update_computation)
[:len(func_inputs(update_computation))//2])... )`
* `updates_converted = to_destination_type(
updates...[update_index], type(func_inputs(update_computation)
[len(func_inputs(update_computation))//2:])... )`
* `updated_values = update_computation(result_converted, updates_converted)`
* `updated_values_converted = to_destination_type(
updated_values, type(results...))`
* `updated_values = update_computation(results...[result_index], updates_converted)`
* `updated_results` is a copy of `results` with `results...[result_index]`
set to `updated_values_converted...`.
set to `updated_values...`.
* Otherwise
* `updated_results = results`.
* `exec([], results) = results`.
Expand Down Expand Up @@ -4860,7 +4856,7 @@ undefined.

| Name | Type | Constraints |
|-----------|------------------------------------------------------------|-------------|
| `results` | variadic number of tensors or per-tensor quantized tensors | (C15) |
| `results` | variadic number of tensors or per-tensor quantized tensors | (C15-C17) |

#### Constraints

Expand Down Expand Up @@ -4893,8 +4889,9 @@ undefined.
* (C14) `0 <= index_vector_dim <= rank(scatter_indices)`.
* (C15) `update_computation` has type `(tensor<E0>, ..., tensor<EN-1>,
tensor<E0>, ..., tensor<EN-1>) -> (tensor<E0>, ..., tensor<EN-1>)`,
where `is_promotable(Ei, element_type(inputs[i]))`.
* (C16) `is_promotable(type(inputs...), type(results...))`.
where `is_promotable(element_type(inputs[i]), Ei)`.
* (C16) `shape(inputs...) = shape(results...)`.
* (C17) `element_type(results[i]) = Ei` for all `i` in `[0,N)`.

#### Examples

Expand Down Expand Up @@ -5031,7 +5028,7 @@ More formally:

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C11) |
| `result` | tensor or per-tensor quantized tensor | (C11-C12) |

#### Constraints

Expand All @@ -5050,8 +5047,9 @@ More formally:
* (C9) `select` has type `(tensor<E>, tensor<E>) -> tensor<i1>` where
`E = element_type(operand)`.
* (C10) `scatter` has type `(tensor<E>, tensor<E>) -> tensor<E>` where
`is_promotable(element_type(operand), element_type(E))`.
* (C11) `is_promotable(type(operand), type(result))`.
`is_promotable(element_type(operand), E)`.
* (C11) `shape(operand) = shape(result)`.
* (C12) `element_type(result) = E`.
<!-- markdownlint-enable line-length -->

#### Examples
Expand Down Expand Up @@ -6342,15 +6340,22 @@ currently used in context of reduction computation (refer to

```python
def is_promotable(x: Type, y: Type) -> Value:
if x == Type and y == Type:
return shape(x) == shape(y) and is_promotable(element_type(x), element_type(y))

if x != Type and y != Type:
return (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or
(is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))
is_same_type = (is_bool(x) and is_bool(y)) or
(is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or
(is_complex(x) and is_complex(y)) or
(is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y))

if is_same_type == False:
return False

if is_integer(x) or is_float(x):
return bitwidth(x) <= bitwidth(y)

if is_complex(x):
return bitwidth(element_type(x)) <= bitwidth(element_type(y))

if is_quantized(x):
return bitwidth(storage_type(x)) <= bitwidth(storage_type(y))

return false
```
Expand Down

0 comments on commit 400557f

Please sign in to comment.