diff --git a/docs/status.md b/docs/status.md index 99759bcf60c..674ef76d471 100644 --- a/docs/status.md +++ b/docs/status.md @@ -61,7 +61,7 @@ one of the following tracking labels. | ceil | yes | yes | yes | yes | yes | | cholesky | yes | yes | yes | yes | revisit | | clamp | yes | revisit | yes | yes | yes | -| collective_broadcast | yes | revisit | yes | no | no | +| collective_broadcast | yes | revisit | yes | no | yes | | collective_permute | yes | revisit | yes | no | yes | | compare | yes | yes | yes | yes | yes | | complex | yes | yes | yes | yes | yes | diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 0d8ead15768..ee3ce800c08 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -328,6 +328,25 @@ SmallVector eval(Region ®ion, auto operand = scope.findTensor(clzOp.getOperand()); auto result = evalClzOp(operand, clzOp.getType()); scope.add(clzOp.getResult(), result); + } else if (auto collectiveBroadcastOp = + dyn_cast(op)) { + auto operand = scope.findTensor(collectiveBroadcastOp.getOperand()); + + auto replicaGroupsAttr = collectiveBroadcastOp.getReplicaGroups(); + auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape(); + SmallVector> replicaGroups(replicaGroupsShape[0]); + auto replicaGroupsIt = replicaGroupsAttr.getValues().begin(); + for (auto &replicaGroup : replicaGroups) + for (auto i = 0; i < replicaGroupsShape[1]; ++i, ++replicaGroupsIt) + replicaGroup.push_back(*replicaGroupsIt); + + ChannelId channelId = 0; + if (auto channelHandle = collectiveBroadcastOp.getChannelHandle()) + channelId = channelHandle->getHandle(); + + auto result = + evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process); + scope.add(collectiveBroadcastOp.getResult(), result); } else if (auto collectivePermuteOp = dyn_cast(op)) { auto operand = scope.findTensor(collectivePermuteOp.getOperand()); @@ -1074,6 +1093,28 @@ Tensor evalClzOp(const Tensor &operand, ShapedType resultType) { return result; } +Tensor evalCollectiveBroadcastOp( + const Tensor &operand, SmallVector> replicaGroups, + ChannelId channelId, Process *process) { + if (!process) + llvm::report_fatal_error( + "collective_broadcast is only supported when run via " + "interpreter.run_parallel"); + + ProcessGroups processGroups; + if (channelId <= 0) processGroups = process->crossReplica(replicaGroups); + if (channelId > 0) processGroups = process->crossPartition(replicaGroups); + + auto processGroup = processGroups.findGroup(process->getId()); + if (processGroup) + return process->rendezvous(*processGroup, channelId, operand) + .lookup((*processGroup)[0]); + + return evalBroadcastInDimOp( + makeScalar(convert(operand.getElementType(), 0.0)), {}, + operand.getType()); +} + Tensor evalCollectivePermuteOp( const Tensor &operand, SmallVector> sourceTargetPairs, ChannelId channelId, Process *process) { diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index b85c57d1de0..6b58066caae 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -62,6 +62,9 @@ Tensor evalCeilOp(const Tensor &operand, ShapedType resultType); Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max, ShapedType resultType); Tensor evalClzOp(const Tensor &operand, ShapedType resultType); +Tensor evalCollectiveBroadcastOp( + const Tensor &operand, SmallVector> replicaGroups, + ChannelId channelId, Process *process); Tensor evalCollectivePermuteOp( const Tensor &operand, SmallVector> sourceTargetPairs, ChannelId channelId, Process *process); diff --git a/stablehlo/reference/ProcessGrid.cpp b/stablehlo/reference/ProcessGrid.cpp index b6fd1f4cca4..72cc4b3ec89 100644 --- a/stablehlo/reference/ProcessGrid.cpp +++ b/stablehlo/reference/ProcessGrid.cpp @@ -49,8 +49,8 @@ bool ProcessId::operator==(const ProcessId &other) const { std::optional ProcessGroups::findGroup(ProcessId processId) { for (auto processGroup : *this) - for (auto id : processGroup) - if (id == processId) return processGroup; + if (llvm::find(processGroup, processId) != processGroup.end()) + return processGroup; return std::nullopt; } diff --git a/stablehlo/tests/interpret/collective_broadcast.mlir b/stablehlo/tests/interpret/collective_broadcast.mlir new file mode 100644 index 00000000000..bb2e5f4f658 --- /dev/null +++ b/stablehlo/tests/interpret/collective_broadcast.mlir @@ -0,0 +1,223 @@ +// RUN: stablehlo-translate --interpret -split-input-file %s + +module @cross_replica { + func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + return %result : tensor<1x2xi64> + } + func.func @main() { + %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> + %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> + %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> + %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> + %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { + programs=[[@collective_broadcast], [@collective_broadcast], + [@collective_broadcast], [@collective_broadcast]] + } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> + (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) + check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64> + check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> + func.return + } +} + +// ----- + +module @cross_replica_multiple_output { + func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + return %result : tensor<1x2xi64> + } + func.func @main() { + %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> + %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> + %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> + %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> + %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { + programs=[[@collective_broadcast], [@collective_broadcast], + [@collective_broadcast], [@collective_broadcast]] + } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> + (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) + check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> + func.return + } +} + +// ----- + +module @cross_replica_single_replica { + func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[0]]> : tensor<1x1xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + return %result : tensor<1x2xi64> + } + func.func @main() { + %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> + %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> + %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> + %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> + %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { + programs=[[@collective_broadcast, @collective_broadcast, + @collective_broadcast, @collective_broadcast]] + } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> + (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) + check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64> + check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> + check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> + func.return + } +} + +// ----- + +module @cross_replica_multiple_partitions { + func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + return %result : tensor<1x2xi64> + } + func.func @main() { + %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> + %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> + %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> + %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> + %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { + programs=[[@collective_broadcast, @collective_broadcast], + [@collective_broadcast, @collective_broadcast]] + } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> + (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) + check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#1, dense<[[7, 8]]> : tensor<1x2xi64> + check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> + func.return + } +} + +// ----- + +module @cross_partition { + func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + return %result : tensor<1x2xi64> + } + func.func @main() { + %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> + %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> + %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> + %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> + %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { + programs=[[@collective_broadcast, @collective_broadcast, + @collective_broadcast, @collective_broadcast]] + } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> + (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) + check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64> + check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> + func.return + } +} + +// ----- + +module @cross_partition_multiple_output { + func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + return %result : tensor<1x2xi64> + } + func.func @main() { + %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> + %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> + %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> + %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> + %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { + programs=[[@collective_broadcast, @collective_broadcast, + @collective_broadcast, @collective_broadcast]] + } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> + (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) + check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64> + func.return + } +} + +// ----- + +module @cross_partition_single_partition { + func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[0]]> : tensor<1x1xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + return %result : tensor<1x2xi64> + } + func.func @main() { + %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> + %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> + %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> + %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> + %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { + programs=[[@collective_broadcast], [@collective_broadcast], + [@collective_broadcast], [@collective_broadcast]] + } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> + (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) + check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64> + check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> + check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64> + check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> + func.return + } +} + +// ----- + +module @cross_partition_multiple_replicas { + func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> { + %result = "stablehlo.collective_broadcast"(%operand) { + replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<1x2xi64>) -> tensor<1x2xi64> + return %result : tensor<1x2xi64> + } + func.func @main() { + %operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64> + %operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64> + %operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64> + %operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64> + %results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) { + programs=[[@collective_broadcast, @collective_broadcast], + [@collective_broadcast, @collective_broadcast]] + } : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) -> + (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) + check.expect_eq_const %results#0, dense<[[3, 4]]> : tensor<1x2xi64> + check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64> + check.expect_eq_const %results#2, dense<[[7, 8]]> : tensor<1x2xi64> + check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64> + func.return + } +}