Skip to content

Commit

Permalink
Collectives Ops : Match example from the StableHLO op description com…
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigunj authored Aug 8, 2024
1 parent 3ec5546 commit d4405ad
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ Afterwards, within each `process_group`:
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
Expand Down Expand Up @@ -918,6 +920,7 @@ Afterwards, within each `process_group`:
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// channel_id = 0
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
// %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]]
// %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]]
Expand Down
19 changes: 9 additions & 10 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1356,11 +1356,11 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather",

Example:
```mlir
%result = "stablehlo.all_gather"(%operand) {
%result:2 = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
```
}];

Expand Down Expand Up @@ -1395,15 +1395,14 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce",

Example:
```mlir
%result = "stablehlo.all_reduce"(%operand) ({
%result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg1, %arg2 : tensor<i64>
stablehlo.return %0 : tensor<i64>
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<4xi64>) -> tensor<4xi64>
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
```
}];

Expand Down Expand Up @@ -1483,12 +1482,12 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all",

Example:
```mlir
%result = "stablehlo.all_to_all"(%operand) {
%result:2 = "stablehlo.all_to_all"(%operand1, %operand2) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
```
}];

Expand Down

0 comments on commit d4405ad

Please sign in to comment.