Skip to content

Commit

Permalink
Disallow unranked dynamism in ops that take output shape as an operand
Browse files Browse the repository at this point in the history
This includes:
- dynamic_broadcast_in_dim
- dynamic_iota
- dynamic_reshape

#1881
  • Loading branch information
Michael Levesque-Dion committed Feb 5, 2024
1 parent d50b6fc commit 628bad9
Show file tree
Hide file tree
Showing 25 changed files with 99 additions and 83 deletions.
14 changes: 7 additions & 7 deletions stablehlo/conversions/linalg/tests/miscellaneous.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,9 @@ func.func @iota_complexf32() -> tensor<7x10xcomplex<f32>> {

// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @dynamic_iota_f32
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
func.func @dynamic_iota_f32(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
%result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
func.func @dynamic_iota_f32(%shape: tensor<3xi32>) -> tensor<?x?x8xf32> {
%result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor<?x?x8xf32>)
func.return %result : tensor<?x?x8xf32>
}
// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0]
Expand All @@ -622,10 +622,10 @@ func.func @dynamic_iota_f32(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
// -----

// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @dyanmic_iota_ui32
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
func.func @dyanmic_iota_ui32(%shape: tensor<?xi32>) -> tensor<?x?x8xui32> {
%result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xui32>)
// CHECK: func @dynamic_iota_ui32
// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
func.func @dynamic_iota_ui32(%shape: tensor<3xi32>) -> tensor<?x?x8xui32> {
%result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor<?x?x8xui32>)
func.return %result : tensor<?x?x8xui32>
}
// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0]
Expand Down
5 changes: 5 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def HLO_PredTensor : TensorOf<[HLO_Pred]>;

def HLO_Tensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;

def HLO_RankedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;

def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;

def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
Expand Down Expand Up @@ -159,6 +161,9 @@ def HLO_PredIntFpOrQuantizedTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float, HLO
// HLO static shape type definitions.
//===----------------------------------------------------------------------===//

// Static representation of a shape vector as a tensor.
def HLO_StaticDimensionTensor : RankedTensorOf<[HLO_DimensionValue], [HasStaticShapePred, HasAnyRankOfPred<[1]>], "statically shaped 1-dimensional tensor">;

// In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes.
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
Expand Down
12 changes: 6 additions & 6 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Pure]>
```
}];

let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension);
let results = (outs HLO_Tensor:$result);
let arguments = (ins HLO_StaticDimensionTensor:$output_shape, I64Attr:$iota_dimension);
let results = (outs HLO_RankedTensor:$result);
let hasVerifier = 1;

let assemblyFormat = [{
Expand Down Expand Up @@ -1910,13 +1910,13 @@ def StableHLO_DynamicBroadcastInDimOp : StableHLO_ShapedInterfaceOp<
}];
let arguments = (ins
HLO_Tensor:$operand,
HLO_DimensionTensor:$output_dimensions,
HLO_StaticDimensionTensor:$output_dimensions,
DenseI64ArrayAttr:$broadcast_dimensions,
OptionalAttr<DenseI64ArrayAttr>:$known_expanding_dimensions,
OptionalAttr<DenseI64ArrayAttr>:$known_nonexpanding_dimensions
);

let results = (outs HLO_Tensor);
let results = (outs HLO_RankedTensor);

let builders = [
OpBuilder<(ins
Expand Down Expand Up @@ -2508,8 +2508,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [
```
}];

let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape);
let results = (outs HLO_Tensor:$result);
let arguments = (ins HLO_Tensor:$operand, HLO_StaticDimensionTensor:$output_shape);
let results = (outs HLO_RankedTensor:$result);

let hasVerifier = 1;

