Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async RFC #2551

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions rfcs/20240917-async-support.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# [RFC] Add async support to the StableHLO specification

Status: Under Review<br/>
Initial version: 09/17/2024<br/>
Last updated: 09/17/2024<br/>
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/2551)


## Motivation

Today, stableHLO ops are designed to be executed sequentially, and any async dispatch or scheduling is left to the compiler to define.

However, getting XLA to generate optimized schedules has proven to be very challenging.
Users have found that this leaves a lot of performance on the table, and have vocalized a desire to have more control over the scheduling.

There is an [excellent write up](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)
from Yifu Wang that goes into detail the performance benefits of async tensor parallelism.

There is already existing async infrastructure in XLA that we use to create collective matmuls, so the main goal
is to expose this in stableHLO and have it accessible in JAX's `shard_map`.


## Proposed Specification changes

### Types


```ebnf
AsyncType ::= 'async' '<' ValueType '>'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an async data type here? Within XLA, we encode this through tuples that forward operands and interim result from async start to async done ops. Would the same work on the stable HLO level?

Where I think this will be a bit special is when you want to guarantee that the values are not copied on loop boundaries, or generally around control flow. An alternative would be to introduce non-copyable values (buffers (?)).

Copy link
Member

@ezhulenev ezhulenev Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In HLO we use tuples because it's very hard to add new types to HLO. In MLIR adding types is very easy and natural. If extending HLO would not be that hard, I'd vote for adding async type to it as well.

Upd: in HLO we use tuples for async ops and rely on bunch of implicit assumption about how scheduling and buffer assignment works, and I'm not a big fan of it, because if you don't know the implementation detail, from HLO alone it's very hard to tell what's going on. HLO starts as value semantics, but then at some point becomes a buffer semantics, but in printed HLO nothing tells you what is the semantics. Keeping sHLO value-based with types imho a lot easier to parse for a human and to tell what's going on from reading IR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. We will have to lower it to the tuples anyways no? So this is really some form of syntactic sugar.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think async type <-> async bundle (tuples) representations are isomorphic and always can be converted back and forth:

  1. stablehlo.async_start -> %start = (args, results, sync-flag) async_start
  2. stablehlo.async_done %ret0, %ret1, ... -> async-done (get-tuple-element)
  3. Tricky case: stablehlo.async_done %ret0 just one of the returned values -> async-update (get-tuple-element), effectively peels N result buffers and M argument buffers from a tuple, and allows buffer assignment to reuse them (sorry, only internal link for Frederik http://goto.google.com/async-update-peeling). This is underspecified in HLO, and we don't need to support it today, and require that async_done must await on ALL results of corresponding start operation.

```
*Async Types* represents tensor values that must be awaited on before using the underlying values. Async operations
allow multiple operations to be running at once as described in the Async Execution section.

Add `AsyncType` to `ValueType`

```ebnf
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | AsyncType
```

### Ops

### async_start

#### Semantics

Produces the output from executing the `body` function, but runs all operations on a stream separate from the main compute stream.

The output of an `async_start` computation must first be processed by an `async_done` operation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect to be able to pipe this through control flow. This is significantly more complicated if you want to pipe this through while and conds.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in MLIR-land that should be straightforward as stable HLO has straightforward structured control flow, and I think builtin MLIR dataflow analysis will work (needs fact checking!)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for SHLO this should be simple. I was rather wondering about the lowerings through XLA

Copy link
Contributor Author

@chaserileyroberts chaserileyroberts Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now i think we would be ok with this not needing to cross control flow boundaries, but I could see the need eventually.


#### Inputs

| Label | Name | Type | Constraints |
|-------|-----------|---------------------------------------------------------|-------------|
| (I1) | `operand` | variadic number of tensors, quantized tensors or tokens | (C1) |
| (I2) | `body` | function | (C1) |

#### Outputs

| Name | Type | Constraints |
|-----------|------------------------------------|-------------|
| `results` | variadic number of async values | (C1) |

#### Constraints

* (C1) `body` has type `(T0, ..., TN-1) -> (R0, ..., RM-1)`, where
`type(operand[i]) == Ti` and `type(results[i]) == async<Ri>`

#### Examples

```mlir
// %init_i: 2
// %init_sum: 3
%future = "stablehlo.async_start"(
%init_i as %arg0: tensor<i64>,
%init_sum as %arg1: tensor<i64>)
{
%new_sum = stablehlo.add %arg1, %arg0 : tensor<i64>
stablehlo.return %new_sum : tensor<i64>
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>>

%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make async_done variadic and:

  1. Await on all async results nicely mapped to async-done HLO
  2. Await on one async result can be represented with async-update (needs fact checking), but initially we don't need to implement it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like when type is defined on the same line as op name, I think this should be parseable in MLIR. Also for done return type can be inferred from arguments, no need to spell it. But these details can be refined later.

// HLO async-start
%f0, %f1 = stablehlo.async_start(...) -> async<tensor<f32>>, async<tensor<f32>> {
  %0 = ... : tensor<f32>
  %1 = ... : tensor<f32>
  stablehlo.return %0, %1 : tensor<f32>, tensor<f32>
}

// HLO async-done
%t0, %t1 = stablehlo.async_done %f0, %f1 : async<tensor<f32>>, async<tensor<f32>>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should the the return type of async_start be (async<R0>, ..., async<RM>) instead of async<(R0, ..., RM)>?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because conceptually it should be possible to await on just one result. On GPU that should have a straightforward lowering to streams and events

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I'll update the spec then.

// %result: 5
```

### async_done

#### Semantics

Waits for the values created by an `async_start` operation to be finalized. All tensors given to `async_done` must has type `async<T>`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-----------|------------------------------------|-------------|
| (I1) | `operand` | variadic number of async values | (C1) |

#### Outputs

| Name | Type | Constraints |
|-----------|---------------------------------------------------------|-------------|
| `results` | variadic number of tensors, quantized tensors or tokens | (C1) |

#### Constraints

* (C1) `type(operand) == (async<T0>, ..., async<TN>)` and `type(result) == (T0, ..., TN)`

#### Examples

```mlir
// %init_i: 2
// %init_sum: 3
%future = "stablehlo.async_start"(
%init_i as %arg0: tensor<i64>,
%init_sum as %arg1: tensor<i64>)
{
%new_sum = stablehlo.add %arg1, %arg0 : tensor<i64>
stablehlo.return %new_sum : tensor<i64>
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>>

%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64>
// %result: 5
```


## Execution

### Async Execution

Stable HLO programs have dataflow semantics, and each backend is free to dispatch execution in any
pattern as long as it respects data dependencies. But this usually means that only one operation is run at a time.
However, the ops `async_start` and `async_done`, along with barriers or `token` management would allow you to define data dependencies that can force one
operation to start before another finishes. This could allow you to better utilize your hardware or to define your own communication schedule.
Async operations are an advanced tool that should only be used when you know what you are doing.
Loading