-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Async RFC #2551
Async RFC #2551
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# [RFC] Add async support to the StableHLO specification | ||
|
||
Status: Under Review<br/> | ||
Initial version: 09/17/2024<br/> | ||
Last updated: 09/17/2024<br/> | ||
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/2551) | ||
|
||
|
||
## Motivation | ||
|
||
Today, stableHLO ops are designed to be executed sequentially, and any async dispatch or scheduling is left to the compiler to define. | ||
|
||
However, getting XLA to generate optimized schedules has proven to be very challenging. | ||
Users have found that this leaves a lot of performance on the table, and have vocalized a desire to have more control over the scheduling. | ||
|
||
There is an [excellent write up](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487) | ||
from Yifu Wang that goes into detail the performance benefits of async tensor parallelism. | ||
|
||
There is already existing async infrastructure in XLA that we use to create collective matmuls, so the main goal | ||
is to expose this in stableHLO and have it accessible in JAX's `shard_map`. | ||
|
||
|
||
## Proposed Specification changes | ||
|
||
### Types | ||
|
||
|
||
```ebnf | ||
AsyncType ::= 'async' '<' ValueType '>' | ||
``` | ||
*Async Types* represents tensor values that must be awaited on before using the underlying values. Async operations | ||
allow multiple operations to be running at once as described in the Async Execution section. | ||
|
||
Add `AsyncType` to `ValueType` | ||
|
||
```ebnf | ||
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | AsyncType | ||
``` | ||
|
||
### Ops | ||
|
||
### async_start | ||
|
||
#### Semantics | ||
|
||
Produces the output from executing the `body` function, but runs all operations on a stream separate from the main compute stream. | ||
|
||
The output of an `async_start` computation must first be processed by an `async_done` operation. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you expect to be able to pipe this through control flow. This is significantly more complicated if you want to pipe this through while and conds. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think in MLIR-land that should be straightforward as stable HLO has straightforward structured control flow, and I think builtin MLIR dataflow analysis will work (needs fact checking!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, for SHLO this should be simple. I was rather wondering about the lowerings through XLA There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now i think we would be ok with this not needing to cross control flow boundaries, but I could see the need eventually. |
||
|
||
#### Inputs | ||
|
||
| Label | Name | Type | Constraints | | ||
|-------|-----------|---------------------------------------------------------|-------------| | ||
| (I1) | `operand` | variadic number of tensors, quantized tensors or tokens | (C1) | | ||
| (I2) | `body` | function | (C1) | | ||
|
||
#### Outputs | ||
|
||
| Name | Type | Constraints | | ||
|-----------|------------------------------------|-------------| | ||
| `results` | variadic number of async values | (C1) | | ||
|
||
#### Constraints | ||
|
||
* (C1) `body` has type `(T0, ..., TN-1) -> (R0, ..., RM-1)`, where | ||
`type(operand[i]) == Ti` and `type(results[i]) == async<Ri>` | ||
|
||
#### Examples | ||
|
||
```mlir | ||
// %init_i: 2 | ||
// %init_sum: 3 | ||
%future = "stablehlo.async_start"( | ||
%init_i as %arg0: tensor<i64>, | ||
%init_sum as %arg1: tensor<i64>) | ||
{ | ||
%new_sum = stablehlo.add %arg1, %arg0 : tensor<i64> | ||
stablehlo.return %new_sum : tensor<i64> | ||
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>> | ||
|
||
%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like when type is defined on the same line as op name, I think this should be parseable in MLIR. Also for // HLO async-start
%f0, %f1 = stablehlo.async_start(...) -> async<tensor<f32>>, async<tensor<f32>> {
%0 = ... : tensor<f32>
%1 = ... : tensor<f32>
stablehlo.return %0, %1 : tensor<f32>, tensor<f32>
}
// HLO async-done
%t0, %t1 = stablehlo.async_done %f0, %f1 : async<tensor<f32>>, async<tensor<f32>> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So should the the return type of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because conceptually it should be possible to await on just one result. On GPU that should have a straightforward lowering to streams and events There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. I'll update the spec then. |
||
// %result: 5 | ||
``` | ||
|
||
### async_done | ||
|
||
#### Semantics | ||
|
||
Waits for the values created by an `async_start` operation to be finalized. All tensors given to `async_done` must has type `async<T>`. | ||
|
||
#### Inputs | ||
|
||
| Label | Name | Type | Constraints | | ||
|-------|-----------|------------------------------------|-------------| | ||
| (I1) | `operand` | variadic number of async values | (C1) | | ||
|
||
#### Outputs | ||
|
||
| Name | Type | Constraints | | ||
|-----------|---------------------------------------------------------|-------------| | ||
| `results` | variadic number of tensors, quantized tensors or tokens | (C1) | | ||
|
||
#### Constraints | ||
|
||
* (C1) `type(operand) == (async<T0>, ..., async<TN>)` and `type(result) == (T0, ..., TN)` | ||
|
||
#### Examples | ||
|
||
```mlir | ||
// %init_i: 2 | ||
// %init_sum: 3 | ||
%future = "stablehlo.async_start"( | ||
%init_i as %arg0: tensor<i64>, | ||
%init_sum as %arg1: tensor<i64>) | ||
{ | ||
%new_sum = stablehlo.add %arg1, %arg0 : tensor<i64> | ||
stablehlo.return %new_sum : tensor<i64> | ||
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>> | ||
|
||
%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64> | ||
// %result: 5 | ||
``` | ||
|
||
|
||
## Execution | ||
|
||
### Async Execution | ||
|
||
Stable HLO programs have dataflow semantics, and each backend is free to dispatch execution in any | ||
pattern as long as it respects data dependencies. But this usually means that only one operation is run at a time. | ||
However, the ops `async_start` and `async_done`, along with barriers or `token` management would allow you to define data dependencies that can force one | ||
operation to start before another finishes. This could allow you to better utilize your hardware or to define your own communication schedule. | ||
Async operations are an advanced tool that should only be used when you know what you are doing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need an async data type here? Within XLA, we encode this through tuples that forward operands and interim result from async start to async done ops. Would the same work on the stable HLO level?
Where I think this will be a bit special is when you want to guarantee that the values are not copied on loop boundaries, or generally around control flow. An alternative would be to introduce non-copyable values (buffers (?)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In HLO we use tuples because it's very hard to add new types to HLO. In MLIR adding types is very easy and natural. If extending HLO would not be that hard, I'd vote for adding async type to it as well.
Upd: in HLO we use tuples for async ops and rely on bunch of implicit assumption about how scheduling and buffer assignment works, and I'm not a big fan of it, because if you don't know the implementation detail, from HLO alone it's very hard to tell what's going on. HLO starts as value semantics, but then at some point becomes a buffer semantics, but in printed HLO nothing tells you what is the semantics. Keeping sHLO value-based with types imho a lot easier to parse for a human and to tell what's going on from reading IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. We will have to lower it to the tuples anyways no? So this is really some form of syntactic sugar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think async type <-> async bundle (tuples) representations are isomorphic and always can be converted back and forth:
stablehlo.async_start
->%start = (args, results, sync-flag) async_start
stablehlo.async_done %ret0, %ret1, ...
->async-done (get-tuple-element)
stablehlo.async_done %ret0
just one of the returned values ->async-update (get-tuple-element
), effectively peels N result buffers and M argument buffers from a tuple, and allows buffer assignment to reuse them (sorry, only internal link for Frederik http://goto.google.com/async-update-peeling). This is underspecified in HLO, and we don't need to support it today, and require thatasync_done
must await on ALL results of corresponding start operation.