Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async RFC #2551

Closed
wants to merge 2 commits into from
Closed

Async RFC #2551

wants to merge 2 commits into from

Conversation

chaserileyroberts
Copy link
Contributor

@chaserileyroberts chaserileyroberts commented Sep 17, 2024

This RFC proposes adding several async features to StableHLO, mainly async_start, async_done, and the async<...> value type.

The end goal of this RFC is to allow JAX users to define their own collective matmul schedules, and allow more control to potentially better utilize their hardware.

stablehlo.return %new_sum : tensor<i64>
} : (tensor<i64>, tensor<i64>) -> async<tensor<i64>>

%result = "stablehlo.async_done"(%future): async<tensor<i64>> -> tensor<i64>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make async_done variadic and:

  1. Await on all async results nicely mapped to async-done HLO
  2. Await on one async result can be represented with async-update (needs fact checking), but initially we don't need to implement it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like when type is defined on the same line as op name, I think this should be parseable in MLIR. Also for done return type can be inferred from arguments, no need to spell it. But these details can be refined later.

// HLO async-start
%f0, %f1 = stablehlo.async_start(...) -> async<tensor<f32>>, async<tensor<f32>> {
  %0 = ... : tensor<f32>
  %1 = ... : tensor<f32>
  stablehlo.return %0, %1 : tensor<f32>, tensor<f32>
}

// HLO async-done
%t0, %t1 = stablehlo.async_done %f0, %f1 : async<tensor<f32>>, async<tensor<f32>>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should the the return type of async_start be (async<R0>, ..., async<RM>) instead of async<(R0, ..., RM)>?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because conceptually it should be possible to await on just one result. On GPU that should have a straightforward lowering to streams and events

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I'll update the spec then.


### Async Execution

Stable HLO programs are usually defined as simple sequential operations performed one after another,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually StableHLO (and HLO) has a dataflow semantics, and backend is free to reorder execution as long as it respects data dependencies (at HLO we also have control deps, but they are internal implementation detail)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll reword this.


Produces the output from executing the `body` function, but runs all operations on a stream separate from the main compute stream.

The output of an `async_start` computation must first be processed by an `async_done` operation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect to be able to pipe this through control flow. This is significantly more complicated if you want to pipe this through while and conds.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in MLIR-land that should be straightforward as stable HLO has straightforward structured control flow, and I think builtin MLIR dataflow analysis will work (needs fact checking!)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for SHLO this should be simple. I was rather wondering about the lowerings through XLA

Copy link
Contributor Author

@chaserileyroberts chaserileyroberts Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now i think we would be ok with this not needing to cross control flow boundaries, but I could see the need eventually.



```ebnf
AsyncType ::= 'async' '<' ValueType '>'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an async data type here? Within XLA, we encode this through tuples that forward operands and interim result from async start to async done ops. Would the same work on the stable HLO level?

Where I think this will be a bit special is when you want to guarantee that the values are not copied on loop boundaries, or generally around control flow. An alternative would be to introduce non-copyable values (buffers (?)).

Copy link
Member

@ezhulenev ezhulenev Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In HLO we use tuples because it's very hard to add new types to HLO. In MLIR adding types is very easy and natural. If extending HLO would not be that hard, I'd vote for adding async type to it as well.

Upd: in HLO we use tuples for async ops and rely on bunch of implicit assumption about how scheduling and buffer assignment works, and I'm not a big fan of it, because if you don't know the implementation detail, from HLO alone it's very hard to tell what's going on. HLO starts as value semantics, but then at some point becomes a buffer semantics, but in printed HLO nothing tells you what is the semantics. Keeping sHLO value-based with types imho a lot easier to parse for a human and to tell what's going on from reading IR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. We will have to lower it to the tuples anyways no? So this is really some form of syntactic sugar.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think async type <-> async bundle (tuples) representations are isomorphic and always can be converted back and forth:

  1. stablehlo.async_start -> %start = (args, results, sync-flag) async_start
  2. stablehlo.async_done %ret0, %ret1, ... -> async-done (get-tuple-element)
  3. Tricky case: stablehlo.async_done %ret0 just one of the returned values -> async-update (get-tuple-element), effectively peels N result buffers and M argument buffers from a tuple, and allows buffer assignment to reuse them (sorry, only internal link for Frederik http://goto.google.com/async-update-peeling). This is underspecified in HLO, and we don't need to support it today, and require that async_done must await on ALL results of corresponding start operation.

@chaserileyroberts
Copy link
Contributor Author

Closing this as we move to extending jax.compute_on instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants