diff --git a/WORKSPACE.bazel b/WORKSPACE.bazel index 4b2c3c5a933..0a9e32b4c5d 100644 --- a/WORKSPACE.bazel +++ b/WORKSPACE.bazel @@ -17,9 +17,9 @@ workspace(name = "stablehlo") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -LLVM_COMMIT = "b270525f730be6e7196667925f5a9bfa153262e9" +LLVM_COMMIT = "e2402615a5a76d46a433dfcc1de10b38a1263c9d" -LLVM_SHA256 = "fcf77da395cd5097eac5951471b04aa887f565c3447545239421a0eb7089da7c" +LLVM_SHA256 = "9c22349e1d38555b2f223e49951655f60c04c0c3467e0150aaf6c9f50484cc9f" http_archive( name = "llvm-raw", diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 456bfbc70a2..e793a5b6997 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -b270525f730be6e7196667925f5a9bfa153262e9 +e2402615a5a76d46a433dfcc1de10b38a1263c9d diff --git a/examples/c++/ExampleAdd.cpp b/examples/c++/ExampleAdd.cpp index 1161723acd8..a2d479fd1c4 100644 --- a/examples/c++/ExampleAdd.cpp +++ b/examples/c++/ExampleAdd.cpp @@ -49,7 +49,7 @@ int main() { /** create function **/ // create function argument and result types. auto tensorType = - mlir::RankedTensorType::get({3, 4}, mlir::FloatType::getF32(&context)); + mlir::RankedTensorType::get({3, 4}, mlir::Float32Type::get(&context)); auto func_type = mlir::FunctionType::get(&context, {tensorType, tensorType}, {tensorType}); diff --git a/stablehlo/conversions/tosa/tests/nullary.mlir b/stablehlo/conversions/tosa/tests/nullary.mlir index b77a01d574f..d4f9c6d4bff 100644 --- a/stablehlo/conversions/tosa/tests/nullary.mlir +++ b/stablehlo/conversions/tosa/tests/nullary.mlir @@ -17,8 +17,9 @@ func.func @constant_f64() -> tensor<10xf64> { // CHECK-LABEL: @iota_dimension_0 func.func @iota_dimension_0() -> tensor<4x8xf32> { // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() - // CHECK-SAME{LITERAL}: <{value = dense<[[0.000000e+00], [1.000000e+00], [2.000000e+00], [3.000000e+00]]> : tensor<4x1xf32>}> - // CHECK-DAG: %[[VAR1:.*]] = tosa.tile %[[VAR0]] {multiples = array} + // CHECK-SAME{LITERAL}: <{value = dense<[[0.000000e+00], [1.000000e+00], [2.000000e+00], [3.000000e+00]]> : tensor<4x1xf32>}> : () -> tensor<4x1xf32> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[1, 8]> : vector<2xindex>} : () -> !tosa.shape<2> + // CHECK-DAG: %[[VAR2:.*]] = tosa.tile %[[VAR0]], %[[VAR1]] %0 = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> (tensor<4x8xf32>) return %0 : tensor<4x8xf32> } @@ -26,8 +27,9 @@ func.func @iota_dimension_0() -> tensor<4x8xf32> { // CHECK-LABEL: @iota_dimension_1 func.func @iota_dimension_1() -> tensor<4x8xi32> { // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() - // CHECK-SAME{LITERAL}: <{value = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi32>}> - // CHECK-DAG: %[[VAR1:.*]] = tosa.tile %[[VAR0]] {multiples = array} + // CHECK-SAME{LITERAL}: <{value = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi32>}> : () -> tensor<1x8xi32> + // CHECK-DAG: %[[VAR1:.*]] = tosa.const_shape {value = dense<[4, 1]> : vector<2xindex>} : () -> !tosa.shape<2> + // CHECK-DAG: %[[VAR2:.*]] = tosa.tile %[[VAR0]], %[[VAR1]] %0 = "stablehlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<4x8xi32>) return %0 : tensor<4x8xi32> } diff --git a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp index d190077edc5..b4430e7c65b 100644 --- a/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++ b/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp @@ -258,9 +258,13 @@ struct ConvertStablehloIotaOp : public OpRewritePattern { } } + auto shapeType = rewriter.getType(tileMultiples.size()); + auto shapedMultiples = rewriter.create( + op.getLoc(), shapeType, rewriter.getIndexVectorAttr(tileMultiples)); + // Tile the const array to the result shape of the iota op. - rewriter.replaceOpWithNewOp( - op, resultType, constOp, rewriter.getDenseI64ArrayAttr(tileMultiples)); + rewriter.replaceOpWithNewOp(op, resultType, constOp, + shapedMultiples); return success(); } }; diff --git a/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir index dd16fee54f7..e86bbc5a529 100644 --- a/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir +++ b/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir @@ -3927,3 +3927,68 @@ func.func @square_f32(%arg : tensor) -> tensor { %result = "chlo.square"(%arg) : (tensor) -> tensor func.return %result : tensor } + +// ----- + +// CHECK-LABEL: @ragged_dot_mode_1 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<2x11x5xf32>, %[[ARG_1:.*]]: tensor<3x2x5x7xf32>, %[[ARG_2:.*]]: tensor<3xi64>) -> tensor<2x11x7xf32> { +// CHECK: %[[VAL_0:.*]] = stablehlo.iota dim = 1 : tensor<1x11x1xi64> +// CHECK: %[[VAL_C:.*]] = stablehlo.constant dense<0> : tensor<1xi64> +// CHECK: %[[VAL_CST:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x11x7xf32> +// CHECK: %[[VAL_CST_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x11x7xf32> +// CHECK: %[[VAL_1:.*]] = stablehlo.slice %[[ARG_2]] [0:1] : (tensor<3xi64>) -> tensor<1xi64> +// CHECK: %[[VAL_2:.*]] = stablehlo.broadcast_in_dim %[[VAL_C]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64> +// CHECK: %[[VAL_3:.*]] = stablehlo.compare LE, %[[VAL_2]], %[[VAL_0]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1> +// CHECK: %[[VAL_4:.*]] = stablehlo.add %[[VAL_C]], %[[VAL_1]] : tensor<1xi64> +// CHECK: %[[VAL_5:.*]] = stablehlo.broadcast_in_dim %[[VAL_4]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64> +// CHECK: %[[VAL_6:.*]] = stablehlo.compare LT, %[[VAL_0]], %[[VAL_5]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1> +// CHECK: %[[VAL_7:.*]] = stablehlo.and %[[VAL_3]], %[[VAL_6]] : tensor<1x11x1xi1> +// CHECK: %[[VAL_8:.*]] = stablehlo.broadcast_in_dim %[[VAL_7]], dims = [0, 1, 2] : (tensor<1x11x1xi1>) -> tensor<2x11x7xi1> +// CHECK: %[[VAL_9:.*]] = stablehlo.slice %[[ARG_1]] [0:1, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32> +// CHECK: %[[VAL_10:.*]] = stablehlo.reshape %[[VAL_9]] : (tensor<1x2x5x7xf32>) -> tensor<2x5x7xf32> +// CHECK: %[[VAL_11:.*]] = stablehlo.dot_general %[[ARG_0]], %[[VAL_10]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x11x5xf32>, tensor<2x5x7xf32>) -> tensor<2x11x7xf32> +// CHECK: %[[VAL_12:.*]] = stablehlo.select %[[VAL_8]], %[[VAL_11]], %[[VAL_CST_0]] : tensor<2x11x7xi1>, tensor<2x11x7xf32> +// CHECK: %[[VAL_13:.*]] = stablehlo.add %[[VAL_CST]], %[[VAL_12]] : tensor<2x11x7xf32> +// CHECK: %[[VAL_14:.*]] = stablehlo.add %[[VAL_C]], %[[VAL_1]] : tensor<1xi64> +// CHECK: %[[VAL_15:.*]] = stablehlo.slice %[[ARG_2]] [1:2] : (tensor<3xi64>) -> tensor<1xi64> +// CHECK: %[[VAL_16:.*]] = stablehlo.broadcast_in_dim %[[VAL_14]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64> +// CHECK: %[[VAL_17:.*]] = stablehlo.compare LE, %[[VAL_16]], %[[VAL_0]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1> +// CHECK: %[[VAL_18:.*]] = stablehlo.add %[[VAL_14]], %[[VAL_15]] : tensor<1xi64> +// CHECK: %[[VAL_19:.*]] = stablehlo.broadcast_in_dim %[[VAL_18]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64> +// CHECK: %[[VAL_20:.*]] = stablehlo.compare LT, %[[VAL_0]], %[[VAL_19]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1> +// CHECK: %[[VAL_21:.*]] = stablehlo.and %[[VAL_17]], %[[VAL_20]] : tensor<1x11x1xi1> +// CHECK: %[[VAL_22:.*]] = stablehlo.broadcast_in_dim %[[VAL_21]], dims = [0, 1, 2] : (tensor<1x11x1xi1>) -> tensor<2x11x7xi1> +// CHECK: %[[VAL_23:.*]] = stablehlo.slice %[[ARG_1]] [1:2, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32> +// CHECK: %[[VAL_24:.*]] = stablehlo.reshape %[[VAL_23]] : (tensor<1x2x5x7xf32>) -> tensor<2x5x7xf32> +// CHECK: %[[VAL_25:.*]] = stablehlo.dot_general %[[ARG_0]], %[[VAL_24]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x11x5xf32>, tensor<2x5x7xf32>) -> tensor<2x11x7xf32> +// CHECK: %[[VAL_26:.*]] = stablehlo.select %[[VAL_22]], %[[VAL_25]], %[[VAL_CST_0]] : tensor<2x11x7xi1>, tensor<2x11x7xf32> +// CHECK: %[[VAL_27:.*]] = stablehlo.add %[[VAL_13]], %[[VAL_26]] : tensor<2x11x7xf32> +// CHECK: %[[VAL_28:.*]] = stablehlo.add %[[VAL_14]], %[[VAL_15]] : tensor<1xi64> +// CHECK: %[[VAL_29:.*]] = stablehlo.slice %[[ARG_2]] [2:3] : (tensor<3xi64>) -> tensor<1xi64> +// CHECK: %[[VAL_30:.*]] = stablehlo.broadcast_in_dim %[[VAL_28]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64> +// CHECK: %[[VAL_31:.*]] = stablehlo.compare LE, %[[VAL_30]], %[[VAL_0]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1> +// CHECK: %[[VAL_32:.*]] = stablehlo.add %[[VAL_28]], %[[VAL_29]] : tensor<1xi64> +// CHECK: %[[VAL_33:.*]] = stablehlo.broadcast_in_dim %[[VAL_32]], dims = [0] : (tensor<1xi64>) -> tensor<1x11x1xi64> +// CHECK: %[[VAL_34:.*]] = stablehlo.compare LT, %[[VAL_0]], %[[VAL_33]] : (tensor<1x11x1xi64>, tensor<1x11x1xi64>) -> tensor<1x11x1xi1> +// CHECK: %[[VAL_35:.*]] = stablehlo.and %[[VAL_31]], %[[VAL_34]] : tensor<1x11x1xi1> +// CHECK: %[[VAL_36:.*]] = stablehlo.broadcast_in_dim %[[VAL_35]], dims = [0, 1, 2] : (tensor<1x11x1xi1>) -> tensor<2x11x7xi1> +// CHECK: %[[VAL_37:.*]] = stablehlo.slice %[[ARG_1]] [2:3, 0:2, 0:5, 0:7] : (tensor<3x2x5x7xf32>) -> tensor<1x2x5x7xf32> +// CHECK: %[[VAL_38:.*]] = stablehlo.reshape %[[VAL_37]] : (tensor<1x2x5x7xf32>) -> tensor<2x5x7xf32> +// CHECK: %[[VAL_39:.*]] = stablehlo.dot_general %[[ARG_0]], %[[VAL_38]], batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x11x5xf32>, tensor<2x5x7xf32>) -> tensor<2x11x7xf32> +// CHECK: %[[VAL_40:.*]] = stablehlo.select %[[VAL_36]], %[[VAL_39]], %[[VAL_CST_0]] : tensor<2x11x7xi1>, tensor<2x11x7xf32> +// CHECK: %[[VAL_41:.*]] = stablehlo.add %[[VAL_27]], %[[VAL_40]] : tensor<2x11x7xf32> +// CHECK: %[[VAL_42:.*]] = stablehlo.add %[[VAL_28]], %[[VAL_29]] : tensor<1xi64> +func.func @ragged_dot_mode_1(%lhs : tensor<2x11x5xf32>, %rhs : tensor<3x2x5x7xf32>, %group_sizes : tensor<3xi64>) -> tensor<2x11x7xf32> { + %0 = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> + func.return %0 : tensor<2x11x7xf32> +} diff --git a/stablehlo/tests/interpret/chlo/ragged_dot.mlir b/stablehlo/tests/interpret/chlo/ragged_dot.mlir new file mode 100644 index 00000000000..7f2fc07aa81 --- /dev/null +++ b/stablehlo/tests/interpret/chlo/ragged_dot.mlir @@ -0,0 +1,189 @@ +// RUN: stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify-diagnostics %s > %t.mlir +// RUN: stablehlo-translate --interpret --split-input-file %t.mlir + +func.func @ragged_dot_mode_1() { + %lhs = stablehlo.constant dense< + [ + [ -0.0999976546, -0.0605386607, 0.126681596, 0.0375950411, 0.0598301813 ], + [ -0.0343122408, -0.0858866125, 0.103659429, 0.103788935, 0.180407882 ], + [ 0.0150506198, 0.055824928, 0.149289608, -0.0896283686, -0.0839615092 ], + [ 0.0589100644, 0.101344816, -0.097690545, 0.0150246918, -0.0799473301 ], + [ 0.0252457932, 0.106031813, 0.076692991, 0.179130971, 0.153850079 ], + [ 0.0580786392, -0.0724105313, 0.0961757079, 0.0247998089, 0.110357188 ], + [ 0.173096269, 0.128659427, -0.0212640986, -0.0857606456, 0.120824583 ], + [ -0.00152973086, 0.0897915736, 0.126923144, 0.197311223, 0.00960160792 ], + [ -0.0258883312, 0.194765091, 0.11679814, 0.126006752, 0.0954555795 ], + [ -0.0781942382, 0.0894904211, 0.165412158, -0.0181870088, 0.0309234336 ], + [ 0.129948437, 0.0433195308, -0.028667666, -0.0175279453, 0.00777949393 ] + ]> : tensor<11x5xf32> + %rhs = stablehlo.constant dense<[ + [ + [ 0.186608255, 0.124487795, 0.0663751587, 0.167221248, 0.0874548, 0.152611881, -0.0520697422 ], + [ -0.0361745432, 0.114412986, -0.0608718246, -0.0727029, -0.0176235586, -0.0991001204, 0.0242879838 ], + [ -0.0919371173, 0.112945892, 0.181369215, -0.0280267522, -0.0457312278, -0.00473813713, 0.166097224 ], + [ 0.0956176, -0.0548994839, 0.104403876, 0.0157444105, 0.0163175985, 0.0499223098, -0.0557401 ], + [ 0.076156, 0.153672695, 0.0770325884, 0.186622649, 0.066843845, -0.0555545315, 0.194991559 ] + ], + [ + [ 0.00485724211, 0.0356900468, 0.142683387, 0.179502338, 0.0954938307, -0.0354254842, 0.103877716 ], + [ 0.172676593, -0.0249623209, 0.158257961, 0.0413787, 0.0517867729, 0.0801181123, 0.14526847 ], + [ 0.126753062, 0.0386734977, 0.185410261, 0.0898216143, 0.0317991, 0.14740923, 0.106694289 ], + [ 0.110662006, 0.196143657, 0.186324477, 0.155380905, -0.0132051334, 0.0612277314, 0.054330416 ], + [ -0.0689698234, 0.0242085531, 0.073015, 0.162969738, 0.0320116058, 0.118924297, 0.160779119 ] + ], + [ + [ -0.0998214856, -0.0997363, 0.132005602, 0.118200503, -0.00424671918, 0.025317125, 0.104748271 ], + [ 0.104168601, -0.0384214334, 0.150926, 0.112676181, 0.14861238, -0.071635358, -0.0754787177 ], + [ 0.129201442, 0.088871561, -0.0358443409, -0.0359359607, -0.0756817609, 0.0166469738, 0.185647905 ], + [ 0.184263527, 0.0169560835, -0.0192355737, 0.10765069, -0.0147894919, 0.13305977, 0.135159582 ], + [ 0.0267379507, -0.0153532401, -0.0418097563, -0.096605137, -0.0424528457, 0.194970757, -0.0267837271 ] + ]]> : tensor<3x5x7xf32> + %group_sizes = stablehlo.constant dense<[4, 4, 3]> : tensor<3xi64> + %result = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [1], + lhs_ragged_dimensions = [0], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<11x5xf32>, tensor<3x5x7xf32>, tensor<3xi64>) -> tensor<11x7xf32> + check.expect_almost_eq_const %result, dense<[ + [-0.0199659951, 0.00206358638, 0.0285578221, -0.00411329232, -0.00885893404, -0.0113086831, 0.0343487822], + [0.0108370036, 0.0196357146, 0.0464844741, 0.032903526, 0.00752512738, -0.00205732603, 0.0463109687], + [-0.0279003512, 0.0171403233, 0.00885203853, -0.022806216, -0.0135696121, -0.00375272054, 0.0139928926], + [0.0116565451, -0.00521556707, -0.0245668497, -0.00946252606, 2.734600e-03, 0.00460146647, -0.0332586318], + [0.0373648889, 0.040080104, 0.0792120546, 0.0687142611, 0.0129001699, 0.048170276, 6.067640e-02], + [-0.00489785476, 0.0151357278, 0.0273378156, 0.0379059538, 0.0080597708, 0.0209609158, 0.0248660222], + [0.00253825542, -1.175260e-02, 0.0339594558, 0.0408501513, 0.0275165718, 0.0101594552, 0.0491689071], + [5.275800e-02, 0.0415463448, 0.0749897882, 0.0470644757, 0.00624182029, 0.0391805507, 0.03869069], + [0.0637338459, 0.00614991458, 0.0153763723, 0.0190313365, 0.0142990183, 0.0227143262, 0.0187453162], + [0.0359746702, 0.0182777364, -0.00368779944, -0.0100486111, 6.89582666E-5, -0.00202751439, 0.0124766938], + [-0.0151847685, -0.0175893605, 0.0247314386, 0.018632818, 0.00798455066, -0.00110600982, 0.00244264561] + ]> : tensor<11x7xf32> + func.return +} + +// ----- + +func.func @ragged_dot_mode_1_batching() { + %lhs = stablehlo.constant dense<[ + [ + [ -0.0999976546, -0.0605386607, 0.126681596, 0.0375950411, 0.0598301813 ], + [ -0.0343122408, -0.0858866125, 0.103659429, 0.103788935, 0.180407882 ], + [ 0.0150506198, 0.055824928, 0.149289608, -0.0896283686, -0.0839615092 ], + [ 0.0589100644, 0.101344816, -0.097690545, 0.0150246918, -0.0799473301 ], + [ 0.0252457932, 0.106031813, 0.076692991, 0.179130971, 0.153850079 ], + [ 0.0580786392, -0.0724105313, 0.0961757079, 0.0247998089, 0.110357188 ], + [ 0.173096269, 0.128659427, -0.0212640986, -0.0857606456, 0.120824583 ], + [ -0.00152973086, 0.0897915736, 0.126923144, 0.197311223, 0.00960160792 ], + [ -0.0258883312, 0.194765091, 0.11679814, 0.126006752, 0.0954555795 ], + [ -0.0781942382, 0.0894904211, 0.165412158, -0.0181870088, 0.0309234336 ], + [ 0.129948437, 0.0433195308, -0.028667666, -0.0175279453, 0.00777949393 ] + ], + [ + [ -0.0500478409, 0.0459552184, 0.16929689, 0.172762454, -0.0818307 ], + [ 0.171395928, 0.0513568744, 0.0548876, -0.00429011881, 0.195992649 ], + [ 0.0481930152, -0.0201566443, -0.0727801323, 0.184329301, -0.0778752789 ], + [ 0.0502121374, 0.0152426511, -0.0168754607, 0.174145252, 0.0589242205 ], + [ 0.0393337533, 0.182294011, -0.0849748, 0.128454268, 0.131061375 ], + [ 0.148345202, -0.0623903871, -0.0952396914, 0.10653659, 0.160474151 ], + [ 0.0888630375, 0.120867364, 0.117623605, 0.199837387, 0.166571677 ], + [ -0.0300415382, -0.00810345262, 0.00530457497, 0.0539821163, 0.0773340687 ], + [ 0.153794467, 0.0236242339, 0.152453214, -0.0192048177, 0.0246183872 ], + [ 0.0611911938, 0.0403752252, -0.013836287, -0.0465016849, -0.053884007 ], + [ 0.0714964494, 0.140721709, -0.0900838748, 0.0603349432, 0.0495440438 ] + ]]> : tensor<2x11x5xf32> + %rhs = stablehlo.constant dense<[ + [ + [ + [ 0.186608255, 0.124487795, 0.0663751587, 0.167221248, 0.0874548, 0.152611881, -0.0520697422 ], + [ -0.0361745432, 0.114412986, -0.0608718246, -0.0727029, -0.0176235586, -0.0991001204, 0.0242879838 ], + [ -0.0919371173, 0.112945892, 0.181369215, -0.0280267522, -0.0457312278, -0.00473813713, 0.166097224 ], + [ 0.0956176, -0.0548994839, 0.104403876, 0.0157444105, 0.0163175985, 0.0499223098, -0.0557401 ], + [ 0.076156, 0.153672695, 0.0770325884, 0.186622649, 0.066843845, -0.0555545315, 0.194991559 ] + ], + [ + [ 0.0226300061, -0.0574540682, 0.0694696084, -0.0243620798, 0.0465543643, 0.0392091647, 0.188328564 ], + [ -0.0621907599, -0.0400728397, -0.0042250976, 0.0887807682, -0.0619863532, 0.0953761414, 0.0864902064 ], + [ 0.140921891, -0.0256474689, 0.0429295525, 0.0167942569, -0.0390249, -0.0914874449, 0.170502067 ], + [ 0.0279492214, -0.0573936924, 0.184246033, 0.0230939165, -0.060643442, 0.165694535, -0.0723479092 ], + [ -0.051340431, -0.0786809325, 0.00960171223, -0.0240827873, -0.059467189, 0.134945959, 0.0365921929 ] + ] + ], + [ + [ + [ 0.00485724211, 0.0356900468, 0.142683387, 0.179502338, 0.0954938307, -0.0354254842, 0.103877716 ], + [ 0.172676593, -0.0249623209, 0.158257961, 0.0413787, 0.0517867729, 0.0801181123, 0.14526847 ], + [ 0.126753062, 0.0386734977, 0.185410261, 0.0898216143, 0.0317991, 0.14740923, 0.106694289 ], + [ 0.110662006, 0.196143657, 0.186324477, 0.155380905, -0.0132051334, 0.0612277314, 0.054330416 ], + [ -0.0689698234, 0.0242085531, 0.073015, 0.162969738, 0.0320116058, 0.118924297, 0.160779119 ] + ], + [ + [ 0.11469271, 0.140216112, 0.111960642, 0.122514777, -0.0942722782, 0.165809333, 0.0574962273 ], + [ 0.0389968231, -0.08044184, 0.114026703, 0.0466829464, 0.100303732, 0.104614742, -0.0401335768 ], + [ 0.174990177, 0.159764826, 0.167005628, 0.0631844923, -0.0582415, 0.0351042375, 0.196808755 ], + [ -0.035340406, 0.0338070318, -0.00528027117, 0.0543978438, 0.164451241, 0.0319176689, 0.0402595326 ], + [ 0.141994983, 0.00954742, -0.0365443081, 0.199735016, -0.053918656, 0.0891464874, 0.0849051103 ] + ] + ], + [ + [ + [ -0.0998214856, -0.0997363, 0.132005602, 0.118200503, -0.00424671918, 0.025317125, 0.104748271 ], + [ 0.104168601, -0.0384214334, 0.150926, 0.112676181, 0.14861238, -0.071635358, -0.0754787177 ], + [ 0.129201442, 0.088871561, -0.0358443409, -0.0359359607, -0.0756817609, 0.0166469738, 0.185647905 ], + [ 0.184263527, 0.0169560835, -0.0192355737, 0.10765069, -0.0147894919, 0.13305977, 0.135159582 ], + [ 0.0267379507, -0.0153532401, -0.0418097563, -0.096605137, -0.0424528457, 0.194970757, -0.0267837271 ] + ], + [ + [ 0.145917833, -0.0590635166, 0.0194431096, 0.0803030357, -0.0469358861, 0.148506433, -0.0526806451 ], + [ 0.196381122, -0.0228494033, -0.0299202427, -0.069508791, -0.0341768041, 0.0904152468, 0.108802207 ], + [ 0.138430953, 0.108872853, 0.125882119, 0.100856192, 0.0900289789, -0.0830678046, 0.0794649944 ], + [ -0.0318976864, -0.00436662883, 0.109950341, -0.0647689179, 0.128771216, 0.0578369871, 0.0661734 ], + [ 0.0763966814, -0.00110008568, 0.110896833, -0.057086423, -0.0514936894, 0.0455975607, 0.158067733 ] + ] + ]]> : tensor<3x2x5x7xf32> + %group_sizes = stablehlo.constant dense<[4, 4, 3]> : tensor<3xi64> + %result = "chlo.ragged_dot"(%lhs, %rhs, %group_sizes) { + ragged_dot_dimension_numbers = #chlo.ragged_dot< + lhs_batching_dimensions = [0], + rhs_batching_dimensions = [1], + lhs_contracting_dimensions = [2], + rhs_contracting_dimensions = [2], + lhs_ragged_dimensions = [1], + rhs_group_dimensions = [0] + >, + precision_config = [#chlo, #chlo] + } : (tensor<2x11x5xf32>, tensor<3x2x5x7xf32>, tensor<3xi64>) -> tensor<2x11x7xf32> + check.expect_almost_eq_const %result, dense<[ + [ + [-0.0199659951, 0.00206358638, 0.0285578221, -0.00411329232, -0.00885893404, -0.0113086831, 0.0343487822], + [0.0108370036, 0.0196357146, 0.0464844741, 0.032903526, 0.00752512738, -0.00205732603, 0.0463109687], + [-0.0279003512, 0.0171403233, 0.00885203853, -0.022806216, -0.0135696121, -0.00375272054, 0.0139928926], + [0.0116565451, -0.00521556707, -0.0245668497, -0.00946252606, 2.734600e-03, 0.00460146647, -0.0332586318], + [0.0373648889, 0.040080104, 0.0792120546, 0.0687142611, 0.0129001699, 0.048170276, 6.067640e-02], + [-0.00489785476, 0.0151357278, 0.0273378156, 0.0379059538, 0.0080597708, 0.0209609158, 0.0248660222], + [0.00253825542, -1.175260e-02, 0.0339594558, 0.0408501513, 0.0275165718, 0.0101594552, 0.0491689071], + [5.275800e-02, 0.0415463448, 0.0749897882, 0.0470644757, 0.00624182029, 0.0391805507, 0.03869069], + [0.0637338459, 0.00614991458, 0.0153763723, 0.0190313365, 0.0142990183, 0.0227143262, 0.0187453162], + [0.0359746702, 0.0182777364, -0.00368779944, -0.0100486111, 6.89582666E-5, -0.00202751439, 0.0124766938], + [-0.0151847685, -0.0175893605, 0.0247314386, 0.018632818, 0.00798455066, -0.00110600982, 0.00244264561] + ], + [ + [0.0288968664, -0.00678509939, 0.0346419513, 0.0141028976, -0.017396003, 0.00451522879, 0.00792134088], + [-0.0017626211, -0.0284877941, 0.0151375476, -0.00351338694, -0.00874114502, 0.0323345512, 0.0535612516], + [0.00123786228, -0.00454656407, 0.0335229039, 0.0019464466, -2.14070082E-4, 0.0266590156, -0.0212618597], + [-3.47743975E-4, -0.017693948, 0.0353507064, 0.00244920771, -0.0120135043, 0.0417729542, -0.0025454592], + [0.0108208582, -0.0171308704, 0.00553112756, 0.0411250815, 0.0335835591, 0.038393192, -0.00547906291], + [0.0169365555, 0.0157370344, -0.0128378682, 0.0470919088, -0.00582840201, 0.0324328542, 0.010203423], + [0.0520783663, 0.0298755895, 0.0362326317, 0.0681023895, 0.0207777359, 0.052735541, 0.0455959477], + [0.00623999349, -1.49650674E-4, -0.00651274621, 0.0146591738, 0.00641800836, 0.00297434814, 0.00838128477], + [0.0506783053, 0.00703135319, 0.0220930576, 0.0259224195, 0.001958607, 0.0123232938, 0.00920359604], + [0.0123091843, -5.780780e-03, -0.0128484722, 0.00679983944, -0.00871101767, 0.0087406747, -0.0115246754], + [0.0274577513, -0.0175638888, -0.00203213934, -0.0198616516, -0.0110571291, 0.0365728177, 0.0162097216] + ] + ]> : tensor<2x11x7xf32> + func.return +} diff --git a/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir b/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir index 541f2d79990..87406febf07 100644 --- a/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir +++ b/stablehlo/tests/transforms/stablehlo_refine_parameters.mlir @@ -36,7 +36,7 @@ func.func @refine_params(%arg0: tensor, %arg1: tensor<1xf32>, %arg2: tensor // ----- -// expected-error @+1 {{number of refinements must match number of function operands 6 vs 1}} +// expected-error @+1 {{number of refinements must match number of op operands 6 vs 1}} func.func @refine_arguments_invalid_arg_num_mismatch(%arg0: tensor) { return } diff --git a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp index 5c0cb155ea9..939f327699b 100644 --- a/stablehlo/transforms/ChloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/ChloLegalizeToStablehlo.cpp @@ -14,15 +14,16 @@ #include #include #include +#include #include #include +#include #include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -1962,6 +1963,232 @@ static Value materializeSinhApproximation(OpBuilder &rewriter, Location loc, loc, absXLtOne, smallSinhResult, largeSinhResult); } +namespace { + +ArrayAttr convertPrecisionConfig(mlir::ArrayAttr precisionConfig, + ConversionPatternRewriter &rewriter) { + std::vector precisions; + for (Attribute precision : precisionConfig.getValue()) { + switch (dyn_cast(precision).getValue()) { + case mlir::chlo::Precision::HIGHEST: + precisions.push_back(rewriter.getAttr( + mlir::stablehlo::Precision::HIGHEST)); + break; + case mlir::chlo::Precision::HIGH: + precisions.push_back(rewriter.getAttr( + mlir::stablehlo::Precision::HIGH)); + break; + default: + precisions.push_back(rewriter.getAttr( + mlir::stablehlo::Precision::DEFAULT)); + break; + } + } + return ArrayAttr::get(rewriter.getContext(), precisions); +} + +// Mode 1, where the ragged dimension is an lhs non-contracting dim (m). +// lhs : [b, m, k] +// rhs : [g, b, k, n] +// group_sizes : [g] +// result : [b, m, n] +// This pass basically does g iterations of [b, m, k] x [b, k, n] dot_general +// operations, apply partial mask of size group_sizes[i] and then add them +// together. This is a slow implementation that's simple enough to understand +// with the hope that there's already an efficient hardware kernel. +// Note: +// In this implementation, the IR size increases by a factor of g. If this +// becomes a problem, we can try adding stablehlo.while to reduce the IR size. +LogicalResult handleRaggedDotMode1(mlir::chlo::RaggedDotOp op, + ConversionPatternRewriter &rewriter) { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + chlo::RaggedDotDimensionNumbersAttr raggedDotDimensionNumbers = + op.getRaggedDotDimensionNumbers(); + ArrayRef lhsBatchingDimensions = + raggedDotDimensionNumbers.getLhsBatchingDimensions(); + ArrayRef lhsContractingDimensions = + raggedDotDimensionNumbers.getLhsContractingDimensions(); + int64_t rhsGroupDimension = + raggedDotDimensionNumbers.getRhsGroupDimensions()[0]; + + auto groupSizes = op.getGroupSizes(); + auto precisionConfig = op.getPrecisionConfig(); + if (precisionConfig.has_value()) { + precisionConfig = convertPrecisionConfig(precisionConfig.value(), rewriter); + } + RankedTensorType lhsTy = cast(lhs.getType()); + RankedTensorType rhsTy = cast(rhs.getType()); + int64_t lhsRank = lhsTy.getRank(); + int64_t rhsRank = rhsTy.getRank(); + auto outDType = op.getResult().getType().getElementType(); + + int64_t m = lhsTy.getShape()[lhsTy.getRank() - 2]; + int64_t k = lhsTy.getShape()[lhsTy.getRank() - 1]; + int64_t g = rhsTy.getShape()[0]; + int64_t n = rhsTy.getShape()[rhsTy.getRank() - 1]; + + std::vector outDims = {m, n}; + std::vector iotaShape = {m, 1}; + auto iotaDim = 0; + std::vector rhsBatchingDims = {}; + std::vector rhsContractingDims = {0}; + std::vector rhsReshapedSliceShape = {k, n}; + + // If LHS has batching dimension, then decompose ragged dot based on shape + // [b, m, k], otherwise assume shape with no batch [m, k]. + if (lhsRank == 3) { + int64_t b = lhsTy.getShape()[0]; + outDims = {b, m, n}; + iotaShape = {1, m, 1}; + iotaDim = 1; + rhsBatchingDims = {0}; + rhsContractingDims = {1}; + rhsReshapedSliceShape = {b, k, n}; + } + + // result_iota = iota of shape [m, 1] or [1, m, 1] + Value resultIota = rewriter.create( + op.getLoc(), RankedTensorType::get(iotaShape, rewriter.getI64Type()), + /*dimension=*/iotaDim); + Value start = rewriter.create( + op.getLoc(), + rewriter.getZeroAttr(RankedTensorType::get({1}, rewriter.getI64Type()))); + + std::vector broadcastDimensions(lhsRank); + std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), 0); + + Value out = rewriter.create( + op.getLoc(), + rewriter.getZeroAttr(RankedTensorType::get(outDims, outDType))); + + Value outZeros = rewriter.create( + op.getLoc(), + rewriter.getZeroAttr(RankedTensorType::get(outDims, outDType))); + for (auto i = 0; i < g; ++i) { + // groupSize = group_sizes[i] + Value groupSize = rewriter.create( + op.getLoc(), RankedTensorType::get({1}, rewriter.getI64Type()), + groupSizes, + /*startIndices=*/rewriter.getDenseI64ArrayAttr({i}), + /*limitIndices=*/rewriter.getDenseI64ArrayAttr({i + 1}), + /*strides=*/rewriter.getDenseI64ArrayAttr({1})); + + Value startBroadcasted = rewriter.create( + op.getLoc(), resultIota.getType(), start, + /*broadcast_dimensions=*/ + rewriter.getDenseI64ArrayAttr(0)); + + // start <= result_iota + Value startLEResultIota = rewriter.create( + op.getLoc(), startBroadcasted, resultIota, ComparisonDirection::LE); + + // result_iota < (start + size) + Value resultIotaLTStartPlusGroupSize = + rewriter.create( + op.getLoc(), resultIota, + rewriter.create( + op.getLoc(), resultIota.getType(), + rewriter.create(op.getLoc(), start, + groupSize), + /*broadcast_dimensions=*/rewriter.getDenseI64ArrayAttr(0)), + ComparisonDirection::LT); + + // (start <= result_iota) & (result_iota < (start + size)) + Value logicalAnd = rewriter.create( + op.getLoc(), startLEResultIota, resultIotaLTStartPlusGroupSize); + Value logicalAndBroadcasted = + rewriter.create( + op.getLoc(), + RankedTensorType::get(op.getResult().getType().getShape(), + rewriter.getI1Type()), + logicalAnd, + /*broadcast_dimensions=*/ + rewriter.getDenseI64ArrayAttr(broadcastDimensions)); + + // rhs_rehaped_slice = rhs[i, :, :, :] + std::vector rhs_start_indices(rhsTy.getRank(), 0); + rhs_start_indices[rhsGroupDimension] = i; + std::vector rhs_limit_indices = rhsTy.getShape(); + rhs_limit_indices[rhsGroupDimension] = i + 1; + Value rhsSliced = rewriter.create( + op.getLoc(), rhs, + /*startIndices=*/rewriter.getDenseI64ArrayAttr(rhs_start_indices), + /*limitIndices=*/rewriter.getDenseI64ArrayAttr(rhs_limit_indices), + /*strides=*/ + rewriter.getDenseI64ArrayAttr(std::vector(rhsRank, 1))); + Value rhsReshapedSlice = rewriter.create( + op.getLoc(), + RankedTensorType::get(rhsReshapedSliceShape, rhsTy.getElementType()), + rhsSliced); + + // Einsum of (b)mk,(b)kn->(b)mn + Value dotGeneral = rewriter.create( + op.getLoc(), TypeRange{out.getType()}, + ValueRange{lhs, rhsReshapedSlice}, + ArrayRef{ + rewriter.getNamedAttr( + "dot_dimension_numbers", + rewriter.getAttr( + /*lhs_batching_dimensions=*/lhsBatchingDimensions, + /*rhs_batching_dimensions=*/rhsBatchingDims, + /*lhs_contracting_dimensions=*/lhsContractingDimensions, + /*rhs_contracting_dimensions=*/rhsContractingDims)), + rewriter.getNamedAttr("precision_config", + precisionConfig.value())}); + + // Place the i'th dot_general to the corresponding position in the result. + Value select = rewriter.create( + op.getLoc(), logicalAndBroadcasted, dotGeneral, outZeros); + out = rewriter.create(op.getLoc(), out, select); + start = + rewriter.create(op.getLoc(), start, groupSize); + } + rewriter.replaceOp(op, {out}); + return success(); +} + +// Mode 2, where the ragged dimension is an lhs/rhs contracting dim (k). +// lhs : [b, m, k] +// rhs : [b, k, n] +// group_sizes : [g] +// result : [g, b, m, n] +LogicalResult handleRaggedDotMode2(mlir::chlo::RaggedDotOp op, + ConversionPatternRewriter &rewriter) { + return failure(); +} + +// Mode 3, where the ragged dimension is an lhs/rhs batch dim (b). +// lhs : [b, m, k] +// rhs : [b, k, n] +// group_sizes : [g] +// result : [b, m, n] +LogicalResult handleRaggedDotMode3(mlir::chlo::RaggedDotOp op, + ConversionPatternRewriter &rewriter) { + return failure(); +} + +} // namespace + +struct ConvertRaggedDotOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // RaggedDot has three general modes, based on the kind of the ragged + // dimension. + LogicalResult matchAndRewrite( + mlir::chlo::RaggedDotOp op, OpAdaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getLhs().getType().getRank() < op.getRhs().getType().getRank()) { + return handleRaggedDotMode1(op, rewriter); + } else if (op.getLhs().getType().getRank() < + op.getResult().getType().getRank()) { + return handleRaggedDotMode2(op, rewriter); + } else { + return handleRaggedDotMode3(op, rewriter); + } + } +}; + struct ConvertSinhOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -2193,10 +2420,12 @@ static void populateChloBroadcastingPatterns(MLIRContext *context, static void populateChloDecompositionPatterns(MLIRContext *context, RewritePatternSet *patterns) { populateWithGenerated(*patterns); - patterns->add(context); + patterns + ->add( + context); } } // namespace diff --git a/stablehlo/transforms/StablehloRefineArguments.cpp b/stablehlo/transforms/StablehloRefineArguments.cpp index a903b7fea04..eaf007786ae 100644 --- a/stablehlo/transforms/StablehloRefineArguments.cpp +++ b/stablehlo/transforms/StablehloRefineArguments.cpp @@ -76,78 +76,6 @@ ParseResult parseRefinedTypes(ModuleOp module, return success(); } -LogicalResult refinementError(func::FuncOp func, int64_t idx, Type argType, - Type refinedType, StringRef msg) { - return func.emitOpError() - << "invalid refinement for argument " << idx << ", refinement " << msg - << " in " << mlir::debugString(argType) << " -> " - << mlir::debugString(refinedType); -} - -// Validates refinement types: -// - A type refinement must be specified for each operand -// - Refinement types that match operand types are skipped -// - Refinement types that do not match operands must be refining tensors -// - Refined tensor types must be ranked, operand type can be unranked -// - Refined tensor types must match operand type for all static dimensions -// -LogicalResult validateRefinedTypes(func::FuncOp func, TypeRange refinedTypes) { - // Validate refined shapes - if (func.getNumArguments() != refinedTypes.size()) { - return func.emitOpError( - "number of refinements must match number of function operands ") - << refinedTypes.size() << " vs " << func.getNumArguments(); - } - - // Validate that refinements are valid - auto argTypes = func.getArgumentTypes(); - for (int64_t i = 0; i < func.getNumArguments(); ++i) { - Type type = argTypes[i]; - Type refinedType = refinedTypes[i]; - - // Always allow skipping refinement - if (type == refinedType) continue; - - // If mismatched, must be tensor types - auto tensorType = dyn_cast(type); - auto refinedTensorType = dyn_cast(refinedType); - if (!tensorType || !refinedTensorType) { - return refinementError(func, i, type, refinedType, "must be a tensor"); - } - - // Check that element types match - if (tensorType.getElementType() != refinedTensorType.getElementType()) { - return refinementError(func, i, type, refinedType, - "element types must match"); - } - - // Refined rank cannot be unranked if mismatch - if (isa(refinedType)) { - return refinementError(func, i, type, refinedType, "must be ranked"); - } - - // Unranked operands can be refined to anything - if (!tensorType.hasRank()) continue; - - // Validate ranks match if ranked (must allow unranked tensorType) - if (tensorType.getRank() != refinedTensorType.getRank()) { - return refinementError(func, i, type, refinedType, - "rank must match operand rank"); - } - - // Validate static dimension sizes match - for (auto [dimSize, refinedDimSize] : - llvm::zip(tensorType.getShape(), refinedTensorType.getShape())) { - if (!ShapedType::isDynamic(dimSize) && dimSize != refinedDimSize) { - return refinementError( - func, i, type, refinedType, - "dimension sizes must match for static dimensions"); - } - } - } - return success(); -} - // Wrap operands in "type barriers" so the rest of the program remains valid // after the signature update and before shape refinement. // @@ -219,9 +147,74 @@ struct StablehloRefineArgumentsPass } // namespace +LogicalResult refinementError(Operation* op, int64_t idx, Type argType, + Type refinedType, StringRef msg) { + return op->emitOpError() + << "invalid refinement for argument " << idx << ", refinement " << msg + << " in " << mlir::debugString(argType) << " -> " + << mlir::debugString(refinedType); +} + +LogicalResult validateRefinedTypes(Operation* op, TypeRange argTypes, TypeRange refinedTypes) { + // Validate refined shapes + if (argTypes.size() != refinedTypes.size()) { + return op->emitOpError( + "number of refinements must match number of op operands ") + << refinedTypes.size() << " vs " << argTypes.size(); + } + + // Validate that refinements are valid + for (size_t i = 0; i < argTypes.size(); ++i) { + Type type = argTypes[i]; + Type refinedType = refinedTypes[i]; + + // Always allow skipping refinement + if (type == refinedType) continue; + + // If mismatched, must be tensor types + auto tensorType = dyn_cast(type); + auto refinedTensorType = dyn_cast(refinedType); + if (!tensorType || !refinedTensorType) { + return refinementError(op, i, type, refinedType, "must be a tensor"); + } + + // Check that element types match + if (tensorType.getElementType() != refinedTensorType.getElementType()) { + return refinementError(op, i, type, refinedType, + "element types must match"); + } + + // Refined rank cannot be unranked if mismatch + if (isa(refinedType)) { + return refinementError(op, i, type, refinedType, "must be ranked"); + } + + // Unranked operands can be refined to anything + if (!tensorType.hasRank()) continue; + + // Validate ranks match if ranked (must allow unranked tensorType) + if (tensorType.getRank() != refinedTensorType.getRank()) { + return refinementError(op, i, type, refinedType, + "rank must match operand rank"); + } + + // Validate static dimension sizes match + for (auto [dimSize, refinedDimSize] : + llvm::zip(tensorType.getShape(), refinedTensorType.getShape())) { + if (!ShapedType::isDynamic(dimSize) && dimSize != refinedDimSize) { + return refinementError( + op, i, type, refinedType, + "dimension sizes must match for static dimensions"); + } + } + } + return success(); +} + LogicalResult refineArguments(func::FuncOp func, TypeRange refinedTypes) { // Verify that refinements are valid - if (failed(validateRefinedTypes(func, refinedTypes))) return failure(); + if (failed(validateRefinedTypes(func, func.getArgumentTypes(), refinedTypes))) + return failure(); // Wrap refined operands in operand wrapper to keep IR valid for refinement wrapRefinedOperands(func, refinedTypes); diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index a1d11013c78..4274c094134 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -111,6 +111,10 @@ LogicalResult refineValues(PatternRewriter& rewriter, Operation* op, // verification of the op. if (isa(user->getDialect())) continue; + // TODO(bartchr): Consider if the dialect allow-listing approach is too + // strict. In the meantime, allow some shape interop with the shardy + // dialect. + if (user->getDialect()->getNamespace() == "sdy") continue; // Simply changing operand type of `func.return` won't work because // that won't update the FunctionType of the enclosing `func.func`. diff --git a/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/transforms/StablehloRefineShapes.h index 10e043dd66c..7f35092d060 100644 --- a/stablehlo/transforms/StablehloRefineShapes.h +++ b/stablehlo/transforms/StablehloRefineShapes.h @@ -16,19 +16,38 @@ limitations under the License. #ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H #define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H +#include + +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "stablehlo/dialect/Base.h" namespace mlir { namespace stablehlo { +// Emits an error message for invalid refinement. +LogicalResult refinementError(Operation* op, int64_t idx, Type argType, + Type refinedType, StringRef msg); + +// Validates refinement types: +// - A type refinement must be specified for each operand +// - Refinement types that match operand types are skipped +// - Refinement types that do not match operands must be refining tensors +// - Refined tensor types must be ranked, operand type can be unranked +// - Refined tensor types must match operand type for all static dimensions +LogicalResult validateRefinedTypes(Operation* op, TypeRange argTypes, + TypeRange refinedTypes); + // Gets a FuncOp that --stablehlo-refine-shapes will run on. // Returns a nullptr and emits appropriate errors if such a function cannot // be obtained from the module.