Skip to content

Commit

Permalink
Fix markdown lint issues (#1852)
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK authored Nov 20, 2023
1 parent 76e25a5 commit 95fb1d4
Showing 1 changed file with 51 additions and 11 deletions.
62 changes: 51 additions & 11 deletions rfcs/20231017-collective-broadcast.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,60 @@ Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/1809)

## Motivation

StableHLO currently has [five collective communication primitives](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective-ops): `collective_permute`, `all_gather`, `all_to_all`, `all_reduce`, and `reduce_scatter`. However, one of the major collective communication primitives, `broadcast`, is missing from this list. This primitive allows for a one-to-many replication of a tensor to many devices efficiently. `broadcast` is a primitive in [MPI](https://www.open-mpi.org/doc/v4.1/man3/MPI_Bcast.3.php), [NCCL](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#c.ncclBroadcast), and [PyTorch](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast). From here on out, we will refer to this operation as `collective_broadcast` for reasons discussed later.

While it technically would be possible to replicate a broadcast with a conditional mask and a `psum`, that reduces to an `all_reduce` communication primitive, which is significantly more expensive than a simple `collective_broadcast`. Additionally, when dealing with network-switch environments, the explicit use of `collective_broadcast` allows the switch to greatly optimize it's throughput when replicating to many targets simultaneously. However, XLA currently has no ability to lower directly to a mesh's `collective_broadcast` primitive, so a lot of that optimization is left on the table.

Additionally, a new compiler pass that detects usage of the old `psum` hack and replaces it with a `collective_broadcast` could be implemented only once and forever be supported by all hardware, future and current. This could have positive knock-on effects for users who don't even realize they're using it!

`collective_broadcast` can be used to quickly replicate a tensor across an entire mesh, and would use less communication resources as compared to `all_gather` or `psum`. `collective_broadcast` is also the base primitive used in the [SUMMA](https://www.netlib.org/lapack/lawnspdf/lawn96.pdf) distributed GEMM algorithm. As AI computing grows larger, there likely will grow a need for these 2D distributed GEMM algorithms. Adding support for one of the needed primitives could help advance research in these areas.
StableHLO currently has [five collective communication primitives](https://github.com/openxla/stablehlo/blob/main/docs/spec.md#collective-ops):
`collective_permute`, `all_gather`, `all_to_all`, `all_reduce`, and
`reduce_scatter`. However, one of the major collective communication
primitives, `broadcast`, is missing from this list. This primitive allows for a
one-to-many replication of a tensor to many devices efficiently. `broadcast` is
a primitive in [MPI](https://www.open-mpi.org/doc/v4.1/man3/MPI_Bcast.3.php),
[NCCL](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#c.ncclBroadcast),
and [PyTorch](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast).
From here on out, we will refer to this operation as `collective_broadcast` for
reasons discussed later.

While it technically would be possible to replicate a broadcast with a
conditional mask and a `psum`, that reduces to an `all_reduce` communication
primitive, which is significantly more expensive than a simple
`collective_broadcast`. Additionally, when dealing with network-switch
environments, the explicit use of `collective_broadcast` allows the switch to
greatly optimize it's throughput when replicating to many targets
simultaneously. However, XLA currently has no ability to lower directly to a
mesh's `collective_broadcast` primitive, so a lot of that optimization is left
on the table.

Additionally, a new compiler pass that detects usage of the old `psum` hack and
replaces it with a `collective_broadcast` could be implemented only once and
forever be supported by all hardware, future and current. This could have
positive knock-on effects for users who don't even realize they're using it!

`collective_broadcast` can be used to quickly replicate a tensor across an
entire mesh, and would use less communication resources as compared to
`all_gather` or `psum`. `collective_broadcast` is also the base primitive used
in the [SUMMA](https://www.netlib.org/lapack/lawnspdf/lawn96.pdf) distributed
GEMM algorithm. As AI computing grows larger, there likely will grow a need for
these 2D distributed GEMM algorithms. Adding support for one of the needed
primitives could help advance research in these areas.

## Alternatives considered

Instead of adding `collective_broadcast` as a primitive, we considered loosening the restriction of `collective_permute` to allow a one-to-many communication schedule instead of the current restriction of a one-to-one schedule. Downstream compilers would then be responsible for detecting this and calling their own `collective_broadcast` primitive. However, loosening this restriction makes defining the transposition rule for `collective_permute` significantly more complicated. Questions of how to calculate that and do it efficiently given any communication configuration and do so in SPMD became difficult. However, the transposition rule for `collective_broadcast` is just `psum` with a source-device one-hot masking. This simplicity plus the broad usage of `collective_broadcast` in the wider ecosystem made us choose to ultimately add the new primitive instead.
Instead of adding `collective_broadcast` as a primitive, we considered
loosening the restriction of `collective_permute` to allow a one-to-many
communication schedule instead of the current restriction of a one-to-one
schedule. Downstream compilers would then be responsible for detecting this and
calling their own `collective_broadcast` primitive. However, loosening this
restriction makes defining the transposition rule for `collective_permute`
significantly more complicated. Questions of how to calculate that and do it
efficiently given any communication configuration and do so in SPMD became
difficult. However, the transposition rule for `collective_broadcast` is just
`psum` with a source-device one-hot masking. This simplicity plus the broad
usage of `collective_broadcast` in the wider ecosystem made us choose to
ultimately add the new primitive instead.

## Why call it collective_broadcast and not just broadcast?
Unfortunately, the op name `broadcast` is already taken by [an op in XLA proper](https://www.tensorflow.org/xla/operation_semantics#broadcast), so we can't have the two names clash. `collective_broadcast` was the preferred alternative.

Unfortunately, the op name `broadcast` is already taken by [an op in XLA proper](https://www.tensorflow.org/xla/operation_semantics#broadcast),
so we can't have the two names clash. `collective_broadcast` was the preferred
alternative.

## Proposed Specification

Expand All @@ -42,8 +82,8 @@ Afterwards, `result@process` is given by:

* `operand@process_groups[i, 0]` if there exists an `i` such that
the process is in `process_groups[i]`.
* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))` otherwise.

* `broadcast_in_dim(constant(0, element_type(result)), [], type(result))`
otherwise.

#### Inputs

Expand Down

0 comments on commit 95fb1d4

Please sign in to comment.