Skip to content

Commit

Permalink
Add interpreter for CollectiveBroadcastOp (#1983)
Browse files Browse the repository at this point in the history
closes #1982
  • Loading branch information
ghpvnist authored Feb 6, 2024
1 parent aa9a196 commit 36189e2
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ one of the following tracking labels.
| ceil | yes | yes | yes | yes | yes |
| cholesky | yes | yes | yes | yes | revisit |
| clamp | yes | revisit | yes | yes | yes |
| collective_broadcast | yes | revisit | yes | no | no |
| collective_broadcast | yes | revisit | yes | no | yes |
| collective_permute | yes | revisit | yes | no | yes |
| compare | yes | yes | yes | yes | yes |
| complex | yes | yes | yes | yes | yes |
Expand Down
41 changes: 41 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,25 @@ SmallVector<InterpreterValue> eval(Region &region,
auto operand = scope.findTensor(clzOp.getOperand());
auto result = evalClzOp(operand, clzOp.getType());
scope.add(clzOp.getResult(), result);
} else if (auto collectiveBroadcastOp =
dyn_cast<CollectiveBroadcastOp>(op)) {
auto operand = scope.findTensor(collectiveBroadcastOp.getOperand());

auto replicaGroupsAttr = collectiveBroadcastOp.getReplicaGroups();
auto replicaGroupsShape = replicaGroupsAttr.getShapedType().getShape();
SmallVector<SmallVector<uint32_t>> replicaGroups(replicaGroupsShape[0]);
auto replicaGroupsIt = replicaGroupsAttr.getValues<int64_t>().begin();
for (auto &replicaGroup : replicaGroups)
for (auto i = 0; i < replicaGroupsShape[1]; ++i, ++replicaGroupsIt)
replicaGroup.push_back(*replicaGroupsIt);

ChannelId channelId = 0;
if (auto channelHandle = collectiveBroadcastOp.getChannelHandle())
channelId = channelHandle->getHandle();

auto result =
evalCollectiveBroadcastOp(operand, replicaGroups, channelId, process);
scope.add(collectiveBroadcastOp.getResult(), result);
} else if (auto collectivePermuteOp = dyn_cast<CollectivePermuteOp>(op)) {
auto operand = scope.findTensor(collectivePermuteOp.getOperand());

Expand Down Expand Up @@ -1074,6 +1093,28 @@ Tensor evalClzOp(const Tensor &operand, ShapedType resultType) {
return result;
}

Tensor evalCollectiveBroadcastOp(
const Tensor &operand, SmallVector<SmallVector<uint32_t>> replicaGroups,
ChannelId channelId, Process *process) {
if (!process)
llvm::report_fatal_error(
"collective_broadcast is only supported when run via "
"interpreter.run_parallel");

ProcessGroups processGroups;
if (channelId <= 0) processGroups = process->crossReplica(replicaGroups);
if (channelId > 0) processGroups = process->crossPartition(replicaGroups);

auto processGroup = processGroups.findGroup(process->getId());
if (processGroup)
return process->rendezvous(*processGroup, channelId, operand)
.lookup((*processGroup)[0]);

return evalBroadcastInDimOp(
makeScalar(convert(operand.getElementType(), 0.0)), {},
operand.getType());
}

Tensor evalCollectivePermuteOp(
const Tensor &operand, SmallVector<SmallVector<uint32_t>> sourceTargetPairs,
ChannelId channelId, Process *process) {
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Tensor evalCeilOp(const Tensor &operand, ShapedType resultType);
Tensor evalClampOp(const Tensor &min, const Tensor &operand, const Tensor &max,
ShapedType resultType);
Tensor evalClzOp(const Tensor &operand, ShapedType resultType);
Tensor evalCollectiveBroadcastOp(
const Tensor &operand, SmallVector<SmallVector<uint32_t>> replicaGroups,
ChannelId channelId, Process *process);
Tensor evalCollectivePermuteOp(
const Tensor &operand, SmallVector<SmallVector<uint32_t>> sourceTargetPairs,
ChannelId channelId, Process *process);
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/reference/ProcessGrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ bool ProcessId::operator==(const ProcessId &other) const {

std::optional<ProcessGroup> ProcessGroups::findGroup(ProcessId processId) {
for (auto processGroup : *this)
for (auto id : processGroup)
if (id == processId) return processGroup;
if (llvm::find(processGroup, processId) != processGroup.end())
return processGroup;

return std::nullopt;
}
Expand Down
223 changes: 223 additions & 0 deletions stablehlo/tests/interpret/collective_broadcast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
return %result : tensor<1x2xi64>
}
func.func @main() {
%operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
%operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
%operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
%operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
%results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
programs=[[@collective_broadcast], [@collective_broadcast],
[@collective_broadcast], [@collective_broadcast]]
} : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
(tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64>
check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
func.return
}
}

// -----

module @cross_replica_multiple_output {
func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
return %result : tensor<1x2xi64>
}
func.func @main() {
%operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
%operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
%operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
%operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
%results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
programs=[[@collective_broadcast], [@collective_broadcast],
[@collective_broadcast], [@collective_broadcast]]
} : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
(tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
func.return
}
}

// -----

module @cross_replica_single_replica {
func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[0]]> : tensor<1x1xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
return %result : tensor<1x2xi64>
}
func.func @main() {
%operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
%operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
%operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
%operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
%results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
programs=[[@collective_broadcast, @collective_broadcast,
@collective_broadcast, @collective_broadcast]]
} : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
(tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64>
check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
func.return
}
}

// -----

module @cross_replica_multiple_partitions {
func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
return %result : tensor<1x2xi64>
}
func.func @main() {
%operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
%operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
%operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
%operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
%results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
programs=[[@collective_broadcast, @collective_broadcast],
[@collective_broadcast, @collective_broadcast]]
} : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
(tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#1, dense<[[7, 8]]> : tensor<1x2xi64>
check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
func.return
}
}

// -----

module @cross_partition {
func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
return %result : tensor<1x2xi64>
}
func.func @main() {
%operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
%operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
%operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
%operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
%results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
programs=[[@collective_broadcast, @collective_broadcast,
@collective_broadcast, @collective_broadcast]]
} : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
(tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
check.expect_eq_const %results#0, dense<[[0, 0]]> : tensor<1x2xi64>
check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
func.return
}
}

// -----

module @cross_partition_multiple_output {
func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[2, 1, 0]]> : tensor<1x3xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
return %result : tensor<1x2xi64>
}
func.func @main() {
%operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
%operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
%operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
%operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
%results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
programs=[[@collective_broadcast, @collective_broadcast,
@collective_broadcast, @collective_broadcast]]
} : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
(tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
check.expect_eq_const %results#0, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#1, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#3, dense<[[0, 0]]> : tensor<1x2xi64>
func.return
}
}

// -----

module @cross_partition_single_partition {
func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[0]]> : tensor<1x1xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
return %result : tensor<1x2xi64>
}
func.func @main() {
%operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
%operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
%operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
%operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
%results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
programs=[[@collective_broadcast], [@collective_broadcast],
[@collective_broadcast], [@collective_broadcast]]
} : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
(tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
check.expect_eq_const %results#0, dense<[[1, 2]]> : tensor<1x2xi64>
check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
check.expect_eq_const %results#2, dense<[[5, 6]]> : tensor<1x2xi64>
check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
func.return
}
}

// -----

module @cross_partition_multiple_replicas {
func.func @collective_broadcast(%operand : tensor<1x2xi64>) -> tensor<1x2xi64> {
%result = "stablehlo.collective_broadcast"(%operand) {
replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor<1x2xi64>) -> tensor<1x2xi64>
return %result : tensor<1x2xi64>
}
func.func @main() {
%operand0 = stablehlo.constant dense<[[1, 2]]> : tensor<1x2xi64>
%operand1 = stablehlo.constant dense<[[3, 4]]> : tensor<1x2xi64>
%operand2 = stablehlo.constant dense<[[5, 6]]> : tensor<1x2xi64>
%operand3 = stablehlo.constant dense<[[7, 8]]> : tensor<1x2xi64>
%results:4 = "interpreter.run_parallel"(%operand0, %operand1, %operand2, %operand3) {
programs=[[@collective_broadcast, @collective_broadcast],
[@collective_broadcast, @collective_broadcast]]
} : (tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>) ->
(tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>, tensor<1x2xi64>)
check.expect_eq_const %results#0, dense<[[3, 4]]> : tensor<1x2xi64>
check.expect_eq_const %results#1, dense<[[3, 4]]> : tensor<1x2xi64>
check.expect_eq_const %results#2, dense<[[7, 8]]> : tensor<1x2xi64>
check.expect_eq_const %results#3, dense<[[7, 8]]> : tensor<1x2xi64>
func.return
}
}

0 comments on commit 36189e2

Please sign in to comment.