From b7f938d97f4731026612f138d63f53b0541a9c62 Mon Sep 17 00:00:00 2001 From: Chase Roberts Date: Tue, 17 Sep 2024 14:12:44 -0700 Subject: [PATCH 1/2] Added async RFC --- rfcs/20240917-async-support.md | 133 +++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 rfcs/20240917-async-support.md diff --git a/rfcs/20240917-async-support.md b/rfcs/20240917-async-support.md new file mode 100644 index 00000000000..ba7cfdbb077 --- /dev/null +++ b/rfcs/20240917-async-support.md @@ -0,0 +1,133 @@ +# [RFC] Add async support to the StableHLO specification + +Status: Under Review
+Initial version: 09/17/2024
+Last updated: 09/17/2024
+Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/2551) + + +## Motivation + +Today, stableHLO ops are designed to be executed sequentially, and any async dispatch or scheduling is left to the compiler to define. + +However, getting XLA to generate optimized schedules has proven to be very challenging. +Users have found that this leaves a lot of performance on the table, and have vocalized a desire to have more control over the scheduling. + +There is an [excellent write up](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487) +from Yifu Wang that goes into detail the performance benefits of async tensor parallelism. + +There is already existing async infrastructure in XLA that we use to create collective matmuls, so the main goal +is to expose this in stableHLO and have it accessible in JAX's `shard_map`. + + +## Proposed Specification changes + +### Types + + +```ebnf +AsyncType ::= 'async' '<' ValueType '>' +``` +*Async Types* represents tensor values that must be awaited on before using the underlying values. Async operations +allow multiple operations to be running at once as described in the Async Execution section. + +Add `AsyncType` to `NonValueType` + +```ebnf +NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType | AsyncType +``` + +### Ops + +### async_start + +#### Semantics + +Produces the output from executing the `body` function, but runs all operations on a stream separate from the main compute stream. + +The output of an `async_start` computation must first be processed by an `async_done` operation. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|-----------|---------------------------------------------------------|-------------| +| (I1) | `operand` | variadic number of tensors, quantized tensors or tokens | (C1) | +| (I2) | `body` | function | (C1) | + +#### Outputs + +| Name | Type | Constraints | +|-----------|------------------------------------|-------------| +| `results` | variadic number of async values | (C1) | + +#### Constraints + +* (C1) `body` has type `(T0, ..., TN-1) -> (R0, ..., RM-1)`, where + `type(operand[i]) == Ti` and `type(results[i]) == async` + +#### Examples + +```mlir +// %init_i: 2 +// %init_sum: 3 +%future = "stablehlo.async_start"( + %init_i as %arg0: tensor, + %init_sum as %arg1: tensor) +{ + %new_sum = stablehlo.add %arg1, %arg0 : tensor + stablehlo.return %new_sum : tensor +} : (tensor, tensor) -> async> + +%result = "stablehlo.async_done"(%future): async> -> tensor +// %result: 5 +``` + +### async_done + +#### Semantics + +Waits for the values created by an `async_start` operation to be finalized. All tensors given to `async_done` must has type `async`. + +#### Inputs + +| Label | Name | Type | Constraints | +|-------|-----------|------------------------------------|-------------| +| (I1) | `operand` | variadic number of async values | (C1) | + +#### Outputs + +| Name | Type | Constraints | +|-----------|---------------------------------------------------------|-------------| +| `results` | variadic number of tensors, quantized tensors or tokens | (C1) | + +#### Constraints + +* (C1) `type(operand) == (async, ..., async)` and `type(result) == (T0, ..., TN)` + +#### Examples + +```mlir +// %init_i: 2 +// %init_sum: 3 +%future = "stablehlo.async_start"( + %init_i as %arg0: tensor, + %init_sum as %arg1: tensor) +{ + %new_sum = stablehlo.add %arg1, %arg0 : tensor + stablehlo.return %new_sum : tensor +} : (tensor, tensor) -> async> + +%result = "stablehlo.async_done"(%future): async> -> tensor +// %result: 5 +``` + + +## Execution + +### Async Execution + +Stable HLO programs have dataflow semantics, and each backend is free to dispatch execution in any +pattern as long as it respects data dependencies. But this usually means that only one operation is run at a time. +However, the ops `async_start` and `async_done`, along with barriers or `token` management would allow you to define data dependencies that can force one +operation to start before another finishes. This could allow you to better utilize your hardware or to define your own communication schedule. +Async operations are an advanced tool that should only be used when you know what you are doing. From 53765b53a84dab27858553ac5713a20c27c1518e Mon Sep 17 00:00:00 2001 From: Chase Roberts Date: Thu, 19 Sep 2024 12:13:26 -0700 Subject: [PATCH 2/2] Async is now a ValueType --- rfcs/20240917-async-support.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rfcs/20240917-async-support.md b/rfcs/20240917-async-support.md index ba7cfdbb077..face036a418 100644 --- a/rfcs/20240917-async-support.md +++ b/rfcs/20240917-async-support.md @@ -31,10 +31,10 @@ AsyncType ::= 'async' '<' ValueType '>' *Async Types* represents tensor values that must be awaited on before using the underlying values. Async operations allow multiple operations to be running at once as described in the Async Execution section. -Add `AsyncType` to `NonValueType` +Add `AsyncType` to `ValueType` ```ebnf -NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType | AsyncType +ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | AsyncType ``` ### Ops