Expand Down
12 changes: 3 additions & 9 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3470,11 +3470,7 @@ LogicalResult verifyDynamicBroadcastInDimOp(
std::optional<ArrayRef<int64_t>> knownNonexpandingDimensions,
Value result) {
auto operandType = operand.getType().dyn_cast<RankedTensorType>();
auto resultType = result.getType().dyn_cast<RankedTensorType>();

// If either the operand or result are unranked, there is very little
// to verify statically.
if (!operandType || !resultType) return success();
auto resultType = result.getType().cast<RankedTensorType>();

auto outputDimensionsType =
outputDimensions.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -3555,13 +3551,12 @@ LogicalResult verifyDynamicIotaOp(std::optional<Location> location,
Value outputShape, int64_t iotaDimension,
Value result) {
auto shape = result.getType().cast<ShapedType>();

if (!isCompatibleForHloTypeInference(outputShape, shape))
return emitOptionalError(
location, "output_shape is incompatible with return type of operation ",
result.getType());

if (!shape.hasRank()) return success();

if (iotaDimension >= shape.getRank() || iotaDimension < 0)
return emitOptionalError(
location,
Expand Down Expand Up @@ -3616,8 +3611,7 @@ LogicalResult verifyDynamicReshapeOp(std::optional<Location> location,
Value outputShape, Value result) {
auto resultType = result.getType().cast<ShapedType>();
auto outputShapeType = outputShape.getType().cast<ShapedType>();
if (resultType.hasRank() && outputShapeType.hasStaticShape() &&
outputShapeType.getDimSize(0) != resultType.getRank())
if (outputShapeType.getDimSize(0) != resultType.getRank())
return emitOptionalError(location,
"output should have a rank equal to the number of "
"elements in output_shape");
Expand Down
61 changes: 39 additions & 22 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1010,15 +1010,6 @@ func.func @dynamic_broadcast_in_dim(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
func.return %0 : tensor<?x?x?xi32>
}

// -----

// CHECK-LABEL: func @dynamic_broadcast_in_dim_unranked
func.func @dynamic_broadcast_in_dim_unranked(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<*xi32> {
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<*xi32>
func.return %0 : tensor<*xi32>
}

// -----

// CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim
Expand Down Expand Up @@ -1079,6 +1070,22 @@ func.func @dynamic_broadcast_in_dim_too_large(%arg0: tensor<1xf32>, %shape: tens

// -----

func.func @dynamic_broadcast_in_dim_unranked_result(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<*xi32> {
// expected-error@+1 {{op result #0 must be ranked tensor}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<*xi32>
func.return %0 : tensor<*xi32>
}

// -----

func.func @dynamic_broadcast_in_dim_dynamic_output_shape(%arg0: tensor<?x?xi32>, %shape: tensor<?xi64>) -> tensor<7x8x9xi32> {
// expected-error@+1 {{op operand #1 must be statically shaped}}
%0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<?x?xi32>, tensor<?xi64>) -> tensor<7x8x9xi32>
func.return %0 : tensor<7x8x9xi32>
}

// -----

// CHECK-LABEL: func @broadcast_in_dim
func.func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> {
%0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array<i64: 1, 2>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32>
Expand Down Expand Up @@ -3391,13 +3398,6 @@ func.func @dynamic_reshape(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> ten

// -----

func.func @dynamic_reshape_unranked(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<*xf32> {
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?xf32> {
// expected-error @+1 {{output should have a rank equal to the number of elements in output_shape}}
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?xf32>
Expand All @@ -3407,7 +3407,7 @@ func.func @dynamic_reshape_incompatible_shapes(%arg0: tensor<?xf32>, %shape: ten
// -----

func.func @dynamic_reshape_output_shape_negative_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> {
// @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}}
// expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}}
%0 = stablehlo.constant dense<[-1, 1]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32>
return %1 : tensor<1x4xf32>
Expand All @@ -3416,14 +3416,30 @@ func.func @dynamic_reshape_output_shape_negative_size(%arg0: tensor<4xf32>) -> t
// -----

func.func @dynamic_reshape_output_shape_mismatching_size(%arg0: tensor<4xf32>) -> tensor<1x4xf32> {
// @expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}}
// expected-error@+2 {{output_shape is incompatible with return type of operation 'tensor<1x4xf32>'}}
%0 = stablehlo.constant dense<[1, 1]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<1x4xf32>
return %1 : tensor<1x4xf32>
}

// -----

func.func @dynamic_reshape_unranked_result(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<*xf32> {
// expected-error@+1 {{op result #0 must be ranked tensor}}
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @dynamic_reshape_dynamic_output_shape(%arg0: tensor<?xf32>, %shape: tensor<?xindex>) -> tensor<1x4xf32> {
// expected-error@+1 {{op operand #1 must be statically shaped}}
%0 = "stablehlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<1x4xf32>
func.return %0 : tensor<1x4xf32>
}

// -----

func.func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "stablehlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
func.return %0 : tensor<2x4xf32>
Expand Down Expand Up @@ -5700,17 +5716,18 @@ func.func @dynamic_iota_dynamic() -> tensor<?xf32> {
// -----

func.func @dynamic_iota_unranked() -> tensor<*xf32> {
// expected-error@+2 {{op result #0 must be ranked tensor}}
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
}

// -----

func.func @dynamic_iota_unranked_large() -> tensor<*xf32> {
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 3 : (tensor<1xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
func.func @dynamic_iota_dynamic_output_shape(%arg: tensor<?xi64>) -> tensor<*xf32> {
// expected-error@+1 {{op operand #0 must be statically shaped}}
%0 = stablehlo.dynamic_iota %arg, dim = 0 : (tensor<?xi64>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

// -----
Expand Down
18 changes: 9 additions & 9 deletions stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,11 @@ func.func @refine_dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>
// -----

// CHECK-LABEL: @refine_dynamic_broadcast_in_dim
func.func @refine_dynamic_broadcast_in_dim(%arg0: tensor<4xf32>) -> tensor<*xf32> {
func.func @refine_dynamic_broadcast_in_dim(%arg0: tensor<4xf32>) -> tensor<?x?xf32> {
// CHECK: stablehlo.dynamic_broadcast_in_dim{{.*}} -> tensor<3x4xf32>
%0 = stablehlo.constant dense<[3, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
%1 = stablehlo.dynamic_broadcast_in_dim %arg0, %0, dims = [1] : (tensor<4xf32>, tensor<2xi64>) -> tensor<?x?xf32>
func.return %1 : tensor<?x?xf32>
}

// -----
Expand Down Expand Up @@ -598,11 +598,11 @@ func.func @refine_dynamic_gather(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2
// -----

// CHECK-LABEL: @refine_dynamic_iota
func.func @refine_dynamic_iota() -> tensor<*xf32> {
func.func @refine_dynamic_iota() -> tensor<?xf32> {
// CHECK: stablehlo.dynamic_iota{{.*}} -> tensor<4xf32>
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<?xf32>
func.return %1 : tensor<?xf32>
}

// -----
Expand All @@ -621,11 +621,11 @@ func.func @refine_dynamic_pad(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tenso
// -----

// CHECK-LABEL: @refine_dynamic_reshape
func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor<*xf32> {
func.func @refine_dynamic_reshape(%arg0: tensor<4xf32>) -> tensor<?x?xf32> {
// CHECK: stablehlo.dynamic_reshape{{.*}} -> tensor<1x4xf32>
%0 = stablehlo.constant dense<[1, 4]> : tensor<2xi64>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
%1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<4xf32>, tensor<2xi64>) -> tensor<?x?xf32>
func.return %1 : tensor<?x?xf32>
}

// -----
Expand Down
6 changes: 3 additions & 3 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tenso
}

// CHECK-LABEL: "op_dynamic_reshape"
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<?xindex>) -> tensor<?x?xf32>
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc
Binary file not shown.
6 changes: 3 additions & 3 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tenso
}

// CHECK-LABEL: "op_dynamic_reshape"
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<?xindex>) -> tensor<?x?xf32>
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc
Binary file not shown.
6 changes: 3 additions & 3 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tenso
}

// CHECK-LABEL: "op_dynamic_reshape"
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<?xindex>) -> tensor<?x?xf32>
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc
Binary file not shown.
6 changes: 3 additions & 3 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tenso
}

// CHECK-LABEL: "op_dynamic_reshape"
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<?xindex>) -> tensor<?x?xf32>
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc
Binary file not shown.
6 changes: 3 additions & 3 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1260,9 +1260,9 @@ func.func @op_dynamic_pad(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tenso
}

// CHECK-LABEL: "op_dynamic_reshape"
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<?xindex>) -> tensor<?x?xf32>
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc
Binary file not shown.
6 changes: 3 additions & 3 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1268,9 +1268,9 @@ func.func @op_dynamic_pad(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tenso
}

// CHECK-LABEL: "op_dynamic_reshape"
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<?xindex>) -> tensor<?x?xf32>
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_15_0.mlir.bc
Binary file not shown.
6 changes: 3 additions & 3 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1279,9 +1279,9 @@ func.func @op_dynamic_pad(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tenso
}

// CHECK-LABEL: "op_dynamic_reshape"
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<?xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<?xindex>) -> tensor<?x?xf32>
func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1<?x?x!vhlo.f32_v1>
%0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_16_0.mlir.bc
Binary file not shown.
Loading

0 comments on commit 628bad9

Please sign in to comment.