diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 22a4ded9910..ff925fcf8b9 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -154,6 +154,9 @@ def HLO_IntFpOrComplexOrQuantizedIntTensor : TensorOf<[HLO_Int, HLO_Float, HLO_C // Any pred, int or floating-point tensor types def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float]>; +// Any pred, int or floating-point ranked tensor types +def HLO_PredIntOrFpRankedTensor : RankedTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>; + // Any pred, int, floating-point or quantized tensor types def HLO_PredIntFpOrQuantizedTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float, HLO_QuantizedInt]>; diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 30d561da32c..b2ab6d63899 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -3036,11 +3036,11 @@ def StableHLO_RngOp : StableHLO_Op<"rng", [InferTensorTypeWithReify, AllElementT let arguments = (ins 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$a, 0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$b, - HLO_DimensionTensor:$shape, + HLO_StaticDimensionTensor:$shape, StableHLO_RngDistributionAttr:$rng_distribution ); - let results = (outs HLO_PredIntOrFpTensor:$result); + let results = (outs HLO_PredIntOrFpRankedTensor:$result); let assemblyFormat = [{ $a `,` $b `,` $shape `,` `distribution` `=` $rng_distribution diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 811980be302..6c4356e70a8 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -2685,19 +2685,14 @@ LogicalResult inferRngOp( auto shapeOperandType = shape.getType().cast(); Type elementType = getElementTypeOrSelf(b); - // Operand `shape` (1D by ODS) may be a constant or not, if `shape` is: - // 1, not constant and have dynamic dim (tensor): infer tensor<*x>. - // 2. not constant nor dynamic (e.g. tensor<3xi64>): infer tensor. - // 3. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>. + // Operand `shape` (static 1D by ODS) may be a constant or not, if `shape` is: + // 1. not constant (e.g. tensor<3xi64>): infer tensor. + // 2. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>. // Match to check whether the `shape` operand is a constant. DenseIntElementsAttr shapeAttr; if (!matchPattern(shape, m_Constant(&shapeAttr))) { int size = shapeOperandType.getDimSize(0); - if (isDynamicDimSize(size)) { - inferredReturnShapes.emplace_back(elementType); - return success(); - } shapeVector.resize(size, ShapedType::kDynamic); inferredReturnShapes.emplace_back(shapeVector, elementType); return success(); @@ -2705,8 +2700,8 @@ LogicalResult inferRngOp( // `shape` operand is a constant. shapeVector.reserve(shapeAttr.size()); - for (const APInt& fp : shapeAttr.getValues()) - shapeVector.push_back(fp.getSExtValue()); + for (const APInt& dimSize : shapeAttr.getValues()) + shapeVector.push_back(dimSize.getSExtValue()); inferredReturnShapes.emplace_back(shapeVector, elementType); return success(); } diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 22060062947..4af356ce3c4 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -2125,6 +2125,22 @@ func.func @rng_bit_generator_dynamic(%arg0: tensor) -> (tensor, // ----- +func.func @rng_dynamic_dim(%a: tensor, %b: tensor, %shape: tensor) -> tensor<*xf32> { + // expected-error@+1 {{op operand #2 must be statically shaped}} + %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor, tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @rng_unranked_output(%a: tensor, %b: tensor, %shape: tensor<3xi64>) -> tensor<*xf32> { + // expected-error@+1 {{op result #0 must be ranked}} + %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor, tensor<3xi64>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: func @rng_normal func.func @rng_normal(%arg0: tensor, %arg1: tensor) -> tensor<2x3x5xf32> { %cst = "stablehlo.constant"() {value = dense<[2, 3, 5]> : tensor<3xi64>} : () -> tensor<3xi64> @@ -2142,14 +2158,6 @@ func.func @rng_normal_no_constant(%a: tensor, %b: tensor, %shape: tens // ----- -// CHECK-LABEL: func @rng_normal_dynamic_dim -func.func @rng_normal_dynamic_dim(%a: tensor, %b: tensor, %shape: tensor) -> tensor<*xf32> { - %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor, tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @rng_normal_invalid_shape(%arg0: tensor, %arg1: tensor) { %cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64> // expected-error@+2 {{failed to infer returned types}} @@ -2180,7 +2188,7 @@ func.func @rng_normal_invalid_sigma_rank(%mu: tensor, %sigma: tensor<1xf32> func.func @rng_normal_invalid_shape_rank(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64> - // expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}} + // expected-error@+1 {{operand #2 must be statically shaped 1-dimensional tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}} %0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo}: (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } @@ -2213,14 +2221,6 @@ func.func @rng_uniform_no_constant(%a: tensor, %b: tensor, %shape: ten // ----- -// CHECK-LABEL: func @rng_uniform_dynamic_dim -func.func @rng_uniform_dynamic_dim(%a: tensor, %b: tensor, %shape: tensor) -> tensor<*xf32> { - %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor, tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @rng_uniform_invalid_shape(%arg0: tensor, %arg1: tensor, %arg2: tensor<7xi64>) { // expected-error@+2 {{failed to infer returned types}} // expected-error @+1 {{inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor'}} @@ -2251,7 +2251,7 @@ func.func @rng_uniform_invalid_b_rank(%a: tensor, %b: tensor<1xf32>) -> ten func.func @rng_uniform_invalid_shape_rank(%a: tensor, %b: tensor) -> tensor<2x3x5xf32> { %shape = stablehlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64> - // expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}} + // expected-error@+1 {{operand #2 must be statically shaped 1-dimensional tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}} %0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo}: (tensor, tensor, tensor<1x3xi64>) -> tensor<2x3x5xf32> func.return %0 : tensor<2x3x5xf32> } diff --git a/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/tests/stablehlo_refine_shapes.mlir index c3deb36b016..df4ced44f66 100644 --- a/stablehlo/tests/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/stablehlo_refine_shapes.mlir @@ -743,11 +743,11 @@ func.func @refine_reduce_scatter_flattened_ids(%data: tensor<4x16xf32>) -> tenso // ----- // CHECK-LABEL: @refine_rng -func.func @refine_rng(%arg0: tensor, %arg1: tensor) -> tensor<*xf32> { +func.func @refine_rng(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: stablehlo.rng{{.*}} -> tensor<4xf32> %0 = stablehlo.constant dense<[4]> : tensor<1xi64> - %1 = stablehlo.rng %arg0, %arg1, %0, distribution = NORMAL : (tensor, tensor, tensor<1xi64>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> + %1 = stablehlo.rng %arg0, %arg1, %0, distribution = NORMAL : (tensor, tensor, tensor<1xi64>) -> tensor + func.return %1 : tensor } // ----- diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir index c09e86c1a6f..d2eef1d28e5 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc index db151f2c95e..c040c899c31 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir index 95e51eff537..59657615db9 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc index c3bd3005d1c..2e3fc977ed9 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir index 04227a70c55..8d3f1a14053 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc index 3197c5c29b3..684a242efbb 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir index 0b43304dca0..db050230c2a 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc index faf8586da5e..847748925eb 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir index 3f61d319ff5..9e25d0177f5 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc index ea07a2c9baf..4b48cf0681a 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir index 86f49ef22b6..cf0001f2ea9 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir @@ -259,20 +259,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1750,13 +1750,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc index 4134975b8b1..6f6da967ff4 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir index 91038b90978..16b2527a3c1 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir @@ -258,20 +258,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1761,13 +1761,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc index 7d52c8e3cc4..b9b4c99c6a7 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir index e6bdd26b7a6..5b06671a35c 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir @@ -256,20 +256,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1836,13 +1836,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc index f6b993b17c6..8f459cf76b6 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir index 05864477016..fa9f862417b 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc index b461ab54a56..80545295112 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir index c19d4810fa5..4333f71a283 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir @@ -261,20 +261,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "attr_rng_distribution_uniform" -func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } // CHECK-LABEL: "attr_rng_distribution_normal" -func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { // CHECK: rng_distribution = #vhlo rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor } @@ -1841,13 +1841,13 @@ func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor } // CHECK-LABEL: "op_rng" -func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ // CHECK-SAME: rng_distribution = #vhlo - // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { rng_distribution = #stablehlo - } : (tensor, tensor, tensor) -> tensor + } : (tensor, tensor, tensor<0xindex>) -> tensor func.return %0 : tensor }