Skip to content

Commit

Permalink
Disallow unranked dynamism in dynamic_iota
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Levesque-Dion committed Jan 30, 2024
1 parent 00ae7d4 commit 03d23a0
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 18 deletions.
14 changes: 7 additions & 7 deletions stablehlo/conversions/linalg/tests/miscellaneous.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,9 @@ func.func @iota_complexf32() -> tensor<7x10xcomplex<f32>> {

// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @dynamic_iota_f32
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
func.func @dynamic_iota_f32(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
%result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xf32>)
// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
func.func @dynamic_iota_f32(%shape: tensor<3xi32>) -> tensor<?x?x8xf32> {
%result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor<?x?x8xf32>)
func.return %result : tensor<?x?x8xf32>
}
// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0]
Expand All @@ -622,10 +622,10 @@ func.func @dynamic_iota_f32(%shape: tensor<?xi32>) -> tensor<?x?x8xf32> {
// -----

// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK: func @dyanmic_iota_ui32
// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xi32>
func.func @dyanmic_iota_ui32(%shape: tensor<?xi32>) -> tensor<?x?x8xui32> {
%result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<?xi32>) -> (tensor<?x?x8xui32>)
// CHECK: func @dynamic_iota_ui32
// CHECK-SAME: %[[SHAPE:.*]]: tensor<3xi32>
func.func @dynamic_iota_ui32(%shape: tensor<3xi32>) -> tensor<?x?x8xui32> {
%result = "stablehlo.dynamic_iota"(%shape) {iota_dimension = 1 : i64} : (tensor<3xi32>) -> (tensor<?x?x8xui32>)
func.return %result : tensor<?x?x8xui32>
}
// CHECK: %[[V1:.*]] = tensor.extract %[[SHAPE]][%c0]
Expand Down
5 changes: 5 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def HLO_PredTensor : TensorOf<[HLO_Pred]>;

def HLO_Tensor : TensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;

def HLO_RankedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;

def HLO_ComplexTensor : TensorOf<[HLO_Complex]>;

def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;
Expand All @@ -130,6 +132,9 @@ def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Int]>;
// Dynamic representation of a shape vector as a tensor.
def HLO_DimensionTensor : 1DTensorOf<[HLO_DimensionValue]>;

// Static representation of a shape vector as a tensor.
def HLO_StaticDimensionTensor : RankedTensorOf<[HLO_DimensionValue], [HasStaticShapePred, HasAnyRankOfPred<[1]>], "statically shaped 1-dimensional tensor">;

//===----------------------------------------------------------------------===//
// HLO combined type definitions.
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Pure]>
```
}];

let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension);
let results = (outs HLO_Tensor:$result);
let arguments = (ins HLO_StaticDimensionTensor:$output_shape, I64Attr:$iota_dimension);
let results = (outs HLO_RankedTensor:$result);
let hasVerifier = 1;

let assemblyFormat = [{
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3555,13 +3555,12 @@ LogicalResult verifyDynamicIotaOp(std::optional<Location> location,
Value outputShape, int64_t iotaDimension,
Value result) {
auto shape = result.getType().cast<ShapedType>();

if (!isCompatibleForHloTypeInference(outputShape, shape))
return emitOptionalError(
location, "output_shape is incompatible with return type of operation ",
result.getType());

if (!shape.hasRank()) return success();

if (iotaDimension >= shape.getRank() || iotaDimension < 0)
return emitOptionalError(
location,
Expand Down
9 changes: 5 additions & 4 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5700,17 +5700,18 @@ func.func @dynamic_iota_dynamic() -> tensor<?xf32> {
// -----

func.func @dynamic_iota_unranked() -> tensor<*xf32> {
// expected-error@+2 {{op result #0 must be ranked tensor}}
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
}

// -----

func.func @dynamic_iota_unranked_large() -> tensor<*xf32> {
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 3 : (tensor<1xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
func.func @dynamic_iota_dynamic_output_shape(%arg: tensor<?xi64>) -> tensor<*xf32> {
// expected-error@+1 {{op operand #0 must be statically shaped}}
%0 = stablehlo.dynamic_iota %arg, dim = 0 : (tensor<?xi64>) -> tensor<?x?xf32>
func.return %0 : tensor<?x?xf32>
}

// -----
Expand Down
6 changes: 3 additions & 3 deletions stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,11 @@ func.func @refine_dynamic_gather(%arg0 : tensor<2x4x9xi32>, %arg1 : tensor<1x5x2
// -----

// CHECK-LABEL: @refine_dynamic_iota
func.func @refine_dynamic_iota() -> tensor<*xf32> {
func.func @refine_dynamic_iota() -> tensor<?xf32> {
// CHECK: stablehlo.dynamic_iota{{.*}} -> tensor<4xf32>
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
%1 = stablehlo.dynamic_iota %0, dim = 0 : (tensor<1xi64>) -> tensor<?xf32>
func.return %1 : tensor<?xf32>
}

// -----
Expand Down

0 comments on commit 03d23a0

Please sign in to comment.