diff --git a/docs/spec.md b/docs/spec.md index de009fd7fd0..ce4c5d296fc 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -778,8 +778,8 @@ defined as follows: Afterwards, within each `process_group`: -* `result@process[result_index] = to_destination_type(exec(schedule), - type(result))` for some binary tree `schedule` where: +* `result@process[result_index] = exec(schedule)` for some binary tree + `schedule` where: * `exec(node)` = `computation(exec(node.left), exec(node.right))`. * `exec(leaf)` = `leaf.value`. * `schedule` is an implementation-defined binary tree whose in-order @@ -800,7 +800,7 @@ Afterwards, within each `process_group`: | Name | Type | Constraints | |----------|---------------------------------------|-------------| -| `result` | tensor or per-tensor quantized tensor | (C6) | +| `result` | tensor or per-tensor quantized tensor | (C6-C7) | #### Constraints @@ -812,8 +812,9 @@ Afterwards, within each `process_group`: * (C3) `0 <= replica_groups < size(replica_groups)`. * (C4) If `use_global_device_ids = true`, then `channel_id > 0`. * (C5) `computation` has type `(tensor, tensor) -> (tensor)` where - `is_promotable(E, element_type(operand))`. -* (C6) `is_promotable(type(result), type(operand))`. + `is_promotable(element_type(operand), E)`. +* (C6) `shape(result) = shape(operand)`. +* (C7) `element_type(result) = E`. #### Examples @@ -4012,8 +4013,7 @@ doesn't hold for many popular reductions. E.g. floating-point addition for `body` and zero for `init_values` don't actually form a monoid because floating-point addition is not associative. -More formally, `results...[j0, ..., jR-1] = to_destination_type( -reduce(input_slices_converted), type(results...)))` where: +More formally, `results...[j0, ..., jR-1] = reduce(input_slices_converted)` where: * `input_slices = inputs...[j0, ..., :, ..., jR-1]`, where `:` are inserted at `dimensions`. @@ -4057,10 +4057,10 @@ reduce(input_slices_converted), type(results...)))` where: * (C5) `is_unique(dimensions)`. * (C6) `body` has type `(tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `is_promotable(element_type(inputs[i]), element_type(E[i]))`. + `is_promotable(element_type(inputs[i]), Ei)`. * (C7) `shape(results...) = shape(inputs...)` except that the dimension sizes of `inputs...` corresponding to `dimensions` are not included. -* (C8) `is_promotable(element_type(inputs...), element_type(results...))`. +* (C8) `element_type(results[i]) = Ei` for all `i` in `[0,N)`. #### Examples @@ -4182,7 +4182,7 @@ Afterwards, within each `process_group`: | Name | Type | Constraints | |----------|---------------------------------------|-------------| -| `result` | tensor or per-tensor quantized tensor | (C8) | +| `result` | tensor or per-tensor quantized tensor | (C8-C9) | #### Constraints @@ -4196,10 +4196,11 @@ Afterwards, within each `process_group`: * (C5) `0 <= replica_groups < size(replica_groups)`. * (C6) If `use_global_device_ids = true`, then `channel_id > 0`. * (C7) `computation` has type `(tensor, tensor) -> (tensor)` where - `is_promotable(E, element_type(operand))`. -* (C8) `is_promotable(type(result), type(operand))` except: + `is_promotable(element_type(operand), E)`. +* (C8) `shape(result) = shape(operand)` except: * `dim(result, scatter_dimension) = dim(operand, scatter_dimension) / dim(process_groups, 1)`. +* (C9) `element_type(result) = E`. #### Examples @@ -4256,7 +4257,7 @@ where: | Label | Name | Type | Constraints | |-------|---------------------|--------------------------------------------------------------------------|-------------------------------------------------| | (I1) | `inputs` | variadic number of tensors or per-tensor quantized tensors | (C1-C4), (C6), (C8), (C10), (C12), (C13), (C15) | -| (I2) | `init_values` | variadic number of 0-dimensional tensors or per-tensor quantized tensors | (C1), (C13), (C16) | +| (I2) | `init_values` | variadic number of 0-dimensional tensors or per-tensor quantized tensors | (C1), (C13) | | (I3) | `window_dimensions` | 1-dimensional tensor constant of type `si64` | (C4), (C5), (C15) | | (I4) | `window_strides` | 1-dimensional tensor constant of type `si64` | (C6), (C7), (C15) | | (I5) | `base_dilations` | 1-dimensional tensor constant of type `si64` | (C8), (C9), (C15) | @@ -4287,7 +4288,7 @@ where: * (C12) `shape(padding) = [rank(inputs[0]), 2]`. * (C13) `body` has type `(tensor, ..., tensor, tensor, ...,` `tensor) -> (tensor, ..., tensor)` where - `is_promotable(element_type(inputs[i]), element_type(E[i]))`. + `is_promotable(element_type(inputs[i]), Ei)`. * (C14) `same(shape(results...))`. * (C15) `shape(results[0]) = num_windows` where: * `dilated_input_shape = shape(inputs[0]) = 0 ? 0 : (shape(inputs[0]) - 1) * base_dilations + 1`. @@ -4295,7 +4296,7 @@ where: * `dilated_window_shape = (window_dimensions - 1) * window_dilations + 1`. * `is_empty_window = padded_input_shape = 0 || dilated_window_shape > padded_input_shape`. * `num_windows = is_empty_window ? 0 : floor((padded_input_shape - dilated_window_shape) / window_strides) + 1`. -* (C16) `is_promotable(element_type(results...), element_type(init_values...))`. +* (C16) `element_type(results[i]) = Ei` for all `i` in `[0,N)`. #### Examples @@ -4754,17 +4755,12 @@ Given that, `results = exec(schedule, inputs)`, where: `index_space(updates[0])`. * `exec([update_index, ...], results) = exec([...], updated_results)` where: * If `result_index` is in bounds for `shape(results...)` - * `results_converted = to_destination_type( - results...[result_index], type(func_inputs(update_computation) - [:len(func_inputs(update_computation))//2])... )` * `updates_converted = to_destination_type( updates...[update_index], type(func_inputs(update_computation) [len(func_inputs(update_computation))//2:])... )` - * `updated_values = update_computation(result_converted, updates_converted)` - * `updated_values_converted = to_destination_type( - updated_values, type(results...))` + * `updated_values = update_computation(results...[result_index], updates_converted)` * `updated_results` is a copy of `results` with `results...[result_index]` - set to `updated_values_converted...`. + set to `updated_values...`. * Otherwise * `updated_results = results`. * `exec([], results) = results`. @@ -4798,7 +4794,7 @@ undefined. | Name | Type | Constraints | |-----------|------------------------------------------------------------|-------------| -| `results` | variadic number of tensors or per-tensor quantized tensors | (C15) | +| `results` | variadic number of tensors or per-tensor quantized tensors | (C15-C17) | #### Constraints @@ -4831,8 +4827,9 @@ undefined. * (C14) `0 <= index_vector_dim <= rank(scatter_indices)`. * (C15) `update_computation` has type `(tensor, ..., tensor, tensor, ..., tensor) -> (tensor, ..., tensor)`, - where `is_promotable(Ei, element_type(inputs[i]))`. -* (C16) `is_promotable(type(inputs...), type(results...))`. + where `is_promotable(element_type(inputs[i]), Ei)`. +* (C16) `shape(inputs...) = shape(results...)`. +* (C17) `element_type(results[i]) = Ei` for all `i` in `[0,N)`. #### Examples @@ -4969,7 +4966,7 @@ More formally: | Name | Type | Constraints | |----------|---------------------------------------|-------------| -| `result` | tensor or per-tensor quantized tensor | (C11) | +| `result` | tensor or per-tensor quantized tensor | (C11-C12) | #### Constraints @@ -4988,8 +4985,9 @@ More formally: * (C9) `select` has type `(tensor, tensor) -> tensor` where `E = element_type(operand)`. * (C10) `scatter` has type `(tensor, tensor) -> tensor` where - `is_promotable(element_type(operand), element_type(E))`. -* (C11) `is_promotable(type(operand), type(result))`. + `is_promotable(element_type(operand), E)`. +* (C11) `shape(operand) = shape(result)`. +* (C12) `element_type(result) = E`. #### Examples @@ -6280,15 +6278,22 @@ currently used in context of reduction computation (refer to ```python def is_promotable(x: Type, y: Type) -> Value: - if x == Type and y == Type: - return shape(x) == shape(y) and is_promotable(element_type(x), element_type(y)) - - if x != Type and y != Type: - return (is_bool(x) and is_bool(y)) or - (is_integer(x) and is_integer(y)) or - (is_float(x) and is_float(y)) or - (is_complex(x) and is_complex(y)) or - (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y)) + is_same_type = (is_bool(x) and is_bool(y)) or + (is_integer(x) and is_integer(y)) or (is_float(x) and is_float(y)) or + (is_complex(x) and is_complex(y)) or + (is_quantized(x) and is_quantized(y) and expressed_type(x) = expressed_type(y)) + + if is_same_type == False: + return False + + if is_integer(x) or is_float(x): + return bitwidth(x) <= bitwidth(y) + + if is_complex(x): + return bitwidth(element_type(x)) <= bitwidth(element_type(y)) + + if is_quantized(x): + return bitwidth(storage_type(x)) <= bitwidth(storage_type(y)) return false ```