From a1aa3dfa7643ec09ebef506041d261079ead7b4e Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Tue, 28 Nov 2023 02:45:14 +0000 Subject: [PATCH 1/4] verifier and type inference changes --- stablehlo/dialect/StablehloOps.cpp | 20 ++- stablehlo/dialect/StablehloOps.td | 14 +- stablehlo/dialect/TypeInference.cpp | 152 ++++++++++++++--- stablehlo/dialect/TypeInference.h | 12 +- stablehlo/reference/Ops.cpp | 2 +- stablehlo/tests/infer_stablehlo.mlir | 121 +++++++++++++- stablehlo/tests/ops_stablehlo.mlir | 154 +++++++++++++++++- stablehlo/tests/verify_reduce.mlir | 102 +++++++++++- stablehlo/tests/verify_reduce_window.mlir | 122 +++++++++++++- stablehlo/tests/verify_scatter.mlir | 153 ++++++++++++++++- .../tests/verify_select_and_scatter.mlir | 143 +++++++++++++++- 11 files changed, 932 insertions(+), 63 deletions(-) diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 9b0933a2d4e..b4c11274a28 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -156,7 +156,6 @@ LogicalResult ReduceScatterOp::verify() { } INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp) -INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AllReduceOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op) INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp) @@ -919,6 +918,15 @@ LogicalResult AllReduceOp::verify() { getComputation()); } +LogicalResult AllReduceOp::inferReturnTypeComponents( + MLIRContext*, std::optional location, ValueShapeRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + AllReduceOp::Adaptor adaptor(operands, attributes, properties, regions); + return hlo::inferAllReduceOp(location, adaptor.getOperand(), + adaptor.getComputation(), inferredReturnShapes); +} + //===----------------------------------------------------------------------===// // BatchNormGradOp //===----------------------------------------------------------------------===// @@ -1379,7 +1387,7 @@ LogicalResult ReduceWindowOp::inferReturnTypeComponents( location, adaptor.getInputs(), adaptor.getInitValues(), adaptor.getWindowDimensions(), adaptor.getWindowStrides(), adaptor.getBaseDilations(), adaptor.getWindowDilations(), - adaptor.getPadding(), inferredReturnShapes); + adaptor.getPadding(), adaptor.getBody(), inferredReturnShapes); } LogicalResult ReduceWindowOp::verify() { @@ -1782,7 +1790,8 @@ LogicalResult ReduceOp::inferReturnTypeComponents( ReduceOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferReduceOp(location, adaptor.getInputs().getTypes(), adaptor.getInitValues().getTypes(), - adaptor.getDimensions(), inferredReturnShapes); + adaptor.getDimensions(), adaptor.getBody(), + inferredReturnShapes); } LogicalResult ReduceOp::verify() { @@ -2295,8 +2304,8 @@ LogicalResult SelectAndScatterOp::inferReturnTypes( SmallVectorImpl& inferredReturnTypes) { SelectAndScatterOp::Adaptor adaptor(operands, attributes, properties, regions); - return hlo::inferSelectAndScatterOp(adaptor.getOperand(), - inferredReturnTypes); + return hlo::inferSelectAndScatterOp( + adaptor.getOperand(), adaptor.getScatter(), inferredReturnTypes); } LogicalResult SelectAndScatterOp::verify() { @@ -2316,6 +2325,7 @@ LogicalResult ScatterOp::inferReturnTypes( SmallVectorImpl& inferredReturnTypes) { ScatterOp::Adaptor adaptor(operands, attributes, properties, regions); return hlo::inferScatterOp(location, adaptor.getInputs(), + adaptor.getUpdateComputation(), inferredReturnTypes); } diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 80b6e1ea86e..1d37c7e3628 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1327,7 +1327,7 @@ def StableHLO_AllGatherOp : StableHLO_Op<"all_gather", } def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce", - [HLO_CompatibleOperandsAndResultType /*all_reduce_c6*/]> { + [InferTensorType /*all_reduce_c6, all_reduce_c7*/]> { let summary = "AllReduce operation"; let description = [{ Within each process group in the process grid, applies a reduction function @@ -1362,8 +1362,7 @@ def StableHLO_AllReduceOp : StableHLO_Op<"all_reduce", let hasVerifier = 1; } -def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", - [SameOperandsAndResultElementType /*reduce_scatter_c8*/]> { +def StableHLO_ReduceScatterOp : StableHLO_Op<"reduce_scatter", []> { let summary = "ReduceScatter operation"; let description = [{ Within each process group in the process grid, performs reduction, using @@ -1448,7 +1447,7 @@ def StableHLO_AllToAllOp : StableHLO_Op<"all_to_all", def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [ RecursiveMemoryEffects, SameVariadicOperandSize /*reduce_c3*/, - InferTensorTypeWithReify /*reduce_c7*/, + InferTensorTypeWithReify /*reduce_c7, reduce_c8*/, SingleBlockImplicitTerminator<"ReturnOp"> ]> { /*reduce_c7*/ let summary = "Reduce operation"; @@ -2514,7 +2513,8 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [ def StableHLO_ScatterOp: StableHLO_Op<"scatter", [RecursiveMemoryEffects, SameVariadicOperandSize /*scatter_c5*/, - DeclareOpInterfaceMethods /*scatter_c16*/]> { + DeclareOpInterfaceMethods /*scatter_c16, + scater_c17*/]> { let summary = "Scatter operation"; let description = [{ Produces `results` tensors which are equal to `inputs` tensors except that @@ -2587,8 +2587,8 @@ def StableHLO_SelectOp: StableHLO_Op<"select", [Pure, HLO_BroadcastingElementwis } def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter", - [DeclareOpInterfaceMethods /*select_and_scatter_c11*/, - RecursiveMemoryEffects]> { + [DeclareOpInterfaceMethods /*select_and_scatter_c11, + select_and_scatter_c12*/, RecursiveMemoryEffects]> { let summary = "SelectAndScatter operation"; let description = [{ Scatters the values from the `source` tensor using `scatter` based on the diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 651d384753a..85a1ab2f234 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -97,6 +97,51 @@ bool tensorsHaveSameElType(Type type1, Type type2, return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision); } +unsigned potentiallyComplexBitWidth(Type type) { + auto complexTy = type.dyn_cast(); + return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth() + : type.getIntOrFloatBitWidth(); +} + +// Returns true if the element-type of type1 can be promoted to that of type2. +// An element-type 'x' is promotatble to element-type 'y' is they have the same +// base type and bitwidth(x) <= bitwidth(y). When 'x' and 'y' are quantized +// element-types, then promotion is applied only to the 'storage_type' +// component. +bool isPromotableElementType(Type type1, Type type2, + bool ignoreFpPrecision = false) { + auto tensorTy1 = type1.dyn_cast(); + auto tensorTy2 = type2.dyn_cast(); + + if (!tensorTy1 || !tensorTy2) return false; + + Type tensorEl1 = tensorTy1.getElementType(); + Type tensorEl2 = tensorTy2.getElementType(); + + if (ignoreFpPrecision && tensorEl1.isa() && + tensorTy2.getElementType().isa()) + return true; + + bool isSameType = + (tensorEl1.isa() and tensorEl2.isa()) || + (tensorEl1.isa() and tensorEl2.isa()) || + (tensorEl1.isa() and tensorEl2.isa()) || + (tensorEl1.isa() and + tensorEl2.isa()); + + if (!isSameType) return false; + + if (!tensorEl1.isa()) + return potentiallyComplexBitWidth(tensorEl1) <= + potentiallyComplexBitWidth(tensorEl2); + + auto quantType1 = tensorEl1.cast(); + auto quantType2 = tensorEl2.cast(); + return quantType1.getExpressedType() == quantType2.getExpressedType() && + potentiallyComplexBitWidth(quantType1.getStorageType()) <= + potentiallyComplexBitWidth(quantType2.getStorageType()); +} + // Return true if type1 and type2 are shape-compatible and have same element // type. If 'ignoreFpPrecision' is True, then allow floats with different // precisions while checking element-types. @@ -405,12 +450,6 @@ SmallVector inferWindowOutputShape(ArrayRef baseShape, return outputDimensions; } -unsigned potentiallyComplexBitWidth(Type type) { - auto complexTy = type.dyn_cast(); - return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth() - : type.getIntOrFloatBitWidth(); -} - LogicalResult verifyReplicaGroups(std::optional location, DenseIntElementsAttr replicaGroups, bool allGroupsMustHaveSameSize, @@ -530,6 +569,17 @@ LogicalResult verifyReduceOpInputsAndInferShape( return success(); } +// Returns the types of the terminator arguments of the input mlir::Block +// 'block'. +SmallVector getAccumulatorTypes(Block& block) { + SmallVector accumulatorSubShapes; + for (Value retOperand : block.getTerminator()->getOperands()) { + auto shapedTy = retOperand.getType().dyn_cast(); + accumulatorSubShapes.push_back(shapedTy); + } + return accumulatorSubShapes; +} + LogicalResult verifyReducerShape(std::optional loc, Block& block, ArrayRef inputTypes, ArrayRef initValueTypes, @@ -598,24 +648,37 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, // all_reduce_c5, reduce_c6, reduce_scatter_c7, reduce_window_c13, // reduce_window_i2, scatter_c6, scatter_c15, select_and_scatter_c10 - if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx], - initValueTypes[inputIdx], - /*ignoreFpPrecision=*/true)) + if (failed(verifyCompatibleShape(initValueTypes[inputIdx], + accumulatorSubShapes[inputIdx]))) + return emitOptionalError( + loc, "The shape of reduction-region's result type at index ", + inputIdx, " differs from the op's corresponding init-value type: ", + accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]); + + if (!isPromotableElementType(initValueTypes[inputIdx], + accumulatorSubShapes[inputIdx], + /*ignoreFpPrecision=*/true)) return emitOptionalError( - loc, "The type of reduction-region's result type at index ", inputIdx, - " differs from the op's corresponding init-value type: ", + loc, "The element-type of reduction-region's result type at index ", + inputIdx, + " is expected to be promotable from the op's corresponding " + "init-value element-type: ", accumulatorSubShapes[inputIdx], " vs ", initValueTypes[inputIdx]); // reduce_c6, reduce_window_c3, scatter_c6, scatter_c15, // select_and_scatter_c10 - if (!tensorsHaveSameElType( + if (!isPromotableElementType( inputTypes[inputIdx], - block.getArgument(numInputs + inputIdx).getType(), true)) + block.getArgument(numInputs + inputIdx).getType(), + /*ignoreFpPrecision=*/true)) return emitOptionalError( loc, "The element-type of reduction-region's argument at index ", - numInputs + inputIdx, " is expected to be ", + numInputs + inputIdx, " is expected to be promotable from ", inputTypes[inputIdx].getElementType(), ", but got ", - block.getArgument(numInputs + inputIdx).getType(), " as its type."); + block.getArgument(numInputs + inputIdx) + .getType() + .cast() + .getElementType()); Type blockArgType = block.getArgument(numInputs + inputIdx).getType(); auto blockArgTensorTy = blockArgType.cast(); @@ -1453,6 +1516,18 @@ LogicalResult inferAllToAllOp( return success(); } +LogicalResult inferAllReduceOp( + std::optional location, Value operand, Region& computation, + SmallVectorImpl& inferredReturnShapes) { + // all_reduce_c6, all_reduce_c7 + SmallVector accumulatorTypes = + getAccumulatorTypes(computation.front()); + auto operandShapedTy = operand.getType().cast(); + inferredReturnShapes.emplace_back(getSameShapeTensorType( + operandShapedTy, accumulatorTypes[0].getElementType())); + return success(); +} + LogicalResult inferBatchNormGradOp( std::optional location, Value operand, Value scale, Value mean, Value variance, Value gradOutput, int64_t featureIndex, @@ -2532,7 +2607,7 @@ LogicalResult inferRealOp(std::optional, Value operand, LogicalResult inferReduceOp( std::optional location, TypeRange inputTypes, - TypeRange initValueTypes, DenseIntElementsAttr dimensions, + TypeRange initValueTypes, DenseIntElementsAttr dimensions, Region& body, SmallVectorImpl& inferredReturnShapes) { SmallVector inputArgTensorTypes{ llvm::map_range(inputTypes, [](Type t) { return t.cast(); })}; @@ -2546,10 +2621,11 @@ LogicalResult inferReduceOp( initValueTensorTypes, dimensions, newDimensions, encoding))) return failure(); - // reduce_c2, reduce_c3, reduce_c7 + // reduce_c3, reduce_c7, reduce_c8 + SmallVector accumulatorTypes = getAccumulatorTypes(body.front()); for (uint64_t inputIdx = 0; inputIdx < inputTypes.size(); ++inputIdx) { ShapedType inputType = inputArgTensorTypes[inputIdx]; - Type elementType = inputType.getElementType(); + Type elementType = accumulatorTypes[inputIdx].getElementType(); if (inputType.hasRank()) inferredReturnShapes.emplace_back(newDimensions, elementType, encoding); else @@ -2565,7 +2641,7 @@ LogicalResult inferReduceWindowOp( std::optional windowStrides, std::optional baseDilations, std::optional windowDilations, - std::optional padding, + std::optional padding, Region& body, SmallVectorImpl& inferredReturnShapes) { SmallVector inputTypes{llvm::map_range( inputs.getTypes(), [](Type t) { return t.cast(); })}; @@ -2582,21 +2658,22 @@ LogicalResult inferReduceWindowOp( return failure(); // reduce_window_c1, reduce_window_c14...reduce_window_c16 + SmallVector accumulatorTypes = getAccumulatorTypes(body.front()); for (size_t i = 0; i < inputTypes.size(); ++i) { auto inputRankedType = inputs[i].getType().dyn_cast(); if (!inputRankedType) { - inferredReturnShapes.emplace_back(inputTypes[i].getElementType()); + inferredReturnShapes.emplace_back(accumulatorTypes[i].getElementType()); } else { auto resultShape = inferWindowOutputShape(inputTypes[i].getShape(), inferredWindow); auto inputBounds = encodingToBounds(inputRankedType.getEncoding()); if (inputBounds.empty()) { inferredReturnShapes.emplace_back(resultShape, - inputTypes[i].getElementType()); + accumulatorTypes[i].getElementType()); } else { auto resultBounds = inferWindowOutputShape(inputBounds, inferredWindow); inferredReturnShapes.emplace_back( - resultShape, inputTypes[i].getElementType(), + resultShape, accumulatorTypes[i].getElementType(), boundsToEncoding(inputRankedType.getEncoding(), resultBounds)); } } @@ -2661,8 +2738,16 @@ LogicalResult inferRngOp( } LogicalResult inferScatterOp(std::optional, ValueRange inputs, + Region& update_computation, SmallVectorImpl& inferredReturnTypes) { - llvm::append_range(inferredReturnTypes, inputs.getTypes()); + // scatter_c16, scatter_c17 + SmallVector accumulatorTypes = + getAccumulatorTypes(update_computation.front()); + for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) { + auto inputShapedTy = inputs[inputIdx].getType().cast(); + inferredReturnTypes.push_back(getSameShapeTensorType( + inputShapedTy, accumulatorTypes[inputIdx].getElementType())); + } return success(); } @@ -2692,9 +2777,14 @@ LogicalResult inferSelectOp( } LogicalResult inferSelectAndScatterOp( - Value operand, SmallVectorImpl& inferredReturnTypes) { - // select_and_scatter_c11 - inferredReturnTypes.push_back(operand.getType()); + Value operand, Region& scatter, + SmallVectorImpl& inferredReturnTypes) { + // select_and_scatter_c11, select_and_scatter_c12 + SmallVector accumulatorTypes = + getAccumulatorTypes(scatter.front()); + auto operandShapedTy = operand.getType().cast(); + inferredReturnTypes.push_back(getSameShapeTensorType( + operandShapedTy, accumulatorTypes[0].getElementType())); return success(); } @@ -3821,6 +3911,16 @@ LogicalResult verifyReduceScatterOp(std::optional location, operandType.getDimSize(index), ") and result (", resultType.getDimSize(index), ")"); } + + // reduce_scatter_c9 + SmallVector accumulatorTypes = + getAccumulatorTypes(computation.front()); + if (resultType.getElementType() != accumulatorTypes[0].getElementType()) { + return emitOptionalError(location, "result element-type is expected to be ", + accumulatorTypes[0].getElementType(), ", but got ", + resultType.getElementType()); + } + return success(); } diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index 107e98392fe..abc87e22d2a 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -120,6 +120,10 @@ LogicalResult inferAllToAllOp( DenseIntElementsAttr replicaGroups, SmallVectorImpl& inferredReturnShapes); +LogicalResult inferAllReduceOp( + std::optional location, Value operand, Region& computation, + SmallVectorImpl& inferredReturnShapes); + LogicalResult inferBatchNormGradOp( std::optional location, Value operand, Value scale, Value mean, Value variance, Value gradOutput, int64_t featureIndex, @@ -281,7 +285,7 @@ LogicalResult inferRealOp(std::optional location, Value operand, LogicalResult inferReduceOp( std::optional location, TypeRange inputTypes, - TypeRange initValueTypes, DenseIntElementsAttr dimensions, + TypeRange initValueTypes, DenseIntElementsAttr dimensions, Region& body, SmallVectorImpl& inferredReturnShapes); LogicalResult inferReduceWindowOp( @@ -290,7 +294,7 @@ LogicalResult inferReduceWindowOp( std::optional windowStrides, std::optional baseDilations, std::optional windowDilations, - std::optional padding, + std::optional padding, Region& body, SmallVectorImpl& inferredReturnShapes); LogicalResult inferReplicaIdOp(MLIRContext* context, std::optional, @@ -306,7 +310,7 @@ LogicalResult inferRngOp( SmallVectorImpl& inferredReturnShapes); LogicalResult inferScatterOp(std::optional location, - ValueRange inputs, + ValueRange inputs, Region& update_computation, SmallVectorImpl& inferredReturnTypes); LogicalResult inferSelectOp( @@ -314,7 +318,7 @@ LogicalResult inferSelectOp( SmallVectorImpl& inferredReturnShapes); LogicalResult inferSelectAndScatterOp( - Value operand, SmallVectorImpl& inferredReturnTypes); + Value operand, Region& scatter, SmallVectorImpl& inferredReturnTypes); LogicalResult inferSendOp(HloDialectInterface* dialect, std::optional location, diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 58bbac38dac..a24021f1c08 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -82,7 +82,7 @@ SmallVector evalReduceOp(ArrayRef inputs, Builder builder(inputs[0].getType().getContext()); auto reduceStatus = hlo::inferReduceOp( /*location=*/{}, inputTypes, initValueTypes, - builder.getI64TensorAttr(dimensions), inferredReduceTypes); + builder.getI64TensorAttr(dimensions), body, inferredReduceTypes); if (failed(reduceStatus)) report_fatal_error( invalidArgument("Could not infer ReduceOp's return type")); diff --git a/stablehlo/tests/infer_stablehlo.mlir b/stablehlo/tests/infer_stablehlo.mlir index e1cfe718520..727cf770604 100644 --- a/stablehlo/tests/infer_stablehlo.mlir +++ b/stablehlo/tests/infer_stablehlo.mlir @@ -102,6 +102,24 @@ func.func @cholesky(%arg0: tensor<1x2x2xf32>) -> tensor<1x2x2xindex> { // ----- +// CHECK-LABEL: func @all_reduce_c6_c7 +func.func @all_reduce_c6_c7(%operand: tensor<10xf32>) -> tensor<10xindex> { + + %0 = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<10xf32>) -> tensor<10xf64> + // CHECK: types0 = tensor<10xf64> + %1 = "hlo_test_infer.get_return_types"(%0) : (tensor<10xf64>) -> tensor<10xindex> + func.return %1 : tensor<10xindex> +} + +// ----- + // CHECK-LABEL: func @all_to_all_c9 func.func @all_to_all_c9(%data: tensor<4x16xf32>) -> tensor<16x4xindex> { %0 = "stablehlo.all_to_all"(%data) { @@ -584,8 +602,8 @@ func.func @after_all_empty_arg() -> !stablehlo.token { // ----- -// CHECK: func @select_and_scatter_c11 -func.func @select_and_scatter_c11( +// CHECK: func @select_and_scatter_c11_c12 +func.func @select_and_scatter_c11_c12( %arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xindex> { %0 = stablehlo.constant dense<0.000000e+00> : tensor @@ -612,8 +630,35 @@ func.func @select_and_scatter_c11( // ----- -// CHECK-LABEL: func @scatter_c16 -func.func @scatter_c16(%input_tensor: tensor<200x100x300xf32>, +// CHECK: func @select_and_scatter_c11_c12 +func.func @select_and_scatter_c11_c12( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xindex> { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + compare_type = #stablehlo, + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + %2 = "hlo_test_infer.get_return_types"(%1) : (tensor<10x24x24x64xf64>) -> tensor<10x24x24x64xindex> + func.return %2 : tensor<10x24x24x64xindex> +} + +// ----- + +// CHECK-LABEL: func @scatter_c16_c17 +func.func @scatter_c16_c17(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xindex> { %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ @@ -638,6 +683,32 @@ func.func @scatter_c16(%input_tensor: tensor<200x100x300xf32>, // ----- +// CHECK-LABEL: func @scatter_c16_c17 +func.func @scatter_c16_c17(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xindex> { + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + // CHECK: types0 = tensor<200x100x300xf64> + %1 = "hlo_test_infer.get_return_types"(%0) : (tensor<200x100x300xf64>) -> tensor<200x100x300xindex> + func.return %1 : tensor<200x100x300xindex> +} + +// ----- + // CHECK-LABEL: func @scatter_bounds func.func @scatter_bounds(%input_tensor: tensor<200x?x?xf32, #stablehlo.bounds>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> @@ -864,7 +935,24 @@ func.func @reduce_c7(%arg0: tensor<7x5xf32>, %arg1 : tensor<5xf32>) -> tensor<6x // ----- -func.func @reduce_c7(%arg0: tensor, %arg1: tensor, +func.func @reduce_c8(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf32>) { + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1{{'stablehlo.reduce' op inferred type(s) 'tensor<4xf64>' are incompatible with return type(s) of operation 'tensor<4xf32>'}} + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf32> + + func.return %0: tensor<4xf32> +} + +// ----- + +func.func @reduce_c3_c7(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{inferred type(s) 'tensor', 'tensor' are incompatible with return type(s) of operation 'tensor', 'tensor', 'tensor'}} @@ -882,7 +970,7 @@ func.func @reduce_c7(%arg0: tensor, %arg1: tensor, // ----- -func.func @reduce_c7(%arg0: tensor, %arg1: tensor, +func.func @reduce_c7_c8(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'stablehlo.reduce' op inferred type(s) 'tensor', 'tensor' are incompatible with return type(s) of operation 'tensor', 'tensor'}} @@ -900,7 +988,7 @@ func.func @reduce_c7(%arg0: tensor, %arg1: tensor, // ----- -func.func @reduce_c7(%arg0: tensor, %arg1 : tensor) +func.func @reduce_c8(%arg0: tensor, %arg1 : tensor) -> (tensor) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'stablehlo.reduce' op inferred type(s) 'tensor' are incompatible with return type(s) of operation 'tensor'}} @@ -1023,6 +1111,25 @@ func.func @reduce_window_c16(%arg0: tensor<4x2xf32>, // ----- +func.func @reduce_window_c16(%arg0: tensor<4x2xf32>, %init0: tensor) -> + (tensor<2x2xf32>) { + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1 {{inferred type(s) 'tensor<2x2xf64>' are incompatible with return type(s) of operation 'tensor<2x2xf32>'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor, %b0: tensor): + %1 = stablehlo.add %a0, %b0 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2xf32>, tensor) -> (tensor<2x2xf32>) + func.return %0 : tensor<2x2xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // Bounded Dynamism //===----------------------------------------------------------------------===// diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index c7b33ddd21e..2b7c6a4a3c3 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -27,6 +27,40 @@ func.func @all_reduce(%operand: tensor<10xf32>) -> tensor<10xf32> { func.return %0 : tensor<10xf32> } +// ----- + +// CHECK-LABEL: func @all_reduce_with_promotable_types +func.func @all_reduce_with_promotable_types(%operand: tensor) -> tensor { + + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor) -> tensor + + func.return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @all_reduce_with_promotable_quantized_types +func.func @all_reduce_with_promotable_quantized_types(%operand: tensor>) + -> tensor> { + + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor>, %arg1: tensor>): + %0 = stablehlo.add %arg0, %arg1 : tensor> + "stablehlo.return"(%0) : (tensor>) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor>) -> tensor> + + func.return %result : tensor> +} // ----- @@ -186,7 +220,7 @@ func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { // ----- func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0 = "stablehlo.all_reduce"(%operand) ({ ^bb0(%arg0: tensor, %arg1: tensor): %max = stablehlo.maximum %arg0, %arg1 : tensor @@ -201,7 +235,7 @@ func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { // ----- func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor<4xf32>' vs 'tensor'}} + // expected-error@+1 {{The shape of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor<4xf32>' vs 'tensor'}} %0 = "stablehlo.all_reduce"(%operand) ({ ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): %max = stablehlo.maximum %arg0, %arg1 : tensor<4xf32> @@ -215,6 +249,40 @@ func.func @all_reduce_c5(%operand: tensor<10xf32>) -> tensor<10xf32> { // ----- +func.func @all_reduce_c5(%operand: tensor) -> tensor { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor) -> tensor + + func.return %result : tensor +} +// ----- + +func.func @all_reduce_c5(%operand: tensor>) + -> tensor> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor>, %arg1: tensor>): + %0 = stablehlo.add %arg0, %arg1 : tensor> + "stablehlo.return"(%0) : (tensor>) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor>) -> tensor> + + func.return %result : tensor> +} + +// ----- + // CHECK-LABEL: func @reduce_scatter func.func @reduce_scatter(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { %0 = "stablehlo.reduce_scatter"(%data) ({ @@ -245,6 +313,38 @@ func.func @reduce_scatter_dynamic(%data: tensor) -> tensor { // ----- +// CHECK-LABEL: func @reduce_scatter_with_promotable_types +func.func @reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + +// ----- + +// CHECK-LABEL: func @reduce_scatter_with_promotable_quantized_types +func.func @reduce_scatter_with_promotable_quantized_types( + %data: tensor<4x16x!quant.uniform>) -> + tensor<4x4x!quant.uniform> { + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.add %arg2, %arg3 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> + func.return %0 : tensor<4x4x!quant.uniform> +} + +// ----- + func.func @reduce_scatter_c2(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{expects scatter_dimension >= 0}} %0 = "stablehlo.reduce_scatter"(%data) ({ @@ -404,7 +504,7 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // ----- func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0 = "stablehlo.reduce_scatter"(%data) ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = stablehlo.add %arg2, %arg3 : tensor @@ -416,6 +516,39 @@ func.func @reduce_scatter_c7(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // ----- +func.func @reduce_scatter_c7(%data: tensor<4x16xi32>) -> tensor<4x4xi8> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xi32>) -> tensor<4x4xi8> + func.return %0 : tensor<4x4xi8> +} + +// ----- + +func.func @reduce_scatter_c7(%data: tensor<4x16x!quant.uniform>) + -> tensor<4x4x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor>, %arg3: tensor>): + %1 = stablehlo.add %arg2, %arg3 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16x!quant.uniform>) -> tensor<4x4x!quant.uniform> + func.return %0 : tensor<4x4x!quant.uniform> +} + +// ----- + func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<4xf32> { // expected-error@+1 {{operand and result should have same rank}} %0 = "stablehlo.reduce_scatter"(%data) ({ @@ -455,6 +588,21 @@ func.func @reduce_scatter_c8(%data: tensor<4x16xf32>) -> tensor<3x4xf32> { // ----- +func.func @reduce_scatter_c9(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { + // expected-error@+1 {{result element-type is expected to be 'f64', but got 'f32'}} + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// ----- + func.func @reduce_scatter_i3(%data: tensor<4x16xf32>) -> tensor<4x4xf32> { // expected-error@+1 {{replica groups should be a rank 2 tensor}} %0 = "stablehlo.reduce_scatter"(%data) ({ diff --git a/stablehlo/tests/verify_reduce.mlir b/stablehlo/tests/verify_reduce.mlir index 8915f4941aa..f48a2db01f7 100644 --- a/stablehlo/tests/verify_reduce.mlir +++ b/stablehlo/tests/verify_reduce.mlir @@ -96,6 +96,35 @@ func.func @reduce_mix_rank_and_unranked(%arg0: tensor<4x4xf32>, %arg1: tensor<*x // ----- +// CHECK-LABEL: func @reduce_with_promotable_types +func.func @reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + +// ----- + +// CHECK-LABEL: func @reduce_with_promotable_quantized_types +func.func @reduce_with_promotable_quantized_types(%arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor>) -> tensor<4x!quant.uniform> { + %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, tensor>) -> tensor<4x!quant.uniform> + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = stablehlo.add %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + } + return %0 : tensor<4x!quant.uniform> +} + +// ----- + func.func @reduce_c1(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor, %arg3: tensor) -> (tensor<2xf32>, tensor<2xf32>) { @@ -307,7 +336,7 @@ func.func @reduce_c6(%arg0: tensor, %arg1: tensor, func.func @reduce_c6(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0:2 = "stablehlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor): @@ -325,7 +354,7 @@ func.func @reduce_c6(%arg0: tensor, %arg1: tensor, func.func @reduce_c6(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - // expected-error@+1 {{The type of reduction-region's result type at index 1 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 1 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0:2 = "stablehlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor): @@ -343,7 +372,7 @@ func.func @reduce_c6(%arg0: tensor, %arg1: tensor, func.func @reduce_c6(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { - // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be 'i32', but got 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be promotable from 'i32', but got 'f32'}} %0:2 = "stablehlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor): @@ -392,6 +421,73 @@ func.func @reduce_c6(%arg0: tensor<8x5xf32>, %arg1 : tensor<4xf32>) // ----- + +func.func @reduce_c6(%arg0: tensor<4x4xi32>, %arg1 : tensor) + -> (tensor<4xi8>) { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xi32>, tensor) -> tensor<4xi8> + + func.return %0: tensor<4xi8> +} + +// ----- + +func.func @reduce_c6(%arg0: tensor<4x4xi32>, %arg1 : tensor) + -> (tensor<4xi8>) { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'i8'}} + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xi32>, tensor) -> tensor<4xi8> + + func.return %0: tensor<4xi8> +} + +// ----- + +func.func @reduce_c6(%arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor>) -> tensor<4x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, + tensor>) -> tensor<4x!quant.uniform> + + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = stablehlo.add %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + } + return %0 : tensor<4x!quant.uniform> +} + +// ----- + +func.func @reduce_c6(%arg0: tensor<4x4x!quant.uniform>, + %arg1: tensor>) -> tensor<4x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + %0 = stablehlo.reduce(%arg0 init: %arg1) across dimensions = [0] : (tensor<4x4x!quant.uniform>, + tensor>) -> tensor<4x!quant.uniform> + + reducer(%arg2: tensor>, %arg3: tensor>) { + %1 = stablehlo.add %arg2, %arg3 : tensor> + stablehlo.return %1 : tensor> + } + return %0 : tensor<4x!quant.uniform> +} + +// ----- + func.func @reduce_i3(%input: tensor<1x6xi64>, %init_value: tensor) -> tensor<1xi64> { // expected-error@+1 {{dimensions must be rank 1}} %0 = "stablehlo.reduce"(%input, %init_value) ({ diff --git a/stablehlo/tests/verify_reduce_window.mlir b/stablehlo/tests/verify_reduce_window.mlir index f410243b1b9..62d9b4d80e0 100644 --- a/stablehlo/tests/verify_reduce_window.mlir +++ b/stablehlo/tests/verify_reduce_window.mlir @@ -82,6 +82,46 @@ func.func @reduce_window_with_unranked_dynamic_dims(%arg0: tensor<*xf32>, // ----- +// CHECK-LABEL: func @reduce_window_with_promotable_types +func.func @reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = stablehlo.add %a0, %b0 : tensor + %3 = stablehlo.add %a1, %b1 : tensor + "stablehlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @reduce_window_with_promotable_quantized_types +func.func @reduce_window_with_promotable_quantized_types(%arg0: tensor<4x2x!quant.uniform>, + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = stablehlo.add %a0, %b0 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + func.func @reduce_window_c1(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { @@ -123,7 +163,7 @@ func.func @reduce_window_c2(%arg0: tensor<4x2xf32>, func.func @reduce_window_c3(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xf32>) { - // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be 'i32', but got 'tensor' as its type.}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 3 is expected to be promotable from 'i32', but got 'f32'}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -490,7 +530,7 @@ func.func @reduce_window_c13(%arg0: tensor<4x2xf32>, func.func @reduce_window_c13(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error@+1 {{The type of reduction-region's result type at index 1 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 1 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): @@ -544,10 +584,86 @@ func.func @reduce_window_c13(%arg0: tensor<4x2xf32>, %init0: tensor<4x2xf32>) // ----- +func.func @reduce_window_c13(%arg0: tensor<4x2xi32>, %init0: tensor) -> + (tensor<2x2xi8>) { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor, %b0: tensor): + %1 = stablehlo.add %a0, %b0 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2xi32>, tensor) -> (tensor<2x2xi8>) + func.return %0 : tensor<2x2xi8> +} + +// ----- + +func.func @reduce_window_c13(%arg0: tensor<4x2xi32>, %init0: tensor) -> + (tensor<2x2xi8>) { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'i8'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor, %b0: tensor): + %1 = stablehlo.add %a0, %b0 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2xi32>, tensor) -> (tensor<2x2xi8>) + func.return %0 : tensor<2x2xi8> +} + +// ----- + +func.func @reduce_window_c13(%arg0: tensor<4x2x!quant.uniform>, + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = stablehlo.add %a0, %b0 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + +func.func @reduce_window_c13(%arg0: tensor<4x2x!quant.uniform>, + %init0: tensor>) -> (tensor<2x2x!quant.uniform>) { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + %0 = "stablehlo.reduce_window"(%arg0, %init0) ({ + ^bb0(%a0: tensor>, %b0: tensor>): + %1 = stablehlo.add %a0, %b0 : tensor> + "stablehlo.return"(%1) : (tensor>) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> + } + : (tensor<4x2x!quant.uniform>, tensor>) -> (tensor<2x2x!quant.uniform>) + func.return %0 : tensor<2x2x!quant.uniform> +} + +// ----- + func.func @reduce_window_i2(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor<1xf32>, %init1: tensor<1xi32>) -> (tensor<2x2xf32>, tensor<2x2xi32>) { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor<1xf32>'}} + // expected-error@+1 {{The shape of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor<1xf32>'}} %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, %b1: tensor): diff --git a/stablehlo/tests/verify_scatter.mlir b/stablehlo/tests/verify_scatter.mlir index 1b7ec22fdcd..cc7284a6892 100644 --- a/stablehlo/tests/verify_scatter.mlir +++ b/stablehlo/tests/verify_scatter.mlir @@ -71,6 +71,55 @@ func.func @valid_scatter_dimensions_with_dynamic_index_vector_dim( // ----- +// CHECK: func @scatter_with_promotable_types +func.func @scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + +// ----- + +// CHECK: func @scatter_with_promotable_quantized_types +func.func @scatter_with_promotable_quantized_types(%input_tensor: tensor<200x100x300x!quant.uniform>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = stablehlo.add %lhs, %rhs : tensor> + "stablehlo.return"(%add) : (tensor>) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, + tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> +} + +// ----- + func.func @scatter_c1(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xi32>) -> tensor<3xi32> { // expected-error @+1 {{Not all inputs have compatible shapes.}} @@ -233,7 +282,7 @@ func.func @scatter_c4(%input_tensor: tensor<200x100x300xf32>, func.func @scatter_c6_c15(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xi32>) -> tensor<200x100x300xf32> { - // expected-error@+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = stablehlo.add %lhs, %rhs : tensor @@ -257,7 +306,7 @@ func.func @scatter_c6_c15(%input_tensor: tensor<200x100x300xf32>, func.func @scatter_c6_c15_c16(%input_tensor: tensor<200x100x300xi32>, %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> tensor<200x100x300xf32> { - // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be 'i32', but got 'tensor' as its type.}} + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'f32'}} %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ ^bb0(%lhs: tensor, %rhs: tensor): %add = stablehlo.add %lhs, %rhs : tensor @@ -794,3 +843,103 @@ func.func @scatter_c15(%input_tensor: tensor<200x100x300xf32>, tensor<200x100x300xf32> func.return %0 : tensor<200x100x300xf32> } + +// ----- + +func.func @scatter_c15(%input_tensor: tensor<200x100x300xi32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xi32>) -> + tensor<200x100x300xi8> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xi32>, tensor<10x2xi32>, tensor<10x300xi32>) -> + tensor<200x100x300xi8> + func.return %0 : tensor<200x100x300xi8> +} + +// ----- + +func.func @scatter_c15(%input_tensor: tensor<200x100x300xi32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xi8>) -> + tensor<200x100x300xi8> { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'i32', but got 'i8'}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xi32>, tensor<10x2xi32>, tensor<10x300xi8>) -> + tensor<200x100x300xi8> + func.return %0 : tensor<200x100x300xi8> +} + +// ----- + +func.func @scatter_c15(%input_tensor: tensor<200x100x300x!quant.uniform>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = stablehlo.add %lhs, %rhs : tensor> + "stablehlo.return"(%add) : (tensor>) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> +} + +// ----- + +func.func @scatter_c15(%input_tensor: tensor<200x100x300x!quant.uniform>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from '!quant.uniform', but got '!quant.uniform'}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor>, %rhs: tensor>): + %add = stablehlo.add %lhs, %rhs : tensor> + "stablehlo.return"(%add) : (tensor>) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300x!quant.uniform>, tensor<10x2xi32>, tensor<10x300x!quant.uniform>) -> + tensor<200x100x300x!quant.uniform> + func.return %0 : tensor<200x100x300x!quant.uniform> +} diff --git a/stablehlo/tests/verify_select_and_scatter.mlir b/stablehlo/tests/verify_select_and_scatter.mlir index bede69572aa..5b98b6c6afc 100644 --- a/stablehlo/tests/verify_select_and_scatter.mlir +++ b/stablehlo/tests/verify_select_and_scatter.mlir @@ -26,6 +26,62 @@ func.func @select_and_scatter( // ----- +// CHECK: func @select_and_scatter_with_promotable_types +func.func @select_and_scatter_with_promotable_types( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> () { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + func.return +} + +// ----- + +// CHECK: func @select_and_scatter_with_promotable_quantized_types +func.func @select_and_scatter_with_promotable_quantized_types( + %arg0: tensor<10x24x24x64x!quant.uniform>, + %arg1: tensor<10x12x12x64x!quant.uniform>, + %arg2 : tensor>) -> + tensor<10x24x24x64x!quant.uniform> { + + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = "stablehlo.compare"(%arg3, %arg4) { + compare_type = #stablehlo, + comparison_direction = #stablehlo + } : (tensor>, tensor>) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = stablehlo.add %arg3, %arg4 : tensor> + "stablehlo.return"(%2) : (tensor>) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64x!quant.uniform>, + tensor<10x12x12x64x!quant.uniform>, + tensor>) -> + tensor<10x24x24x64x!quant.uniform> + func.return %1 : tensor<10x24x24x64x!quant.uniform> +} + +// ----- + // CHECK: func @select_and_scatter_with_unranked_dims func.func @select_and_scatter_with_unranked_dims( %arg0: tensor<4x5x1x1xbf16>, @@ -607,7 +663,7 @@ func.func @select_and_scatter_c10( %arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> () { %0 = stablehlo.constant dense<0.000000e+00> : tensor - // expected-error @+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor' vs 'tensor'}} + // expected-error @+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "stablehlo.compare"(%arg3, %arg4) { @@ -633,7 +689,7 @@ func.func @select_and_scatter_c10( %arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> () { %0 = stablehlo.constant dense<0> : tensor - // expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be 'f32', but got 'tensor' as its type.}} + // expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'f32', but got 'i32'}} %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ ^bb0(%arg3: tensor, %arg4: tensor): %2 = "stablehlo.compare"(%arg3, %arg4) { @@ -678,3 +734,86 @@ func.func @select_and_scatter_c10( tensor<10x24x24x64xf32> func.return } + +// ----- + +func.func @select_and_scatter_c10( + %arg0: tensor<10x24x24x64xi32>, + %arg1: tensor<10x12x12x64xi32>) -> tensor<10x24x24x64xi8> { + %0 = stablehlo.constant dense<0> : tensor + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor' vs 'tensor'}} + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + compare_type = #stablehlo, + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xi32>, tensor<10x12x12x64xi32>, tensor) -> + tensor<10x24x24x64xi8> + func.return %1 : tensor<10x24x24x64xi8> +} + +// ----- + +func.func @select_and_scatter_c10( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> () { + %0 = stablehlo.constant dense<0> : tensor + // expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be promotable from 'f32', but got 'i8'}} + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf32> + func.return +} + +// ----- + +func.func @select_and_scatter_c10( + %arg0: tensor<10x24x24x64x!quant.uniform>, + %arg1: tensor<10x12x12x64x!quant.uniform>, + %arg2 : tensor>) -> + tensor<10x24x24x64x!quant.uniform> { + + // expected-error@+1 {{The element-type of reduction-region's result type at index 0 is expected to be promotable from the op's corresponding init-value element-type: 'tensor>' vs 'tensor>'}} + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = "stablehlo.compare"(%arg3, %arg4) { + compare_type = #stablehlo, + comparison_direction = #stablehlo + } : (tensor>, tensor>) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor>, %arg4: tensor>): + %2 = stablehlo.add %arg3, %arg4 : tensor> + "stablehlo.return"(%2) : (tensor>) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64x!quant.uniform>, + tensor<10x12x12x64x!quant.uniform>, + tensor>) -> + tensor<10x24x24x64x!quant.uniform> + func.return %1 : tensor<10x24x24x64x!quant.uniform> +} From 10d2d4a795ef6c366e936ef0bf8da187190d1a30 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Fri, 15 Dec 2023 17:41:56 +0000 Subject: [PATCH 2/4] compatibility constraint checks --- stablehlo/dialect/TypeInference.cpp | 2 +- stablehlo/dialect/Version.h | 2 +- stablehlo/dialect/VhloDialect.td | 1 + stablehlo/dialect/VhloOps.cpp | 96 + stablehlo/dialect/VhloOps.td | 29 +- .../stablehlo_legalize_to_vhlo.0_17_0.mlir | 2410 +++++++++++++++++ .../stablehlo_legalize_to_vhlo.0_17_0.mlir.bc | Bin 0 -> 17593 bytes .../tests/stablehlo_legalize_to_vhlo.mlir | 125 + ...o_to_version_downgrade_invalid.0_16_0.mlir | 123 + stablehlo/transforms/VhloToVersion.cpp | 9 +- 10 files changed, 2788 insertions(+), 9 deletions(-) create mode 100644 stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir create mode 100644 stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc create mode 100644 stablehlo/tests/vhlo_to_version_downgrade_invalid.0_16_0.mlir diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 85a1ab2f234..57fb26d1853 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -119,7 +119,7 @@ bool isPromotableElementType(Type type1, Type type2, Type tensorEl2 = tensorTy2.getElementType(); if (ignoreFpPrecision && tensorEl1.isa() && - tensorTy2.getElementType().isa()) + tensorEl2.isa()) return true; bool isSameType = diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h index f8c92b7f1c7..0fe2385f4fa 100644 --- a/stablehlo/dialect/Version.h +++ b/stablehlo/dialect/Version.h @@ -38,7 +38,7 @@ class Version { static FailureOr fromString(llvm::StringRef versionRef); /// Return a Version representing the current VHLO dialect version. - static Version getCurrentVersion() { return Version(0, 16, 3); } + static Version getCurrentVersion() { return Version(0, 17, 0); } /// Return a Version representing the minimum supported VHLO dialect version. static Version getMinimumVersion() { return Version(0, 9, 0); } diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index fe1d4d5d174..608f2a17855 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -34,6 +34,7 @@ def VHLO_Dialect : Dialect { 0.14.0: MLIR bytecode version 3 => 5 (revised to 4 in #1827). 0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO. 0.16.0: Introduce `collective_broadcast` operation. + 0.17.0: Allow reduce operations to promote to higher bitwidth. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloOps.cpp b/stablehlo/dialect/VhloOps.cpp index a3428b1b465..f28432642c9 100644 --- a/stablehlo/dialect/VhloOps.cpp +++ b/stablehlo/dialect/VhloOps.cpp @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" @@ -37,6 +38,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "stablehlo/dialect/AssemblyFormat.h" #include "stablehlo/dialect/VhloBytecode.h" +#include "stablehlo/dialect/VhloTypes.h" namespace mlir { namespace vhlo { @@ -296,5 +298,99 @@ void VhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const { assert(succeeded(result)); } +/////////////////////////// +// Op Constraint Versioning +/////////////////////////// +// These could be migrated to ODS in VhloOps.td if we figured out a better way +// to represent this sort of constraint in tablegen. + +namespace { +bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes, + TypeRange resultTypes) { + SmallVector inputShapedTypes{ + llvm::map_range(operandTypes, [](Type t) { + return convertTypeToBuiltinForPrint(t).cast(); + })}; + SmallVector resultShapedTypes{ + llvm::map_range(resultTypes, [](Type t) { + return convertTypeToBuiltinForPrint(t).cast(); + })}; + + int64_t numInputs = inputShapedTypes.size(); + for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { + if (inputShapedTypes[inputIdx].getElementType() != + resultShapedTypes[inputIdx].getElementType()) + return true; + } + return false; +} +} // namespace + +LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op, + Version targetVersion) { + // Allow mismatched operand and result types in v0.17.0 + if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), + getResult().getType()) && + targetVersion < Version(0, 17, 0)) + return failure(); + + return success(); +} + +LogicalResult ReduceOpV1::validateConstraint(mlir::Operation* op, + Version targetVersion) { + // Allow mismatched operand and result types in v0.17.0 + if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), + getResultTypes()) && + targetVersion < Version(0, 17, 0)) + return failure(); + + return success(); +} + +LogicalResult ReduceScatterOpV1::validateConstraint(mlir::Operation* op, + Version targetVersion) { + // Allow mismatched operand and result types in v0.17.0 + if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), + getResult().getType()) && + targetVersion < Version(0, 17, 0)) + return failure(); + + return success(); +} + +LogicalResult ReduceWindowOpV1::validateConstraint(mlir::Operation* op, + Version targetVersion) { + // Allow mismatched operand and result types in v0.17.0 + if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), + getResultTypes()) && + targetVersion < Version(0, 17, 0)) + return failure(); + + return success(); +} + +LogicalResult ScatterOpV1::validateConstraint(mlir::Operation* op, + Version targetVersion) { + // Allow mismatched operand and result types in v0.17.0 + if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), + getResultTypes()) && + targetVersion < Version(0, 17, 0)) + return failure(); + + return success(); +} + +LogicalResult SelectAndScatterOpV1::validateConstraint(mlir::Operation* op, + Version targetVersion) { + // Allow mismatched operand and result types in v0.17.0 + if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), + getResult().getType()) && + targetVersion < Version(0, 17, 0)) + return failure(); + + return success(); +} + } // namespace vhlo } // namespace mlir diff --git a/stablehlo/dialect/VhloOps.td b/stablehlo/dialect/VhloOps.td index 0456360e199..edd515a2279 100644 --- a/stablehlo/dialect/VhloOps.td +++ b/stablehlo/dialect/VhloOps.td @@ -35,6 +35,17 @@ def VHLO_VersionedOpInterface : OpInterface<"VersionedOpInterface"> { ]; } +def VHLO_VersionedOpConstraintInterface : OpInterface<"VersionedOpConstraintInterface"> { + let cppNamespace = "::mlir::vhlo"; + let methods = [ + InterfaceMethod< + [{Validate versioned constraints on a versioned op. + Used if the spec'ed constraints of an op change over time.}], + "mlir::LogicalResult", "validateConstraint", + (ins "mlir::Operation*":$op, "mlir::vhlo::Version":$targetVersion)>, + ]; +} + class VHLO_Op traits = []> : Op] # traits> { @@ -92,7 +103,8 @@ def VHLO_AllGatherOpV1 : VHLO_Op<"all_gather_v1", "0.9.0", "current"> { let results = (outs VHLO_AnyType:$result); } -def VHLO_AllReduceOpV1 : VHLO_Op<"all_reduce_v1", "0.9.0", "current"> { +def VHLO_AllReduceOpV1 : VHLO_Op<"all_reduce_v1", "0.9.0", "current", + [DeclareOpInterfaceMethods]> { let arguments = (ins VHLO_AnyType:$operand, VHLO_AnyAttr:$replica_groups, @@ -754,7 +766,8 @@ def VHLO_RecvOpV1 : VHLO_Op<"recv_v1", "0.9.0", "current"> { let results = (outs Variadic:$results); } -def VHLO_ReduceOpV1 : VHLO_Op<"reduce_v1", "0.9.0", "current", [SameVariadicOperandSize]> { +def VHLO_ReduceOpV1 : VHLO_Op<"reduce_v1", "0.9.0", "current", + [SameVariadicOperandSize, DeclareOpInterfaceMethods]> { let arguments = (ins Variadic:$inputs, Variadic:$init_values, @@ -773,7 +786,8 @@ def VHLO_ReducePrecisionOpV1 : VHLO_Op<"reduce_precision_v1", "0.9.0", "current" let results = (outs VHLO_AnyType:$output); } -def VHLO_ReduceScatterOpV1 : VHLO_Op<"reduce_scatter_v1", "0.9.0", "current"> { +def VHLO_ReduceScatterOpV1 : VHLO_Op<"reduce_scatter_v1", "0.9.0", "current", + [DeclareOpInterfaceMethods]> { let arguments = (ins VHLO_AnyType:$operand, VHLO_AnyAttr:$scatter_dimension, @@ -785,7 +799,8 @@ def VHLO_ReduceScatterOpV1 : VHLO_Op<"reduce_scatter_v1", "0.9.0", "current"> { let results = (outs VHLO_AnyType:$result); } -def VHLO_ReduceWindowOpV1 : VHLO_Op<"reduce_window_v1", "0.9.0", "current", [SameVariadicOperandSize]> { +def VHLO_ReduceWindowOpV1 : VHLO_Op<"reduce_window_v1", "0.9.0", "current", + [SameVariadicOperandSize, DeclareOpInterfaceMethods]> { let arguments = (ins Variadic:$inputs, Variadic:$init_values, @@ -864,7 +879,8 @@ def VHLO_RsqrtOpV1 : VHLO_Op<"rsqrt_v1", "0.9.0", "current"> { let results = (outs VHLO_AnyType:$result); } -def VHLO_ScatterOpV1 : VHLO_Op<"scatter_v1", "0.9.0", "current", [SameVariadicOperandSize]> { +def VHLO_ScatterOpV1 : VHLO_Op<"scatter_v1", "0.9.0", "current", + [SameVariadicOperandSize, DeclareOpInterfaceMethods]> { let arguments = (ins Variadic:$inputs, VHLO_AnyType:$scatter_indices, @@ -880,7 +896,8 @@ def VHLO_ScatterOpV1 : VHLO_Op<"scatter_v1", "0.9.0", "current", [SameVariadicOp let results = (outs Variadic:$results); } -def VHLO_SelectAndScatterOpV1 : VHLO_Op<"select_and_scatter_v1", "0.9.0", "current"> { +def VHLO_SelectAndScatterOpV1 : VHLO_Op<"select_and_scatter_v1", "0.9.0", "current", + [DeclareOpInterfaceMethods]> { let arguments = (ins VHLO_AnyType:$operand, VHLO_AnyType:$source, diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir new file mode 100644 index 00000000000..70efa90c0b3 --- /dev/null +++ b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir @@ -0,0 +1,2410 @@ +// RUN: stablehlo-opt --mlir-print-op-generic %s.bc | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-translate --serialize --target=0.17.0 | stablehlo-opt --mlir-print-op-generic | FileCheck %s +// RUN: diff <(stablehlo-translate --deserialize %s.bc | stablehlo-opt) <(stablehlo-opt --strip-debuginfo %s) +// RUN: diff %s.bc <(stablehlo-translate --serialize --target=0.17.0 --strip-debuginfo %s) + +// CHECK-WARN-NOT: Not Implemented + +// ============ ATTRIBUTES ============ + +// CHECK-LABEL: "attr_comparison_direction_eq" +func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ne" +func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ge" +func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_gt" +func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_le" +func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_lt" +func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_notype" +func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + // CHECK: compare_type = #vhlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_float" +func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_totalorder" +func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_signed" +func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_unsigned" +func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ConvDimensionNumbers aka #stablehlo.conv is covered below. + +// CHECK-LABEL: "attr_custom_call_api_version_unspecified" +func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 0 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_original" +func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 1 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning" +func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 2 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" +func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 3 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_dict" +// CHECK: #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = #vhlo.integer_v1<1 : i32>, #vhlo.string_v1<"attr2"> = #vhlo.integer_v1<2 : i32>} +func.func @attr_dict() attributes {stablehlo.attr = {attr1 = 1 : i32, attr2 = 2 : i32}} { + return +} + +// DotDimensionNumbers aka #stablehlo.dot is covered below. + +// CHECK-LABEL: "attr_fft_type_fft" +func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_ifft" +func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_rfft" +func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xf32>) -> tensor<9xcomplex> + func.return %0 : tensor<9xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_irfft" +func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<9xcomplex>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// GatherDimensionNumbers aka #stablehlo.gather is covered below. + +// CHECK-LABEL: "attr_precision_config_default" +func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_high" +func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_highest" +func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_rng_algorithm_default" +func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_three_fry" +func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_philox" +func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_distribution_uniform" +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_rng_distribution_normal" +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. + +// CHECK-LABEL: "attr_transpose_no_transpose" +func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_transpose" +func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_adjoint" +func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. + +// CHECK-LABEL: "attr_type_extensions_bounds" +func.func @attr_type_extensions_bounds( + %arg0: tensor>) + -> tensor> { + // CHECK: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1>) -> () + func.return %arg0 : tensor> +} + + +// ============ DEFAULTS ============ + +// CHECK-LABEL: "default_all_gather" +func.func @default_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v1"(%arg0) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "default_all_reduce" +func.func @default_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v1"(%arg0) + // CHECK-SAME: <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_all_to_all" +func.func @default_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "default_cholesky" +func.func @default_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%arg0) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "default_collective_permute" +func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_collective_broadcast" +func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_broadcast_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_compare" +func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%arg0, %arg1) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_convolution" +func.func @default_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { + // CHECK: "vhlo.convolution_v1"(%arg0, %arg1) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x6x6x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> + func.return %0 : tensor<1x6x6x16xf32> +} + +// CHECK-LABEL: "default_custom_call" +func.func @default_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%arg0) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dot_general" +func.func @default_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v1"(%arg0, %arg1) <{ + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "default_dot" +func.func @default_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%arg0, %arg1) <{ + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "default_dynamic_broadcast_in_dim" +func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%arg0, %arg1) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dynamic_conv" +func.func @default_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "default_dynamic_gather" +func.func @default_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + > + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +func.func @default_func(%arg0: tensor) -> tensor { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"default_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<""> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%arg0: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + func.return %arg0 : tensor +} + +// CHECK-LABEL: "dynamic_gather" +func.func @dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v1"(%arg0, %arg1) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = dense<1> : tensor<3xi64> + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "default_infeed" +func.func @default_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%arg0) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_outfeed" +func.func @default_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%arg0, %arg1) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<""> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_recv" +func.func @default_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_send" +func.func @default_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%arg0, %arg1) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_reduce_scatter" +func.func @default_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "default_reduce_window" +func.func @default_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x16x30x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + func.return %0 : tensor<2x16x30x7xf32> +} + +// CHECK-LABEL: "default_scatter" +func.func @default_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + > + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "default_select_and_scatter" +func.func @default_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<10x23x23x64x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "default_sort" +func.func @default_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%arg0) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<-1 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// ============ OPS ============ + +// CHECK-LABEL: "op_abs" +func.func @op_abs(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_add" +func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_after_all" +func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.after_all_v1"(%arg0) : (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_all_gather" +func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v1"(%arg0) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_all_reduce" +func.func @op_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_all_reduce_with_promotable_types" +func.func @op_all_reduce_with_promotable_types(%operand: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v1"(%arg0) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + + func.return %result : tensor +} + +// CHECK-LABEL: "op_all_to_all" +func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "op_and" +func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_atan2" +func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.atan2_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_batch_norm_grad" +func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_grad_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_batch_norm_inference" +func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + // CHECK: "vhlo.batch_norm_inference_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1> + %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} + +// CHECK-LABEL: "op_batch_norm_training" +func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_training_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_bitcast_convert" +func.func @op_bitcast_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.bitcast_convert_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_broadcast_in_dim" +func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_in_dim_v1"(%arg0) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<1> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_broadcast" +func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_v1"(%arg0) <{ + // CHECK-SAME: broadcast_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast"(%arg0) { + broadcast_sizes = array + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_case" +func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cbrt" +func.func @op_cbrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cbrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_ceil" +func.func @op_ceil(%arg0: tensor) -> tensor { + // CHECK: "vhlo.ceil_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cholesky" +func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%arg0) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) { + lower = true + } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "op_clamp" +func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.clamp_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_count_leading_zeros" +func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { + // CHECK: "vhlo.count_leading_zeros_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_collective_permute" +func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "op_compare" +func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%arg0, %arg1) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_complex" +func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK: "vhlo.complex_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_compute_reshape_shape" +func.func @op_compute_reshape_shape(%arg0: index, %arg1: tensor<1xindex>) -> tensor<1xindex> { + // CHECK: "vhlo.compute_reshape_shape_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1<1x!vhlo.index_v1> + %0 = "stablehlo.compute_reshape_shape"(%arg0, %arg1) : (index, tensor<1xindex>) -> tensor<1xindex> + func.return %0 : tensor<1xindex> +} + +// CHECK-LABEL: "op_concatenate" +func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.concatenate_v1"(%arg0, %arg1) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1<8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.concatenate"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_constant" +func.func @op_constant(%arg0: tensor) -> tensor { + // CHECK: "vhlo.constant_v1"() <{ + // CHECK-SAME: value = #vhlo.tensor_v1 : tensor> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1 + %0 = "stablehlo.constant"() { + value = dense<0.0> : tensor + } : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convert" +func.func @op_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.convert_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convolution" +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> { + // CHECK: "vhlo.convolution_v1"(%arg0, %arg1) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x7x7x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = dense<2> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<2> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> + func.return %0 : tensor<1x7x7x16xf32> +} + +// CHECK-LABEL: "op_cosine" +func.func @op_cosine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cosine_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_create_token" +func.func @op_create_token() -> !stablehlo.token { + // CHECK: "vhlo.create_token_v1"() : () -> !vhlo.token_v1 + %0 = "stablehlo.create_token"() : () -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_cross_replica_sum" +func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cross-replica-sum_v1"(%arg0) <{ + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cross-replica-sum"(%arg0) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cstr_reshapable" +func.func @op_cstr_reshapable(%arg0: index, %arg1: tensor<1xindex>) -> !shape.witness { + // CHECK: "vhlo.cstr_reshapable_v1"(%arg0, %arg1) : (!vhlo.index_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.witness_v1 + %0 = "stablehlo.cstr_reshapable"(%arg0, %arg1) : (index, tensor<1xindex>) -> !shape.witness + func.return %0 : !shape.witness +} + +// CHECK-LABEL: "op_custom_call" +func.func @op_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%arg0) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"\08\03\1A\02">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[#vhlo.string_v1<"foo">]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[ + // CHECK-SAME: #vhlo.output_operand_alias_v1< + // CHECK-SAME: outputTupleIndices = [], + // CHECK-SAME: operandIndex = 0, + // CHECK-SAME: operandTupleIndices = []>]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + has_side_effect = true, + backend_config = "\08\03\1A\02", + api_version = 2 : i32, + called_computations = [@foo], + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_divide" +func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.divide_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dot_general" +func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v1"(%arg0, %arg1) <{ + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "op_dot" +func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%arg0, %arg1) <{ + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) { + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_dynamic_broadcast_in_dim" +func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%arg0, %arg1) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, + known_expanding_dimensions = dense<0> : tensor<1xi64>, + known_nonexpanding_dimensions = dense<1> : tensor<1xi64> + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_conv" +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<4xi32>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + window_strides = dense<2> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = dense<2> : tensor<2xi64>, + rhs_dilation = dense<2> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<4xi32>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "op_dynamic_gather" +func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +// CHECK-LABEL: "op_dynamic_iota" +func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_iota_v1"(%arg0) <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_iota"(%arg0) { + iota_dimension = 0 : i64 + } : (tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_pad" +func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_pad_v1"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_reshape" +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_slice" +func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { + // CHECK: "vhlo.dynamic_slice_v1"(%arg0, %arg1) <{ + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { + slice_sizes = array + } : (tensor<16xf32>, tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_dynamic_update_slice" +func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.dynamic_update_slice_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_einsum" +func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.einsum_v1"(%arg0, %arg1) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab,bc->ac"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.einsum"(%arg0, %arg1) { + einsum_config = "ab,bc->ac" + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_exponential_minus_one" +func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_minus_one_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_exponential" +func.func @op_exponential(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_fft" +func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + // CHECK: "vhlo.fft_v1"(%arg0) <{ + // CHECK-SAME: fft_length = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: fft_type = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.complex_v1>) -> !vhlo.tensor_v1<16x!vhlo.complex_v1> + %0 = "stablehlo.fft"(%arg0) { + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "op_floor" +func.func @op_floor(%arg0: tensor) -> tensor { + // CHECK: "vhlo.floor_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +func.func private @op_func(%arg0: tensor {stablehlo.arg = "0"}) -> (tensor {stablehlo.result = "0"}) { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.arg"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.result"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"op_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<"private"> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%arg0: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%arg0) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + + func.return %arg0 : tensor +} + +// CHECK-LABEL: "op_gather" +func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v1"(%arg0, %arg1) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = dense<1> : tensor<3xi64>, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "op_get_dimension_size" +func.func @op_get_dimension_size(%arg0: tensor) -> tensor { + // CHECK: "vhlo.get_dimension_size_v1"(%arg0) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_dimension_size"(%arg0) { + dimension = 0 : i64 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_get_tuple_element" +func.func @op_get_tuple_element(%arg0: tuple, tensor>) -> tensor { + // CHECK: "vhlo.get_tuple_element_v1"(%arg0) <{ + // CHECK-SAME: index = #vhlo.integer_v1<0 : i32> + // CHECK-SAME: }> : (!vhlo.tuple_v1, !vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_tuple_element"(%arg0) { + index = 0 : i32 + } : (tuple, tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_if" +func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.if_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: "vhlo.return_v1"(%arg2) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.if"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + "stablehlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_imag" +func.func @op_imag(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.imag_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_infeed" +func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%arg0) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"foo">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[#vhlo.array_v1<[]>]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) { + infeed_config = "foo", + layout = [[]] + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_iota" +func.func @op_iota() -> tensor<16xf32> { + // CHECK: "vhlo.iota_v1"() <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.iota"() { + iota_dimension = 0 : i64 + } : () -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_is_finite" +func.func @op_is_finite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.is_finite_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log" +func.func @op_log(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log_plus_one" +func.func @op_log_plus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_plus_one_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_logistic" +func.func @op_logistic(%arg0: tensor) -> tensor { + // CHECK: "vhlo.logistic_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_map" +func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.map_v1"(%arg0) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.map"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_maximum" +func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.maximum_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_minimum" +func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.minimum_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_multiply" +func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.multiply_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_negate" +func.func @op_negate(%arg0: tensor) -> tensor { + // CHECK: "vhlo.negate_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_not" +func.func @op_not(%arg0: tensor) -> tensor { + // CHECK: "vhlo.not_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_optimization_barrier" +func.func @op_optimization_barrier(%arg0: tensor) -> tensor { + // CHECK: "vhlo.optimization_barrier_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_or" +func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.or_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_outfeed" +func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%arg0, %arg1) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<"foo"> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) { + outfeed_config = "foo" + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_pad" +func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.pad_v1"(%arg0, %arg1) <{ + // CHECK-SAME: edge_padding_high = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: edge_padding_low = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: interior_padding = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.pad"(%arg0, %arg1) { + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array + } : (tensor<8xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_popcnt" +func.func @op_popcnt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.popcnt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_power" +func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.power_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real_dynamic_slice" +func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.real_dynamic_slice_v1"(%arg0, %arg1, %arg2, %arg3) : (!vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real" +func.func @op_real(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.real_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_recv" +func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_reduce" +func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = dense<0> : tensor<1xi64> + } : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reduce_precision" +func.func @op_reduce_precision(%arg0: tensor) -> tensor { + // CHECK: "vhlo.reduce_precision_v1"(%arg0) <{ + // CHECK-SAME: exponent_bits = #vhlo.integer_v1<8 : i32> + // CHECK-SAME: mantissa_bits = #vhlo.integer_v1<10 : i32> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce_precision"(%arg0) { + exponent_bits = 8 : i32, + mantissa_bits = 10 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK_lABEL: "op_reduce_with_promotable_types" +func.func @op_reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + // CHECK: "vhlo.reduce_v1"(%arg0, %arg1) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f64_v1> + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + +// CHECK-LABEL: "op_reduce_scatter" +func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%arg0) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK_lABEL: "op_reduce_scatter_with_promotable_types" +func.func @op_reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + // CHECK: "vhlo.reduce_scatter_v1"(%arg0) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f64_v1> + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + + +// CHECK-LABEL: "op_reduce_window" +func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x9x16x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>, + base_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x9x16x7xf32> + func.return %0 : tensor<2x9x16x7xf32> +} + +// CHECK_lABEL: "op_reduce_window_with_promotable_types" +func.func @op_reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1, %arg2, %arg3) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]], %[[VAL2:.*]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1, !vhlo.tensor_v1) -> (!vhlo.tensor_v1<2x2x!vhlo.f64_v1>, !vhlo.tensor_v1<2x2x!vhlo.f32_v1>) + %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = stablehlo.add %a0, %b0 : tensor + %3 = stablehlo.add %a1, %b1 : tensor + "stablehlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// CHECK-LABEL: "op_remainder" +func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.remainder_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_replica_id" +func.func @op_replica_id() -> tensor { + // CHECK: "vhlo.replica_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.replica_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_partition_id" +func.func @op_partition_id() -> tensor { + // CHECK: "vhlo.partition_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.partition_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reshape" +func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { + // CHECK: "vhlo.reshape_v1"(%arg0) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f32_v1> + %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// CHECK-LABEL: "op_return" +func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%arg0) ({ + // CHECK-NEXT: "vhlo.return_v1"(%arg1) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reverse" +func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reverse_v1"(%arg0) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reverse"(%arg0) { + dimensions = array + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_rng_bit_generator" +func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { + // CHECK: "vhlo.rng_bit_generator_v1"(%arg0) <{ + // CHECK-SAME: rng_algorithm = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> (!vhlo.tensor_v1, !vhlo.tensor_v1) + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "op_rng" +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.rng_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: rng_distribution = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + rng_distribution = #stablehlo + } : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_afz" +func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_afz_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_even" +func.func @op_round_nearest_even(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_even_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_rsqrt" +func.func @op_rsqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.rsqrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_scatter" +func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK_lABEL: "op_scatter_with_promotable_types" +func.func @op_scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + // CHECK: "vhlo.scatter_v1"(%arg0, %arg1, %arg2) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f64_v1> + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + +// CHECK-LABEL: "op_select_and_scatter" +func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "op_select_and_scatter_with_promotable_types" +func.func @op_select_and_scatter_with_promotable_types(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf64> { + // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f64_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf64> + func.return %0 : tensor<10x24x24x64xf64> +} + +// CHECK-LABEL: "op_select" +func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.select_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_send" +func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%arg0, %arg1) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_set_dimension_size" +func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.set_dimension_size_v1"(%arg0, %arg1) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_shift_left" +func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_left_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_arithmetic" +func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_arithmetic_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_logical" +func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_logical_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sign" +func.func @op_sign(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sign_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sine" +func.func @op_sine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sine_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_slice" +func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { + // CHECK: "vhlo.slice_v1"(%arg0) <{ + // CHECK-SAME: limit_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: strides = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<16xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_sort" +func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%arg0) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimension = 0 : i64, + is_stable = true + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_sqrt" +func.func @op_sqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sqrt_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_subtract" +func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.subtract_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_tanh" +func.func @op_tanh(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tanh_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_torch_index_select" +func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { + // CHECK: "vhlo.torch_index_select_v1"(%arg0, %arg1) <{ + // CHECK-SAME: batch_dims = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: dim = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x1x5x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<2x1x5x!vhlo.f32_v1> + %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + func.return %0 : tensor<2x1x5xf32> +} + +// CHECK-LABEL: "op_trace" +func.func @op_trace(%arg0: tensor) { + // CHECK: "vhlo.trace_v1"(%arg0) <{ + // CHECK-SAME: tag = #vhlo.string_v1<"foo"> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> () + "stablehlo.trace"(%arg0) { + tag = "foo" + } : (tensor) -> () + func.return +} + +// CHECK-LABEL: "op_transpose" +func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { + // CHECK: "vhlo.transpose_v1"(%arg0) <{ + // CHECK-SAME: permutation = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x16x!vhlo.f32_v1> + %0 = "stablehlo.transpose"(%arg0) { + permutation = array + } : (tensor<16x8xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "op_triangular_solve" +func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.triangular_solve_v1"(%arg0, %arg1) <{ + // CHECK-SAME: left_side = #vhlo.bool_v1, + // CHECK-SAME: lower = #vhlo.bool_v1, + // CHECK-SAME: transpose_a = #vhlo, + // CHECK-SAME: unit_diagonal = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_tuple" +func.func @op_tuple(%arg0: tensor) -> tuple> { + // CHECK: "vhlo.tuple_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tuple_v1> + %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> + func.return %0 : tuple> +} + +// CHECK-LABEL: "op_unary_einsum" +func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { + // CHECK: "vhlo.unary_einsum_v1"(%arg0) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab->a"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x!vhlo.f32_v1> + %0 = "stablehlo.unary_einsum"(%arg0) { + einsum_config = "ab->a" + } : (tensor<8x16xf32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} + +// CHECK-LABEL: "op_uniform_dequantize" +func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.uniform_dequantize_v1"(%arg0) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_uniform_quantize" +func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { + // CHECK: "vhlo.uniform_quantize_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_while" +func.func @op_while(%arg0: tensor) -> tensor { + // CHECK: "vhlo.while_v1"(%arg0) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1) + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} + +// CHECK-LABEL: "op_xor" +func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.xor_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ============ TYPES ============ + +// CHECK-LABEL: "type_i1" +func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i4" +func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i8" +func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i16" +func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i32" +func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i64" +func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui4" +func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui8" +func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui16" +func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui32" +func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui64" +func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FN" +func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2" +func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FNUZ" +func.func @type_f8E4M3FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3B11FNUZ" +func.func @type_f8E4M3B11FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2FNUZ" +func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_bf16" +func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f16" +func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f32" +func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f64" +func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_complex_f32" +func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_complex_f64" +func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_dynamism_ranked" +func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_dynamism_unranked" +func.func @type_dynamism_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "vhlo.abs_v1"(%arg0) : (!vhlo.unranked_tensor_v1) -> !vhlo.unranked_tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK-LABEL: "type_quantization" +func.func @type_quantization(%arg0: tensor>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%arg0, %arg1) : (!vhlo.tensor_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_callee" +func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.return_v1"(%arg0) : (!vhlo.token_v1) -> () + return %arg0 : !stablehlo.token +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_caller" +func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.call_v1"(%arg0) <{callee = #vhlo.string_v1<"type_token_callee">} + // CHECK-SAME: (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token + return %0 : !stablehlo.token +} + +// CHECK-LABEL: "type_tuple" +func.func @type_tuple(%arg0: tuple>) -> tuple { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + // CHECK: (!vhlo.tuple_v1>) -> !vhlo.tuple_v1 + } : (tuple>) -> tuple + return %0 : tuple +} diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc new file mode 100644 index 0000000000000000000000000000000000000000..162bc8b661b4429d7705a6043c6f8856c0d5e6f3 GIT binary patch literal 17593 zcmch9e~eVcw(j0lyQc1@)AZYPn$FNEI?c{7X}H6jVJ>iifEjo+54gYq!tlUAr)Q=I z8fUu4o*xKCd}8!OjTkXv)QC@v7&U6th!G=3y=v5mSIohP5u-+p8Z}3AG-8gvZ`Izr zd-n|JdGC)m*j-iYtF>y?s#U92?e2E%e|)zwx@n})Gf?~=|FJ$l_NO@=3+DXEFh60! z6yKRV!oP{dlc}`ZI=j7d!QxAnuUxZ!)0Ru~y*n?f4DY$}>TCaU(=E5%areFVKlsSw zPd)S83opO+=G*VS_rXV>e)08pKm7E|Z%u46o64rMnQRW5&z7e$^ z*RzA{FgwDIvB%hPc9OlyPO&rW9Q&Mo%YJ0PvL-&6Pvz73Og@Lt=Sz8xujAc(2k+-) zKF0U+>-j-`m>=QC_+$JyKgnO^eEOuZrZH|x8ISPYO_4|p|E;H@JrcX6y&cs}kw|-c zdn6u@&|j<_i1&9}+OL3$KktpuPQ1k<@pogJ0zSfPe<`X_Z;VnSQnNE?HIQ}RfTGKw z5p@4&%Xkv2^Urqv|JrZ;E$i1$XSeuc#^dpHOxxD8xvpXXWAT{8zhpArDbq3;b7S$i z+|n9Px~`0Orcx;|f)x5oCP7OllP<*Nle$Mfo=PR-NtsT&X&H|vfwsD-cuL0NsU&2h zS2~4)DyeiDZ*XX$q}p3s@t!6+`X^ffC&3;E(WPFkahXgL2L>hnHOZ7rdH>0JK~?Wx zDir}MIzo$NoGty=H0RsAX-xg>Ze(sZb9?HJbhbj z%#>&r=^kEd$ISE{wFb=o0{<-{IG@p!-pvc+)d`Fds1=lfdhw)_E>E-_h7TlRu%MbW zF_@9Aax=-*Kh$rEG1bJ>B&H@aHHE3EOkpn5nVP}WOr~ZrHHWFWEGpjp>wE93`7E9i zY7yRT_Z792sbx&%nESd~&D=NDI;J)<)y>p4rgkvZ!&E<0yO}C8Rb^_7seR1-Ym4YI ze)Ip-ex?pEbv;uzGIfxtLrfiJ>MrJfq>eCkl&ND(JS6;t0b^*vKRGWQpCo~d7%Vq7(G_ct|(y9|@% zZWHFq)l{ygaW$Q*8C=ceY8F>>xH}b7=4w7yi?~|K-RYPocV}R#+?~nPYOdCCwUN8C zFj=m)ad!^p%T*6|=VHoS?dGb?Rh6qTuJ&=YpQ{7hosW5PcM+z_-KCfD2|V-)nlk1M*S|-pXBN|S0}hS$=%g319#WK5?sB()hVt{ za~Ebm!`&UQ23O~}`h=^`x%!H$Z@K!Os~@>K&(*J7F-tXBYLeyl!w#03VyUT?yBq6n zsp*!QVYy{2y`^SZZWXI&$!rMfM( z%~Cro)nloCOYOE)*>VqJ=`HsVR^M_D!vQR{-%aeBmvfQIg z9kJ9=OC7V+!eEPTRp&%rG$_j7oL<$lZ5dCUC~-eI}t;UK~#9}(^( zt|ke03RjbbJB_O;!i8^4749s!ig4$`TZEb++(mF0;Vy;G2zMFWM!2iFnk!uR#(d#! zg!2fsRJh$-EfXq7p*7gZ!Cp+?(G~3FpbVg(1efsZ0XIm&&aaVfyCz@=%8LoS7(Eot zgQyvl!6@-mCZ5W2F5@Xvtrlt>B33Hh&rQx}H6={bOq(_hA5Zd(%!m|k=FKJ-PdtHR zRK$2hMoinbrHn{h@~kPdyh~=yJMFBQYRj4nvZ4!Z(Uk!pvRPh>5>jRDF8s+;OGdKXPKE}{ zW~J6c+D;4DMZ(6E9gK96Wz9G6dM<08LGcM*U*YvVUcchigc(i7YbsvT@tTR(9K7b^ zmE&3AKyi_UFp<^j2n^bSiP$TyQa~IGAspywcg&kNg~@hqc1kX1H?a#&OLU>X%#bx? zhAf6U-nk9CIVsGsW`bG;U3oM`GotPyA@qRI!bxU@Z6{4|H8-P|ZFlX3N_KYpb~3Ro zl9&`9r{3?e#R4ekU9lnGwp{C=A&UyT9 zM_h9)zs?b}|H5ZGvUA=I{3b_U(seVx#Sxcu@mn1p5oQEKM|gxr!$1|2-{Z)(oh9Dp z$jrSa&uGn{k?v#is3QdO3MbNx{>>8aUKI_8TAG3k0vsG_37cjFov8rX*5cBR$i{ix z5$)|XD0IEg!No;W?{|2dC!9#aL<)(D;ESQX;4WY8NWRe}Dv&*&sNLH*7bu zAB;%)iFRt+i9iT-6PKI(L9*0CdZLdw_VBQ2ibo#@Z%Yd)e7QIGCouO>n!7HZa^UTx zr%n2!*K$WZ{S1E=qhV-=V{uHr2PJQJ_zG@z@aLo-ZVvpfu3!0it=9@itX#=gVnVA{ z@l}pkvxcu}c7*G4*AZ*i^0hF^`t^K048CCl-vDAu3vY46rcHcPt0OjV=9^*rEnD~& zSas`GzIA~kF1?gry4VrhxAW}_9g)xTJSJW!@Pc$iG)ilgPIJiZ?d83E0og6#@V$Jm zBUWg$?cvX7h0u;p`!0%?nazqQcD@j_{mrEHeUC8zY?{(K)x*twP2$fcm)2}9;X0;a z=S&z|z|-b^{h^z@iOUtO3IkhGfPYutWovY`~fuuv7!qsHbYc(hYo>2Fz{1 zS{ks{1}xiv&2GTjq=n8AZW@uCi6Dl5uo#EMI>Jb27?~GF7KD+7VFb}2kha0(OQaP= zpx}lH*_ZhJ)*8{A*@d7|8+CaYSrJB7`iLN&RR}*do;5!5J|Szd4{FGIf9AL)#d78b z-Vs{8O%2%Q25d_MwzUDf)Wcx0M9$o<_c=^PtN3R#AI1ueSZ^3BHex%a)z0k*!qKRa z%9$l@O1LepgoM<~8d1zVj8=SIl``Tv^Ku>w&2hK^8*RY$G+=uhuq&mNM9Abn8>9oR z*j2ubGseoCc{PvKZG4TiHhUiAcAM9NjxCw>x1nqzZhsdBj&+?c>UCp9&io6%x-RMl zAK7ax&zU#j(5Nls&9&jXa^@{MZuxF{YYq7(XWkY@Zucj6p>S}Aj}=SCB{}m>es`cG zrh2z8QE!O`$G?Xb0s)4xNz7Q1Gwf97w;D>zps%z#lV?oY*gx?Z|ANAMDtgY4Kz9Ls$W6aB$Pw-nqV?E_z zJYjU^%%|yuB^!I*`;5Zu76+976}udm$;l z)4ngQ7H`iuG@-hme{d14q40-3Pr~>bJl?wLK>bl|2ViS{9NwCrhPUQtG|olD2V3)t zFvFL)#s&<{%?3?98hAv&@U=g--pk*F_wslCP_}KffHQn<{;jUXmw;*k0+akfT9+Vk zi*z8F>oC5^nSZCeEkes_V@w%|#?b30X)WMUua_`BLoeQ3>jU@uv-z{vLqd-6sV~Lf zG*a5s;}_qYj`1<+P+x+7_)`+bM>+GiusM-kBeLm*{xD}Yu`o6X*g|p=JSCbvRSOh! z+|h#58~N#E^zoPdgT^s&fe4WI8Nv@)8TLX zcH`}wxs<_g>!D>CP~R+XEi-ehuITCjaf~-P6Ba z#5V@4gM1}8NA@uuK1lWh*K6RNB#r6E0f71t@{f}1S;JX!BjC>PVFJ@MoF)f@4#9D9 zh&7%khtWY>qs4eZuT`RMjk|!&qn(#ETUuI-=P|KFnuDlmV@+Z)xJ79h598<3D0rU!^=+}g-9 zBNz$i-Yvuw(qbks?I~V@H46t@2%Qt4+fC!wsQL5lomy zn_f8@l{_6Z=pn;;&|q}@ZPXp$?xt~V&g`aZTFtMv1p+?~1U^_7c%kHj1nvNnpLij; z?D5^_qB_u}wm)Ebo$}G$oImu3DNZ?1&C5uuoIg@m4RgI_yU}jZDLlH4g)$gUY)L7qRkMw$RIc z7r^kPcHFh~2w;EP?P+wqqjhAo5jj>zh8vNGf#fM#`gZ_)4D-)i0-l`6_5__ezS%G2 zcyxC-4mR)d@TI;xoZ$cOmxq%RWj#eI{N6?6VtTy? zzR?(dP6cSzF)9;sw9`PfCpwNX2*GwJ$eambmyN^DhOp8&>>M!P19pw&DO^UMLF zqpfYMMLL4hFal>o%(A@)582~f4Wxw+!}%d>`~`7QfZcqvK1x=GEVwkl zZoO4wt3uea0K5Hmjjah`xd4L)Q@9O#@ah0tu|i{Rh;Lnh-Fc_R)`qZ+0d~(l8e1R2 zx&!RK`!u#8gl!A3RjV}C62f)_*aHt}Y*Ps939yGA(%9w@)*oPxKB}=TA#8VmJ@JIb zwuZ2BfIa=R#x4zE)c||;S&iXoXYDdO7GTdmud#dx+ZSLjzNoQ62-_cE2vH4b0-b3A zHa@`7(H3A!mT124-Sfr(d*v04^@hef2uyE~tl?nULzkID0rEGE?5(@{;XvLE6Lq`` zn7#odYR}t_=&Kv`#ZyDdv1NKh-ygy141r%f5|J*1diD_ zR<|S3nZ^X77&De>jwPEDv5sggfoWz_Dw9mk#!TAOymp#%Gp3b_XWFvyn4?s4TNqkhxVJDeQuL_xDsiD-wb(b>Tg>k% zmHR4th%i{VqBJ-<7{WsQq?*rxt|Xj>s#LX#!$q&%TNw8A@K3 zQXR}+K3XV`l=d~y`ihOzLUnJxSSk;DwjCKA@lj|#9 zk@uI9sD<(_ZSm0_>gieia^Ly`C69HuRMsMfOFQ+rxym>d_mi zh5mgsLKSE=aFXF&g(0ml_Q6O=uUARCNVPbGITi9jYp_7}(0=Uu;83;LThe-Iq29!* zMU;V~3U0c=i{gjV-IL^ppIZxEc?qsag;W7pkqx_0tT>)Xkd6K_CT@FS3>aHSA-s#JWv?a zi`QE$d7-1n8?GK@z+7+CLk^;@4niRndkQ1ixMd0o*rHfxFU*F>6#1&fvbR}7^qp1D z9}Ae_T+*x1Pf^>CN1^yt*94tl5117v%^U>zs$uk3D#o9Tk-gYbo-HYK6^jO;dGD|( z4Pztsm9TfTvllP_9kBd+oI8IP9QhrwI1w+d(7;4aJs)gEOltz8#;^?v7hxO078|Hk z5jif*N^!E-cVTQ4Q7?ps4(P&ADNjL+f+BhiArSSKYGN+Zv4SXEDi;R)!7kDZOFDx6 z>ZJ_upZ~kw(59{loCxba*OSm!>h(?3kI?R!t$#eF1kCGR>r|acD%ENA7>r%eU)q`P zEA|%j4##-~@K%sAUrBwDOrtMz50+EY+XE@(O$L6r% zTTO3wFYjm|c?Vtt$;(x0q+UOjsS_q-JQsp$!6XFgj5O4)B~o=--F0iUQ2rIB)$4wC zR~ymK-6nLPXgYy7&L)&a8^NhF4)apWngl=TthEvM>f$M94Iw%g_VeBGWG_XoQOEb1 zm~hv-B0)APLwV%zy6EfEMG)R~qn2(4Eq}#{t)6u6!EVS8RV#ytX>^*BGkG#=$*bna zq==DD5~_?lWyXt!D~_iqHCel#A_peL3sAEQ{*q2 z*qG;NQfQMvQ!ghXAiPt&rlwwOe^3`6(_16d_B}sDp$sf64GCJf>8D3wl1}r&gm6Ac zbT4DnwNQ>(ODe|!)T9H03d&8x8A^3$s0CH%n4o^NnTM*SG30X8uFq;>&a3h9=`j&Y z{rVkwv@Z4aQrCQg)Y5JuA{F^5M0;XD0RMBVlUr%!FwAF$-1j*u}fjl<*9Mt~RBA^UHq^==1$b=Ul%s-|sRS z^c^kvokvR@e(%vz(0#NN^zZ-o75}!jGq=tYgVwyYzQ_RCysd_z=dzBD+`146-Rte1 z{$QwjB^a_l7%{k^_6JCf5cKfE|L|ZQspn<5WM~`Ih@<5&**Bnmkl}R<-2;6y>Sqf^ zp{H$W0q_6X(~C;4@h=?T5B--D_%RheOO42+e99JWv#Ij5E!(P1RC@-b^;CHll{{6R z16J19^Qhr>%Amb~&q9;(MO$pw>@VA5v97#g%l;-^d(D4pAK7BT?Ns^LmJ7O}!l&{x)Hb5_h5Qn=b*Od7 zuWfz_mEYK6wI+OL^Cc14CI4oN90*^z@oVQrO}hRY)#o?GVv;dAiP=$El#`ilJCXK86tkMbY^kd|VsZ|%T?aoX$!414ttcM4r1jBOk~)_L zx=hZ8L^bac)e^R&(rYxNmg1Opkk}Mo{FNpk?iHP$Ed?WbVoHN4>5U| z>JNatQ{IK~9M5*Eo>iZ7v~Lb$sJM&0FHHI?=G|9IH+gtU6J!>gW~E#OC4m z>8rV1N2B4ZFx^yc<8lX;JzVx<*tIdao69n>b;v4aFcmY1$}u8DlW7&37(tfkP`T8wV#?-%H5vyG8J92O2uY5Mg`WuEB1TE0V+o2^(gQ`T|9z(Q9~mB@p4Z5uRIHT8Xnq-}`miM*BhniAq*omGiW5{su2i3GlC$v4P+ zla^<^u4k$1IZJ**ICkRami&rH-&*o}kiNz2m&+fCa^8}^Qpxb; zBVmiLL`+T+a) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "op_all_reduce_with_promotable_types" +func.func @op_all_reduce_with_promotable_types(%operand: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v1"(%arg0) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + + func.return %result : tensor +} + // CHECK-LABEL: "op_all_to_all" func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { // CHECK: "vhlo.all_to_all_v1"(%arg0) <{ @@ -1646,6 +1665,23 @@ func.func @op_reduce_precision(%arg0: tensor) -> tensor { func.return %0 : tensor } +// CHECK_lABEL: "op_reduce_with_promotable_types" +func.func @op_reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + // CHECK: "vhlo.reduce_v1"(%arg0, %arg1) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f64_v1> + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + // CHECK-LABEL: "op_reduce_scatter" func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { // CHECK: "vhlo.reduce_scatter_v1"(%arg0) <{ @@ -1671,6 +1707,24 @@ func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { func.return %0 : tensor<16xf32> } +// CHECK_lABEL: "op_reduce_scatter_with_promotable_types" +func.func @op_reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + // CHECK: "vhlo.reduce_scatter_v1"(%arg0) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f64_v1> + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + + // CHECK-LABEL: "op_reduce_window" func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1) <{ @@ -1698,6 +1752,29 @@ func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> func.return %0 : tensor<2x9x16x7xf32> } +// CHECK_lABEL: "op_reduce_window_with_promotable_types" +func.func @op_reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + // CHECK: "vhlo.reduce_window_v1"(%arg0, %arg1, %arg2, %arg3) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]], %[[VAL2:.*]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1, !vhlo.tensor_v1) -> (!vhlo.tensor_v1<2x2x!vhlo.f64_v1>, !vhlo.tensor_v1<2x2x!vhlo.f32_v1>) + %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = stablehlo.add %a0, %b0 : tensor + %3 = stablehlo.add %a1, %b1 : tensor + "stablehlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + // CHECK-LABEL: "op_remainder" func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "vhlo.remainder_v1"(%arg0, %arg1) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 @@ -1822,6 +1899,32 @@ func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, % func.return %0 : tensor<200x100x300xf32> } +// CHECK_lABEL: "op_scatter_with_promotable_types" +func.func @op_scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + // CHECK: "vhlo.scatter_v1"(%arg0, %arg1, %arg2) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f64_v1> + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + // CHECK-LABEL: "op_select_and_scatter" func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) <{ @@ -1853,6 +1956,28 @@ func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<1 func.return %0 : tensor<10x24x24x64xf32> } +// CHECK-LABEL: "op_select_and_scatter_with_promotable_types" +func.func @op_select_and_scatter_with_promotable_types(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf64> { + // CHECK: "vhlo.select_and_scatter_v1"(%arg0, %arg1, %arg2) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f64_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf64> + func.return %0 : tensor<10x24x24x64xf64> +} + // CHECK-LABEL: "op_select" func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK: "vhlo.select_v1"(%arg0, %arg1, %arg2) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 diff --git a/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_16_0.mlir b/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_16_0.mlir new file mode 100644 index 00000000000..3befa48e1f4 --- /dev/null +++ b/stablehlo/tests/vhlo_to_version_downgrade_invalid.0_16_0.mlir @@ -0,0 +1,123 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=0.16.0' --verify-diagnostics --split-input-file %s + +func.func @reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + + // expected-error @+1 {{failed to legalize operation 'vhlo.reduce_v1' that was explicitly marked illegal}} + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + +// ----- + +func.func @all_reduce_with_promotable_types(%operand: tensor) -> tensor { + + // expected-error @+1 {{failed to legalize operation 'vhlo.all_reduce_v1' that was explicitly marked illegal}} + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor) -> tensor + + func.return %result : tensor +} + +// ----- + +func.func @reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + + // expected-error @+1 {{failed to legalize operation 'vhlo.reduce_scatter_v1' that was explicitly marked illegal}} + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + +// ----- + +func.func @reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + + // expected-error @+1 {{failed to legalize operation 'vhlo.reduce_window_v1' that was explicitly marked illegal}} + %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = stablehlo.add %a0, %b0 : tensor + %3 = stablehlo.add %a1, %b1 : tensor + "stablehlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = dense<[5, 1]> : tensor<2xi64>, + window_strides = dense<[3, 1]> : tensor<2xi64> } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// ----- + +func.func @scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + + // expected-error @+1 {{failed to legalize operation 'vhlo.scatter_v1' that was explicitly marked illegal}} + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + +// ----- + +func.func @select_and_scatter_with_promotable_types( + %arg0: tensor<10x24x24x64xf32>, + %arg1: tensor<10x12x12x64xf32>) -> () { + %0 = stablehlo.constant dense<0.000000e+00> : tensor + + // expected-error @+1 {{failed to legalize operation 'vhlo.select_and_scatter_v1' that was explicitly marked illegal}} + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + func.return +} diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp index 6534ebe7899..4e4b106651c 100644 --- a/stablehlo/transforms/VhloToVersion.cpp +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -181,7 +181,14 @@ bool isLegalOperation(Operation* op, const Version& targetVersion) { auto opInterface = dyn_cast(op); if (!opInterface) return false; if (!isLegalVersion(opInterface, targetVersion)) return false; - LLVM_DEBUG(llvm::dbgs() << "Legal version for target. " << op << '\n'); + LLVM_DEBUG(llvm::dbgs() << "Legal op version for target. " << op << '\n'); + + // Validate op constraints + auto constraintInterface = dyn_cast(op); + if (constraintInterface && + failed(constraintInterface.validateConstraint(op, targetVersion))) + return false; + LLVM_DEBUG(llvm::dbgs() << "Legal constraints for target. " << op << '\n'); // Validate attributes auto isLegalAttrFn = [&](const NamedAttribute& attr) { From 6fbb8ef803b751c1e9458ac6a0457a1051c4970f Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Sat, 16 Dec 2023 02:54:22 +0000 Subject: [PATCH 3/4] address feedback:1 --- stablehlo/dialect/TypeInference.cpp | 15 ++++++++------ stablehlo/dialect/VhloOps.cpp | 32 +++++++++++++++-------------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 57fb26d1853..82fba18aaa8 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -103,6 +103,11 @@ unsigned potentiallyComplexBitWidth(Type type) { : type.getIntOrFloatBitWidth(); } +template +bool matchesType(Type a, Type b) { + return a.isa() && b.isa(); +} + // Returns true if the element-type of type1 can be promoted to that of type2. // An element-type 'x' is promotatble to element-type 'y' is they have the same // base type and bitwidth(x) <= bitwidth(y). When 'x' and 'y' are quantized @@ -122,12 +127,10 @@ bool isPromotableElementType(Type type1, Type type2, tensorEl2.isa()) return true; - bool isSameType = - (tensorEl1.isa() and tensorEl2.isa()) || - (tensorEl1.isa() and tensorEl2.isa()) || - (tensorEl1.isa() and tensorEl2.isa()) || - (tensorEl1.isa() and - tensorEl2.isa()); + bool isSameType = matchesType(tensorEl1, tensorEl2) || + matchesType(tensorEl1, tensorEl2) || + matchesType(tensorEl1, tensorEl2) || + matchesType(tensorEl1, tensorEl2); if (!isSameType) return false; diff --git a/stablehlo/dialect/VhloOps.cpp b/stablehlo/dialect/VhloOps.cpp index f28432642c9..2c2ea9e0220 100644 --- a/stablehlo/dialect/VhloOps.cpp +++ b/stablehlo/dialect/VhloOps.cpp @@ -305,23 +305,25 @@ void VhloDialect::printAttribute(Attribute attr, DialectAsmPrinter& os) const { // to represent this sort of constraint in tablegen. namespace { +Type getVhloElementType(Type tensorType) { + if (auto ranked = tensorType.dyn_cast()) { + return ranked.getElementType(); + } + return tensorType.cast().getElementType(); +} + bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes, TypeRange resultTypes) { - SmallVector inputShapedTypes{ - llvm::map_range(operandTypes, [](Type t) { - return convertTypeToBuiltinForPrint(t).cast(); - })}; - SmallVector resultShapedTypes{ - llvm::map_range(resultTypes, [](Type t) { - return convertTypeToBuiltinForPrint(t).cast(); - })}; - - int64_t numInputs = inputShapedTypes.size(); - for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) { - if (inputShapedTypes[inputIdx].getElementType() != - resultShapedTypes[inputIdx].getElementType()) - return true; - } + SmallVector inputElementTypes{llvm::map_range( + operandTypes, [](Type t) { return getVhloElementType(t); })}; + SmallVector resultElementTypes{llvm::map_range( + resultTypes, [](Type t) { return getVhloElementType(t); })}; + + if (llvm::any_of( + llvm::zip(inputElementTypes, resultElementTypes), + [&](auto pair) { return std::get<0>(pair) != std::get<1>(pair); })) + return true; + return false; } } // namespace From e497deb430a9b967daea1b675d7195e891350ce6 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Wed, 20 Dec 2023 20:49:01 +0000 Subject: [PATCH 4/4] addressed feedback:II --- stablehlo/dialect/StablehloOps.td | 2 - stablehlo/dialect/TypeInference.cpp | 48 +++++++-------- stablehlo/dialect/VhloOps.cpp | 58 +++++-------------- .../stablehlo_legalize_to_vhlo.0_17_0.mlir | 10 ++-- stablehlo/transforms/VhloToVersion.cpp | 5 +- 5 files changed, 47 insertions(+), 76 deletions(-) diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 1d37c7e3628..e5dbfed5936 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1474,8 +1474,6 @@ def StableHLO_ReduceOp: StableHLO_ShapedInterfaceOp<"reduce", [ Variadic:$init_values, /*reduce_i2*/ I64ElementsAttr:$dimensions /*reduce_i3*/ ); - // TODO(hinsu): Verify that the attached body arguments and results are - // compatible with reduce op's operands. let regions = (region SizedRegion<1>:$body /*reduce_i4*/); let results = (outs Variadic); diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 82fba18aaa8..67fb6d821df 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -97,15 +97,23 @@ bool tensorsHaveSameElType(Type type1, Type type2, return tensorsHaveSameElType({type1, type2}, ignoreFpPrecision); } -unsigned potentiallyComplexBitWidth(Type type) { - auto complexTy = type.dyn_cast(); - return complexTy ? 2 * complexTy.getElementType().getIntOrFloatBitWidth() - : type.getIntOrFloatBitWidth(); +unsigned getBitWidth(Type type) { + if (auto complexTy = type.dyn_cast()) + return 2 * getBitWidth(complexTy.getElementType()); + if (auto quantTy = type.dyn_cast()) + return getBitWidth(quantTy.getStorageType()); + return type.getIntOrFloatBitWidth(); } template bool matchesType(Type a, Type b) { - return a.isa() && b.isa(); + bool matches = a.isa() && b.isa(); + // Check that expressed type matches for quantized types + if constexpr (std::is_same::value) { + return matches && (a.cast().getExpressedType() == + b.cast().getExpressedType()); + } + return matches; } // Returns true if the element-type of type1 can be promoted to that of type2. @@ -123,10 +131,6 @@ bool isPromotableElementType(Type type1, Type type2, Type tensorEl1 = tensorTy1.getElementType(); Type tensorEl2 = tensorTy2.getElementType(); - if (ignoreFpPrecision && tensorEl1.isa() && - tensorEl2.isa()) - return true; - bool isSameType = matchesType(tensorEl1, tensorEl2) || matchesType(tensorEl1, tensorEl2) || matchesType(tensorEl1, tensorEl2) || @@ -134,15 +138,9 @@ bool isPromotableElementType(Type type1, Type type2, if (!isSameType) return false; - if (!tensorEl1.isa()) - return potentiallyComplexBitWidth(tensorEl1) <= - potentiallyComplexBitWidth(tensorEl2); + if (ignoreFpPrecision && tensorEl1.isa()) return true; - auto quantType1 = tensorEl1.cast(); - auto quantType2 = tensorEl2.cast(); - return quantType1.getExpressedType() == quantType2.getExpressedType() && - potentiallyComplexBitWidth(quantType1.getStorageType()) <= - potentiallyComplexBitWidth(quantType2.getStorageType()); + return getBitWidth(tensorEl1) <= getBitWidth(tensorEl2); } // Return true if type1 and type2 are shape-compatible and have same element @@ -577,7 +575,7 @@ LogicalResult verifyReduceOpInputsAndInferShape( SmallVector getAccumulatorTypes(Block& block) { SmallVector accumulatorSubShapes; for (Value retOperand : block.getTerminator()->getOperands()) { - auto shapedTy = retOperand.getType().dyn_cast(); + auto shapedTy = retOperand.getType().cast(); accumulatorSubShapes.push_back(shapedTy); } return accumulatorSubShapes; @@ -678,10 +676,8 @@ LogicalResult verifyReducerShape(std::optional loc, Block& block, loc, "The element-type of reduction-region's argument at index ", numInputs + inputIdx, " is expected to be promotable from ", inputTypes[inputIdx].getElementType(), ", but got ", - block.getArgument(numInputs + inputIdx) - .getType() - .cast() - .getElementType()); + getElementTypeOrSelf( + block.getArgument(numInputs + inputIdx).getType())); Type blockArgType = block.getArgument(numInputs + inputIdx).getType(); auto blockArgTensorTy = blockArgType.cast(); @@ -2741,11 +2737,11 @@ LogicalResult inferRngOp( } LogicalResult inferScatterOp(std::optional, ValueRange inputs, - Region& update_computation, + Region& updateComputation, SmallVectorImpl& inferredReturnTypes) { // scatter_c16, scatter_c17 SmallVector accumulatorTypes = - getAccumulatorTypes(update_computation.front()); + getAccumulatorTypes(updateComputation.front()); for (uint64_t inputIdx = 0; inputIdx < inputs.size(); ++inputIdx) { auto inputShapedTy = inputs[inputIdx].getType().cast(); inferredReturnTypes.push_back(getSameShapeTensorType( @@ -3232,8 +3228,8 @@ LogicalResult verifyBitcastConvertOp(std::optional location, location, "cannot convert between real and complex types, but got: ", operandShapedType, " and ", targetShapedType); - auto targetEltBitWidth = potentiallyComplexBitWidth(targetElt); - auto operandEltBitWidth = potentiallyComplexBitWidth(operandElt); + auto targetEltBitWidth = getBitWidth(targetElt); + auto operandEltBitWidth = getBitWidth(operandElt); auto operandType = operandShapedType.dyn_cast(); auto targetType = targetShapedType.dyn_cast(); diff --git a/stablehlo/dialect/VhloOps.cpp b/stablehlo/dialect/VhloOps.cpp index 2c2ea9e0220..e8c651e58a6 100644 --- a/stablehlo/dialect/VhloOps.cpp +++ b/stablehlo/dialect/VhloOps.cpp @@ -326,72 +326,46 @@ bool checkIfOperandAndResultElementTypesMatch(TypeRange operandTypes, return false; } -} // namespace -LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op, - Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), - getResult().getType()) && +// Allow mismatched operand and result types in reduce ops in v0.17.0 +LogicalResult verifyConstraint_0_17_0(mlir::Operation* op, + Version targetVersion) { + if (checkIfOperandAndResultElementTypesMatch(op->getOperandTypes(), + op->getResultTypes()) && targetVersion < Version(0, 17, 0)) return failure(); - return success(); } +} // namespace + +LogicalResult AllReduceOpV1::validateConstraint(mlir::Operation* op, + Version targetVersion) { + return verifyConstraint_0_17_0(op, targetVersion); +} LogicalResult ReduceOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), - getResultTypes()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } LogicalResult ReduceScatterOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), - getResult().getType()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } LogicalResult ReduceWindowOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), - getResultTypes()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } LogicalResult ScatterOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getInputs().getTypes(), - getResultTypes()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } LogicalResult SelectAndScatterOpV1::validateConstraint(mlir::Operation* op, Version targetVersion) { - // Allow mismatched operand and result types in v0.17.0 - if (checkIfOperandAndResultElementTypesMatch(getOperand().getType(), - getResult().getType()) && - targetVersion < Version(0, 17, 0)) - return failure(); - - return success(); + return verifyConstraint_0_17_0(op, targetVersion); } } // namespace vhlo diff --git a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir index 70efa90c0b3..acd729402ff 100644 --- a/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir +++ b/stablehlo/tests/stablehlo_legalize_to_vhlo.0_17_0.mlir @@ -509,7 +509,7 @@ func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tenso // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { - broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> + broadcast_dimensions = array } : (tensor, tensor<2xindex>) -> tensor func.return %0 : tensor } @@ -933,7 +933,7 @@ func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> %0 = "stablehlo.broadcast_in_dim"(%arg0) { - broadcast_dimensions = dense<1> : tensor<1xi64> + broadcast_dimensions = array } : (tensor<16xf32>) -> tensor<16x16xf32> func.return %0 : tensor<16x16xf32> } @@ -1214,9 +1214,9 @@ func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xi // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { - broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, - known_expanding_dimensions = dense<0> : tensor<1xi64>, - known_nonexpanding_dimensions = dense<1> : tensor<1xi64> + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array } : (tensor, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp index 4e4b106651c..a8cf57d1aa6 100644 --- a/stablehlo/transforms/VhloToVersion.cpp +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -186,8 +186,11 @@ bool isLegalOperation(Operation* op, const Version& targetVersion) { // Validate op constraints auto constraintInterface = dyn_cast(op); if (constraintInterface && - failed(constraintInterface.validateConstraint(op, targetVersion))) + failed(constraintInterface.validateConstraint(op, targetVersion))) { + LLVM_DEBUG(llvm::dbgs() + << "Op failed to satisfy versioned constraints. " << op << '\n'); return false; + } LLVM_DEBUG(llvm::dbgs() << "Legal constraints for target. " << op << '\n'); // Validate attributes