Skip to content

Commit

Permalink
Use default visibility for symbols (#1993)
Browse files Browse the repository at this point in the history
As pointed out by @mlevesquedion in
#1983 (comment),
the symbol visibility is public by default, and making it implicit
follows the precedent of other tests in this directory.
  • Loading branch information
ghpvnist authored Feb 6, 2024
1 parent 36189e2 commit 7d7977a
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 48 deletions.
16 changes: 8 additions & 8 deletions stablehlo/tests/interpret/all_gather.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
func.func @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
%result = "stablehlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
return %result : tensor<2x4xi64>
}
func.func public @main() {
func.func @main() {
%0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%results:2 = "interpreter.run_parallel"(%0, %1) {
Expand All @@ -25,15 +25,15 @@ module @cross_replica {
// -----

module @cross_replica_and_partition {
func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
func.func @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
%result = "stablehlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
return %result : tensor<2x4xi64>
}
func.func public @main() {
func.func @main() {
%0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%results:2 = "interpreter.run_parallel"(%0, %1) {
Expand All @@ -50,15 +50,15 @@ module @cross_replica_and_partition {
// -----

module @cross_replica_and_partition_issue_1933 {
func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x8xi64> {
func.func @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x8xi64> {
%result = "stablehlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle=1, type=0>
} : (tensor<2x2xi64>) -> tensor<2x8xi64>
return %result : tensor<2x8xi64>
}
func.func public @main() {
func.func @main() {
%0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%results:4 = "interpreter.run_parallel"(%1, %1, %0, %1) {
Expand All @@ -81,7 +81,7 @@ module @cross_replica_and_partition_issue_1933 {
// -----

module @flattened_ids {
func.func public @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
func.func @all_gather(%arg0 : tensor<2x2xi64>) -> tensor<2x4xi64> {
%result = "stablehlo.all_gather"(%arg0) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
Expand All @@ -90,7 +90,7 @@ module @flattened_ids {
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
return %result : tensor<2x4xi64>
}
func.func public @main() {
func.func @main() {
%0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%results:2 = "interpreter.run_parallel"(%0, %1) {
Expand Down
16 changes: 8 additions & 8 deletions stablehlo/tests/interpret/all_reduce.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
func.func @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
Expand All @@ -12,7 +12,7 @@ module @cross_replica {
} : (tensor<4xi64>) -> tensor<4xi64>
return %result : tensor<4xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
%results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
Expand All @@ -27,7 +27,7 @@ module @cross_replica {
// -----

module @cross_replica_and_partition {
func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
func.func @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
Expand All @@ -38,7 +38,7 @@ module @cross_replica_and_partition {
} : (tensor<4xi64>) -> tensor<4xi64>
return %result : tensor<4xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
%results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
Expand All @@ -53,7 +53,7 @@ module @cross_replica_and_partition {
// -----

module @flattened_ids {
func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
func.func @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
Expand All @@ -65,7 +65,7 @@ module @flattened_ids {
} : (tensor<4xi64>) -> tensor<4xi64>
return %result : tensor<4xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
%results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
Expand All @@ -80,7 +80,7 @@ module @flattened_ids {
// -----

module @ragged_replica_groups {
func.func public @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
func.func @all_reduce(%operand : tensor<4xi64>) -> tensor<4xi64> {
%result = "stablehlo.all_reduce"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
Expand All @@ -91,7 +91,7 @@ module @ragged_replica_groups {
} : (tensor<4xi64>) -> tensor<4xi64>
return %result : tensor<4xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%inputs1 = stablehlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
%inputs2 = stablehlo.constant dense<[6, 8, 10, 12]> : tensor<4xi64>
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/tests/interpret/all_to_all.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> {
func.func @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> {
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
Expand All @@ -10,7 +10,7 @@ module @cross_replica {
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
return %result : tensor<4x2xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[[1, 2, 3, 4],
[5, 6, 7, 8]]> : tensor<2x4xi64>
%inputs1 = stablehlo.constant dense<[[9, 10, 11, 12],
Expand All @@ -33,7 +33,7 @@ module @cross_replica {
// -----

module @cross_partition {
func.func public @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> {
func.func @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> {
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
Expand All @@ -43,7 +43,7 @@ module @cross_partition {
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
return %result : tensor<4x2xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[[1, 2, 3, 4],
[5, 6, 7, 8]]> : tensor<2x4xi64>
%inputs1 = stablehlo.constant dense<[[9, 10, 11, 12],
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/tests/interpret/collective_permute.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @collective_permute(%operand : tensor<2x2xi64>) -> tensor<2x2xi64> {
func.func @collective_permute(%operand : tensor<2x2xi64>) -> tensor<2x2xi64> {
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
return %result : tensor<2x2xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%inputs1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%inputs2 = stablehlo.constant dense<[[9, 10], [11, 12]]> : tensor<2x2xi64>
Expand All @@ -26,14 +26,14 @@ module @cross_replica {
// -----

module @cross_partition {
func.func public @collective_permute(%operand : tensor<2x2xi64>) -> tensor<2x2xi64> {
func.func @collective_permute(%operand : tensor<2x2xi64>) -> tensor<2x2xi64> {
%result = "stablehlo.collective_permute"(%operand) {
source_target_pairs = dense<[[0, 1], [1, 2]]> : tensor<2x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 1, type = 0>
} : (tensor<2x2xi64>) -> tensor<2x2xi64>
return %result : tensor<2x2xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
%inputs1 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
%inputs2 = stablehlo.constant dense<[[9, 10], [11, 12]]> : tensor<2x2xi64>
Expand Down
8 changes: 4 additions & 4 deletions stablehlo/tests/interpret/infeed.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @distribution_ops {
func.func public @infeed(%token : !stablehlo.token) ->
func.func @infeed(%token : !stablehlo.token) ->
(tensor<2x2xi64>, !stablehlo.token,
tensor<2x2xi64>, !stablehlo.token) {
%results0:2 = "stablehlo.infeed"(%token) :
Expand All @@ -11,15 +11,15 @@ module @distribution_ops {
func.return %results0#0, %results0#1, %results1#0, %results1#1 :
tensor<2x2xi64>, !stablehlo.token, tensor<2x2xi64>, !stablehlo.token
}
func.func public @infeed_queue0() -> (tensor<2x2xi64>) {
func.func @infeed_queue0() -> (tensor<2x2xi64>) {
%queue0 = stablehlo.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi64>
func.return %queue0 : tensor<2x2xi64>
}
func.func public @infeed_queue1() -> (tensor<2x2xi64>) {
func.func @infeed_queue1() -> (tensor<2x2xi64>) {
%queue0 = stablehlo.constant dense<[[5, 6], [7, 8]]> : tensor<2x2xi64>
func.return %queue0 : tensor<2x2xi64>
}
func.func public @main() {
func.func @main() {
%token = stablehlo.after_all : !stablehlo.token
%results:4 = "interpreter.run_parallel"(%token) {
infeed=[@infeed_queue0, @infeed_queue1],
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/interpret/outfeed.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @distribution_ops {
func.func public @outfeed(%inputs0 : tensor<2x2x2xi64>, %token : !stablehlo.token) -> !stablehlo.token {
func.func @outfeed(%inputs0 : tensor<2x2x2xi64>, %token : !stablehlo.token) -> !stablehlo.token {
%result = "stablehlo.outfeed"(%inputs0, %token) :
(tensor<2x2x2xi64>, !stablehlo.token) -> !stablehlo.token
func.return %result : !stablehlo.token
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[[[1, 2], [3, 4]],
[[5, 6], [7, 8]]]> : tensor<2x2x2xi64>
%token = stablehlo.after_all : !stablehlo.token
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/interpret/partition_id.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @distribution_ops {
func.func public @partition_id() -> tensor<ui32> {
func.func @partition_id() -> tensor<ui32> {
%result = stablehlo.partition_id : tensor<ui32>
return %result : tensor<ui32>
}
func.func public @main() {
func.func @main() {
%results:2 = "interpreter.run_parallel"() {
programs=[[@partition_id, @partition_id]]
} : () -> (tensor<ui32>, tensor<ui32>)
Expand Down
12 changes: 6 additions & 6 deletions stablehlo/tests/interpret/reduce_scatter.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @cross_replica {
func.func public @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {
func.func @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
Expand All @@ -12,7 +12,7 @@ module @cross_replica {
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
return %result : tensor<2x2xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[[1, 2, 3, 4],
[5, 6, 7, 8]]> : tensor<2x4xi64>
%inputs1 = stablehlo.constant dense<[[9, 10, 11, 12],
Expand All @@ -31,7 +31,7 @@ module @cross_replica {
// -----

module @cross_replica_and_partition {
func.func public @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {
func.func @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
Expand All @@ -43,7 +43,7 @@ module @cross_replica_and_partition {
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
return %result : tensor<2x2xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[[1, 2, 3, 4],
[5, 6, 7, 8]]> : tensor<2x4xi64>
%inputs1 = stablehlo.constant dense<[[9, 10, 11, 12],
Expand All @@ -62,7 +62,7 @@ module @cross_replica_and_partition {
// -----

module @flattened_ids {
func.func public @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {
func.func @reduce_scatter(%operand : tensor<2x4xi64>) -> tensor<2x2xi64> {
%result = "stablehlo.reduce_scatter"(%operand) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = stablehlo.add %arg0, %arg1 : tensor<i64>
Expand All @@ -75,7 +75,7 @@ module @flattened_ids {
} : (tensor<2x4xi64>) -> tensor<2x2xi64>
return %result : tensor<2x2xi64>
}
func.func public @main() {
func.func @main() {
%inputs0 = stablehlo.constant dense<[[1, 2, 3, 4],
[5, 6, 7, 8]]> : tensor<2x4xi64>
%inputs1 = stablehlo.constant dense<[[9, 10, 11, 12],
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/interpret/replica_id.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: stablehlo-translate --interpret -split-input-file %s

module @distribution_ops {
func.func public @replica_id() -> tensor<ui32> {
func.func @replica_id() -> tensor<ui32> {
%result = stablehlo.replica_id : tensor<ui32>
return %result : tensor<ui32>
}
func.func public @main() {
func.func @main() {
%results:2 = "interpreter.run_parallel"() {
programs=[[@replica_id], [@replica_id]]
} : () -> (tensor<ui32>, tensor<ui32>)
Expand Down
Loading

0 comments on commit 7d7977a

Please sign in to comment.