diff --git a/stablehlo/conversions/linalg/tests/miscellaneous.mlir b/stablehlo/conversions/linalg/tests/miscellaneous.mlir index 5ec4206adcb..643c109e87f 100644 --- a/stablehlo/conversions/linalg/tests/miscellaneous.mlir +++ b/stablehlo/conversions/linalg/tests/miscellaneous.mlir @@ -601,9 +601,9 @@ func.func @iota_complexf32() -> tensor<7x10xcomplex> { // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @dynamic_iota_f32 -// CHECK-SAME: %[[SHAPE:.*]]: tensor -func.func @dynamic_iota_f32(%shape: tensor) -> tensor { - %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor) -> (tensor) +// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> +func.func @dynamic_iota_f32(%shape: tensor<3xi32>) -> tensor { + %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor) func.return %result : tensor } // CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0] @@ -622,10 +622,10 @@ func.func @dynamic_iota_f32(%shape: tensor) -> tensor { // ----- // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @dyanmic_iota_ui32 -// CHECK-SAME: %[[SHAPE:.*]]: tensor -func.func @dyanmic_iota_ui32(%shape: tensor) -> tensor { - %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor) -> (tensor) +// CHECK: func @dynamic_iota_ui32 +// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32> +func.func @dynamic_iota_ui32(%shape: tensor<3xi32>) -> tensor { + %result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor) func.return %result : tensor } // CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0] diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index efcede4ab33..22a4ded9910 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -117,6 +117,8 @@ def HLO_PredTensor : TensorOf<[HLO_Pred]>; def HLO_Tensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>; +def HLO_RankedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>; + def HLO_ComplexTensor : TensorOf<[HLO_Complex]>; def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; @@ -159,6 +161,9 @@ def HLO_PredIntFpOrQuantizedTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float, HLO // HLO static shape type definitions. //===----------------------------------------------------------------------===// +// Static representation of a shape vector as a tensor. +def HLO_StaticDimensionTensor : RankedTensorOf<[HLO_DimensionValue], [HasStaticShapePred, HasAnyRankOfPred<[1]>], "statically shaped 1-dimensional tensor">; + // In general, static shaped tensor constraints should be avoided unless // it is for a legacy op which is only correct with static shapes. def HLO_StaticShapeTensor : StaticShapeTensorOf<[ diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 2ca6afeb9eb..8c5c7dc7369 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -141,8 +141,8 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Pure]> ``` }]; - let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension); - let results = (outs HLO_Tensor:$result); + let arguments = (ins HLO_StaticDimensionTensor:$output_shape, I64Attr:$iota_dimension); + let results = (outs HLO_RankedTensor:$result); let hasVerifier = 1; let assemblyFormat = [{ @@ -1910,13 +1910,13 @@ def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp< }]; let arguments = (ins HLO_Tensor:$operand, - HLO_DimensionTensor:$output_dimensions, + HLO_StaticDimensionTensor:$output_dimensions, DenseI64ArrayAttr:$broadcast_dimensions, OptionalAttr:$known_expanding_dimensions, OptionalAttr:$known_nonexpanding_dimensions ); - let results = (outs HLO_Tensor); + let results = (outs HLO_RankedTensor); let builders = [ OpBuilder<(ins @@ -2508,8 +2508,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [ ``` }]; - let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape); - let results = (outs HLO_Tensor:$result); + let arguments = (ins HLO_Tensor:$operand, HLO_StaticDimensionTensor:$output_shape); + let results = (outs HLO_RankedTensor:$result); let hasVerifier = 1; diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 159cb5e275c..811980be302 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3470,11 +3470,7 @@ LogicalResult verifyDynamicBroadcastInDimOp( std::optional> knownNonexpandingDimensions, Value result) { auto operandType = operand.getType().dyn_cast(); - auto resultType = result.getType().dyn_cast(); - - // If either the operand or result are unranked, there is very little - // to verify statically. - if (!operandType || !resultType) return success(); + auto resultType = result.getType().cast(); auto outputDimensionsType = outputDimensions.getType().cast(); @@ -3555,13 +3551,12 @@ LogicalResult verifyDynamicIotaOp(std::optional location, Value outputShape, int64_t iotaDimension, Value result) { auto shape = result.getType().cast(); + if (!isCompatibleForHloTypeInference(outputShape, shape)) return emitOptionalError( location, "output_shape is incompatible with return type of operation ", result.getType()); - if (!shape.hasRank()) return success(); - if (iotaDimension >= shape.getRank() || iotaDimension < 0) return emitOptionalError( location, @@ -3616,8 +3611,7 @@ LogicalResult verifyDynamicReshapeOp(std::optional location, Value outputShape, Value result) { auto resultType = result.getType().cast(); auto outputShapeType = outputShape.getType().cast(); - if (resultType.hasRank() && outputShapeType.hasStaticShape() && - outputShapeType.getDimSize(0) != resultType.getRank()) + if (outputShapeType.getDimSize(0) != resultType.getRank()) return emitOptionalError(location, "output should have a rank equal to the number of " "elements in output_shape"); diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index ad4f6a03265..22060062947 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -1010,15 +1010,6 @@ func.func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64 %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array} : (tensor, tensor<3xi64>) -> tensor func.return %0 : tensor } - -// ----- - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_unranked -func.func @dynamic_broadcast_in_dim_unranked(%arg0: tensor, %shape: tensor<3xi64>) -> tensor<*xi32> { - %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array} : (tensor, tensor<3xi64>) -> tensor<*xi32> - func.return %0 : tensor<*xi32> -} - // ----- // CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim @@ -1079,6 +1070,22 @@ func.func @dynamic_broadcast_in_dim_too_large(%arg0: tensor<1xf32>, %shape: tens // ----- +func.func @dynamic_broadcast_in_dim_unranked_result(%arg0: tensor, %shape: tensor<3xi64>) -> tensor<*xi32> { + // expected-error@+1 {{op result #0 must be ranked tensor}} + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array} : (tensor, tensor<3xi64>) -> tensor<*xi32> + func.return %0 : tensor<*xi32> +} + +// ----- + +func.func @dynamic_broadcast_in_dim_dynamic_output_shape(%arg0: tensor, %shape: tensor) -> tensor<7x8x9xi32> { + // expected-error@+1 {{op operand #1 must be statically shaped}} + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array} : (tensor, tensor) -> tensor<7x8x9xi32> + func.return %0 : tensor<7x8x9xi32> +} + +// ----- + // CHECK-LABEL: func @broadcast_in_dim func.func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> @@ -3391,13 +3398,6 @@ func.func @dynamic_reshape(%arg0: tensor, %shape: tensor<2xindex>) -> ten // ----- -func.func @dynamic_reshape_unranked(%arg0: tensor, %shape: tensor<2xindex>) -> tensor<*xf32> { - %0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { // expected-error @+1 {{output should have a rank equal to the number of elements in output_shape}} %0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor @@ -3407,7 +3407,7 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor, %shape: ten // ----- func.func @dynamic_reshape_output_shape_negative_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}} + // expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}} %0 = stablehlo.constant dense<[-1, 1]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> return %1 : tensor<1x4xf32> @@ -3416,7 +3416,7 @@ func.func @dynamic_reshape_output_shape_negative_size(%arg0: tensor<4xf32>) -> t // ----- func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> { - // @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}} + // expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}} %0 = stablehlo.constant dense<[1, 1]> : tensor<2xi64> %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32> return %1 : tensor<1x4xf32> @@ -3424,6 +3424,22 @@ func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) - // ----- +func.func @dynamic_reshape_unranked_result(%arg0: tensor, %shape: tensor<2xindex>) -> tensor<*xf32> { + // expected-error@+1 {{op result #0 must be ranked tensor}} + %0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +func.func @dynamic_reshape_dynamic_output_shape(%arg0: tensor, %shape: tensor) -> tensor<1x4xf32> { + // expected-error@+1 {{op operand #1 must be statically shaped}} + %0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor<1x4xf32> + func.return %0 : tensor<1x4xf32> +} + +// ----- + func.func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { %0 = "stablehlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32> func.return %0 : tensor<2x4xf32> @@ -5700,6 +5716,7 @@ func.func @dynamic_iota_dynamic() -> tensor { // ----- func.func @dynamic_iota_unranked() -> tensor<*xf32> { + // expected-error@+2 {{op result #0 must be ranked tensor}} %0 = stablehlo.constant dense<[4]> : tensor<1xi64> %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<*xf32> func.return %1 : tensor<*xf32> @@ -5707,10 +5724,10 @@ func.func @dynamic_iota_unranked() -> tensor<*xf32> { // ----- -func.func @dynamic_iota_unranked_large() -> tensor<*xf32> { - %0 = stablehlo.constant dense<[4]> : tensor<1xi64> - %1 = stablehlo.dynamic_iota %0, dim = 3 : (tensor<1xi64>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> +func.func @dynamic_iota_dynamic_output_shape(%arg: tensor) -> tensor<*xf32> { + // expected-error@+1 {{op operand #0 must be statically shaped}} + %0 = stablehlo.dynamic_iota %arg, dim = 0 : (tensor) -> tensor + func.return %0 : tensor } // ----- diff --git a/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/tests/stablehlo_refine_shapes.mlir index c219232a3f5..c3deb36b016 100644 --- a/stablehlo/tests/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/stablehlo_refine_shapes.mlir @@ -554,11 +554,11 @@ func.func @refine_dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32> // ----- // CHECK-LABEL: @refine_dynamic_broadcast_in_dim -func.func @refine_dynamic_broadcast_in_dim(%arg0: tensor<4xf32>) -> tensor<*xf32> { +func.func @refine_dynamic_broadcast_in_dim(%arg0: tensor<4xf32>) -> tensor { // CHECK: stablehlo.dynamic_broadcast_in_dim{{.*}} -> tensor<3x4xf32> %0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64> - %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> + %1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor + func.return %1 : tensor } // ----- @@ -598,11 +598,11 @@ func.func @refine_dynamic_gather(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2 // ----- // CHECK-LABEL: @refine_dynamic_iota -func.func @refine_dynamic_iota() -> tensor<*xf32> { +func.func @refine_dynamic_iota() -> tensor { // CHECK: stablehlo.dynamic_iota{{.*}} -> tensor<4xf32> %0 = stablehlo.constant dense<[4]> : tensor<1xi64> - %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> + %1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor + func.return %1 : tensor } // ----- @@ -621,11 +621,11 @@ func.func @refine_dynamic_pad(%arg0: tensor<4xf32>, %arg1: tensor) -> tenso // ----- // CHECK-LABEL: @refine_dynamic_reshape -func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor<*xf32> { +func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor { // CHECK: stablehlo.dynamic_reshape{{.*}} -> tensor<1x4xf32> %0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64> - %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<*xf32> - func.return %1 : tensor<*xf32> + %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor + func.return %1 : tensor } // ----- diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir index 2f562448e8b..c09e86c1a6f 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc index 216cd266ae4..db151f2c95e 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir index 98607566fe8..95e51eff537 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc index 60f0f821299..c3bd3005d1c 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir index aabb800d9df..04227a70c55 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc index ea1add250b6..3197c5c29b3 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir index 8732686f622..0b43304dca0 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc index 16296841579..faf8586da5e 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir index 4b6ebc574ab..3f61d319ff5 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc index 6b9a004a808..ea07a2c9baf 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir index bb0d581b63e..86f49ef22b6 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir @@ -1268,9 +1268,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc index 4ec2df8d1a9..4134975b8b1 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir index 8c1f48bbd8f..91038b90978 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir @@ -1279,9 +1279,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc index 81b4be150f5..7d52c8e3cc4 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir index 152f30bf6ba..e6bdd26b7a6 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir @@ -1296,9 +1296,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc index c9ca8fc3335..f6b993b17c6 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir index ca3ed171335..05864477016 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc index 8b33745e37e..b461ab54a56 100644 Binary files a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_9_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir index 7840b5013f7..c19d4810fa5 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir @@ -1301,9 +1301,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor }