diff --git a/docs/spec.md b/docs/spec.md index f4ad8199f09..57c06195506 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -829,7 +829,9 @@ Afterwards, within each `process_group`: "stablehlo.return"(%0) : (tensor) -> () }) { replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + // channel_id = 0 channel_handle = #stablehlo.channel_handle + // use_global_device_ids = false } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>) // %result0@(0, 0): [6, 8, 10, 12] // %result0@(1, 0): [6, 8, 10, 12] @@ -918,6 +920,7 @@ Afterwards, within each `process_group`: concat_dimension = 0 : i64, split_count = 2 : i64, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + // channel_id = 0 } : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>) // %result#0@(0, 0): [[1, 2], [5, 6], [9, 10], [13, 14]] // %result#0@(1, 0): [[3, 4], [7, 8], [11, 12], [15, 16]] diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 136e2476612..568cb3c67f6 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1356,11 +1356,11 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", Example: ```mlir - %result = "stablehlo.all_gather"(%operand) { + %result:2 = "stablehlo.all_gather"(%operand0, %operand1) { all_gather_dim = 1 : i64, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, channel_handle = #stablehlo.channel_handle - } : (tensor<2x2xi64>) -> tensor<2x4xi64> + } : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>) ``` }]; @@ -1395,15 +1395,14 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce", Example: ```mlir - %result = "stablehlo.all_reduce"(%operand) ({ + %result:2 = "stablehlo.all_reduce"(%operand0, %operand0) ({ ^bb0(%arg0: tensor, %arg1: tensor): - %0 = stablehlo.add %arg1, %arg2 : tensor - stablehlo.return %0 : tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () }) { - replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, channel_handle = #stablehlo.channel_handle - // use_global_device_ids = false - } : (tensor<4xi64>) -> tensor<4xi64> + } : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>) ``` }]; @@ -1483,12 +1482,12 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all", Example: ```mlir - %result = "stablehlo.all_to_all"(%operand) { + %result:2 = "stablehlo.all_to_all"(%operand1, %operand2) { split_dimension = 1 : i64, concat_dimension = 0 : i64, split_count = 2 : i64, replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - } : (tensor<2x4xi64>) -> tensor<4x2xi64> + } : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>) ``` }];