Skip to content

Commit

Permalink
Disallow unranked dynamism in rng_op
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Levesque-Dion committed Feb 7, 2024
1 parent 034ef4b commit 7cf989a
Show file tree
Hide file tree
Showing 24 changed files with 101 additions and 103 deletions.
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def HLO_IntFpOrComplexOrQuantizedIntTensor : TensorOf<[HLO_Int, HLO_Float, HLO_C
// Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float]>;

// Any pred, int or floating-point ranked tensor types
def HLO_PredIntOrFpRankedTensor : RankedTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>;

// Any pred, int, floating-point or quantized tensor types
def HLO_PredIntFpOrQuantizedTensor : TensorOf<[HLO_Pred, HLO_Int, HLO_Float, HLO_QuantizedInt]>;

Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3036,11 +3036,11 @@ def StableHLO_RngOp : StableHLO_Op<"rng", [InferTensorTypeWithReify, AllElementT
let arguments = (ins
0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$a,
0DTensorOf<[HLO_Pred, HLO_Int, HLO_Float]>:$b,
HLO_DimensionTensor:$shape,
HLO_StaticDimensionTensor:$shape,
StableHLO_RngDistributionAttr:$rng_distribution
);

let results = (outs HLO_PredIntOrFpTensor:$result);
let results = (outs HLO_PredIntOrFpRankedTensor:$result);

let assemblyFormat = [{
$a `,` $b `,` $shape `,` `distribution` `=` $rng_distribution
Expand Down
15 changes: 5 additions & 10 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2685,28 +2685,23 @@ LogicalResult inferRngOp(
auto shapeOperandType = shape.getType().cast<ShapedType>();
Type elementType = getElementTypeOrSelf(b);

// Operand `shape` (1D by ODS) may be a constant or not, if `shape` is:
// 1, not constant and have dynamic dim (tensor<?x>): infer tensor<*x>.
// 2. not constant nor dynamic (e.g. tensor<3xi64>): infer tensor<?x?x?x>.
// 3. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>.
// Operand `shape` (static 1D by ODS) may be a constant or not, if `shape` is:
// 1. not constant (e.g. tensor<3xi64>): infer tensor<?x?x?x>.
// 2. constant (e.g. dense<[2, 3, 5]>): infer tensor<2x3x5x>.

// Match to check whether the `shape` operand is a constant.
DenseIntElementsAttr shapeAttr;
if (!matchPattern(shape, m_Constant(&shapeAttr))) {
int size = shapeOperandType.getDimSize(0);
if (isDynamicDimSize(size)) {
inferredReturnShapes.emplace_back(elementType);
return success();
}
shapeVector.resize(size, ShapedType::kDynamic);
inferredReturnShapes.emplace_back(shapeVector, elementType);
return success();
}

// `shape` operand is a constant.
shapeVector.reserve(shapeAttr.size());
for (const APInt& fp : shapeAttr.getValues<APInt>())
shapeVector.push_back(fp.getSExtValue());
for (const APInt& dimSize : shapeAttr.getValues<APInt>())
shapeVector.push_back(dimSize.getSExtValue());
inferredReturnShapes.emplace_back(shapeVector, elementType);
return success();
}
Expand Down
36 changes: 18 additions & 18 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,22 @@ func.func @rng_bit_generator_dynamic(%arg0: tensor<?xui64>) -> (tensor<?xui64>,

// -----

func.func @rng_dynamic_dim(%a: tensor<f32>, %b: tensor<f32>, %shape: tensor<?xi64>) -> tensor<*xf32> {
// expected-error@+1 {{op operand #2 must be statically shaped}}
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<?xi64>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @rng_unranked_output(%a: tensor<f32>, %b: tensor<f32>, %shape: tensor<3xi64>) -> tensor<*xf32> {
// expected-error@+1 {{op result #0 must be ranked}}
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<3xi64>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

// CHECK-LABEL: func @rng_normal
func.func @rng_normal(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<2x3x5xf32> {
%cst = "stablehlo.constant"() {value = dense<[2, 3, 5]> : tensor<3xi64>} : () -> tensor<3xi64>
Expand All @@ -2142,14 +2158,6 @@ func.func @rng_normal_no_constant(%a: tensor<f32>, %b: tensor<f32>, %shape: tens

// -----

// CHECK-LABEL: func @rng_normal_dynamic_dim
func.func @rng_normal_dynamic_dim(%a: tensor<f32>, %b: tensor<f32>, %shape: tensor<?xi64>) -> tensor<*xf32> {
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<?xi64>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @rng_normal_invalid_shape(%arg0: tensor<f32>, %arg1: tensor<f32>) {
%cst = "stablehlo.constant"() {value = dense<7> : tensor<1xi64>} : () -> tensor<1xi64>
// expected-error@+2 {{failed to infer returned types}}
Expand Down Expand Up @@ -2180,7 +2188,7 @@ func.func @rng_normal_invalid_sigma_rank(%mu: tensor<f32>, %sigma: tensor<1xf32>

func.func @rng_normal_invalid_shape_rank(%mu: tensor<f32>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64>
// expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}}
// expected-error@+1 {{operand #2 must be statically shaped 1-dimensional tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}}
%0 = "stablehlo.rng"(%mu, %sigma, %shape) {rng_distribution = #stablehlo<rng_distribution NORMAL>}: (tensor<f32>, tensor<f32>, tensor<1x3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}
Expand Down Expand Up @@ -2213,14 +2221,6 @@ func.func @rng_uniform_no_constant(%a: tensor<f32>, %b: tensor<f32>, %shape: ten

// -----

// CHECK-LABEL: func @rng_uniform_dynamic_dim
func.func @rng_uniform_dynamic_dim(%a: tensor<f32>, %b: tensor<f32>, %shape: tensor<?xi64>) -> tensor<*xf32> {
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<?xi64>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// -----

func.func @rng_uniform_invalid_shape(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<7xi64>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error @+1 {{inferred type(s) 'tensor<?x?x?x?x?x?x?xf32>' are incompatible with return type(s) of operation 'tensor<?xf32>'}}
Expand Down Expand Up @@ -2251,7 +2251,7 @@ func.func @rng_uniform_invalid_b_rank(%a: tensor<f32>, %b: tensor<1xf32>) -> ten

func.func @rng_uniform_invalid_shape_rank(%a: tensor<f32>, %b: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = stablehlo.constant dense<[[2, 3, 5]]> : tensor<1x3xi64>
// expected-error@+1 {{operand #2 must be 1D tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}}
// expected-error@+1 {{operand #2 must be statically shaped 1-dimensional tensor of index or 4/8/16/32/64-bit signless integer or 4/8/16/32/64-bit unsigned integer values, but got 'tensor<1x3xi64>'}}
%0 = "stablehlo.rng"(%a, %b, %shape) {rng_distribution = #stablehlo<rng_distribution UNIFORM>}: (tensor<f32>, tensor<f32>, tensor<1x3xi64>) -> tensor<2x3x5xf32>
func.return %0 : tensor<2x3x5xf32>
}
Expand Down
6 changes: 3 additions & 3 deletions stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,11 @@ func.func @refine_reduce_scatter_flattened_ids(%data: tensor<4x16xf32>) -> tenso
// -----

// CHECK-LABEL: @refine_rng
func.func @refine_rng(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf32> {
func.func @refine_rng(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<?xf32> {
// CHECK: stablehlo.rng{{.*}} -> tensor<4xf32>
%0 = stablehlo.constant dense<[4]> : tensor<1xi64>
%1 = stablehlo.rng %arg0, %arg1, %0, distribution = NORMAL : (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<*xf32>
func.return %1 : tensor<*xf32>
%1 = stablehlo.rng %arg0, %arg1, %0, distribution = NORMAL : (tensor<f32>, tensor<f32>, tensor<1xi64>) -> tensor<?xf32>
func.return %1 : tensor<?xf32>
}

// -----
Expand Down
14 changes: 7 additions & 7 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor<f32>) -> (tensor<f32>, tensor
}

// CHECK-LABEL: "attr_rng_distribution_uniform"
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 UNIFORM>
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

// CHECK-LABEL: "attr_rng_distribution_normal"
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down Expand Up @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>
}

// CHECK-LABEL: "op_rng"
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
// CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{
// CHECK-SAME: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_10_0.mlir.bc
Binary file not shown.
14 changes: 7 additions & 7 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor<f32>) -> (tensor<f32>, tensor
}

// CHECK-LABEL: "attr_rng_distribution_uniform"
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 UNIFORM>
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

// CHECK-LABEL: "attr_rng_distribution_normal"
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down Expand Up @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>
}

// CHECK-LABEL: "op_rng"
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
// CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{
// CHECK-SAME: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_11_0.mlir.bc
Binary file not shown.
14 changes: 7 additions & 7 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor<f32>) -> (tensor<f32>, tensor
}

// CHECK-LABEL: "attr_rng_distribution_uniform"
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 UNIFORM>
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

// CHECK-LABEL: "attr_rng_distribution_normal"
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down Expand Up @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>
}

// CHECK-LABEL: "op_rng"
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
// CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{
// CHECK-SAME: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_12_0.mlir.bc
Binary file not shown.
14 changes: 7 additions & 7 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor<f32>) -> (tensor<f32>, tensor
}

// CHECK-LABEL: "attr_rng_distribution_uniform"
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 UNIFORM>
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

// CHECK-LABEL: "attr_rng_distribution_normal"
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down Expand Up @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>
}

// CHECK-LABEL: "op_rng"
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
// CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{
// CHECK-SAME: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_13_0.mlir.bc
Binary file not shown.
14 changes: 7 additions & 7 deletions stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,20 @@ func.func @attr_rng_algorithm_philox(%arg0: tensor<f32>) -> (tensor<f32>, tensor
}

// CHECK-LABEL: "attr_rng_distribution_uniform"
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_uniform(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 UNIFORM>
rng_distribution = #stablehlo<rng_distribution UNIFORM>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

// CHECK-LABEL: "attr_rng_distribution_normal"
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @attr_rng_distribution_normal(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
// CHECK: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down Expand Up @@ -1742,13 +1742,13 @@ func.func @op_rng_bit_generator(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>
}

// CHECK-LABEL: "op_rng"
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<?xindex>) -> tensor<f32> {
func.func @op_rng(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<0xindex>) -> tensor<f32> {
// CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{
// CHECK-SAME: rng_distribution = #vhlo<rng_distribution_v1 NORMAL>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<?x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
// CHECK-SAME: }> : (!vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<!vhlo.f32_v1>, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1<!vhlo.f32_v1>
%0 = "stablehlo.rng"(%arg0, %arg1, %arg2) {
rng_distribution = #stablehlo<rng_distribution NORMAL>
} : (tensor<f32>, tensor<f32>, tensor<?xindex>) -> tensor<f32>
} : (tensor<f32>, tensor<f32>, tensor<0xindex>) -> tensor<f32>
func.return %0 : tensor<f32>
}

Expand Down
Binary file modified stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_14_0.mlir.bc
Binary file not shown.
Loading

0 comments on commit 7cf989a

Please sign in to comment.